start new impl
This commit is contained in:
parent
baacf92034
commit
de082fedf6
@ -3,8 +3,7 @@ headers = [
|
|||||||
'pgxx/pgxx.hpp',
|
'pgxx/pgxx.hpp',
|
||||||
'pgxx/query_builder.hpp',
|
'pgxx/query_builder.hpp',
|
||||||
|
|
||||||
'pgxx/utils/define.hpp',
|
'pgxx/utils/aliases.hpp',
|
||||||
'pgxx/utils/using.hpp',
|
|
||||||
'pgxx/utils/var.hpp',
|
'pgxx/utils/var.hpp',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -4,55 +4,45 @@
|
|||||||
#include "hack/exception/exception.hpp"
|
#include "hack/exception/exception.hpp"
|
||||||
#include "hack/logger/logger.hpp"
|
#include "hack/logger/logger.hpp"
|
||||||
|
|
||||||
#include "pgxx/utils/define.hpp"
|
#include "utils/aliases.hpp" // IWYU pragma: keep
|
||||||
#include "pgxx/utils/using.hpp"
|
#include "utils/var.hpp"
|
||||||
#include "pgxx/utils/var.hpp"
|
#include "pool_connection.hpp"
|
||||||
#include "pgxx/query_builder.hpp"
|
|
||||||
#include "pgxx/pool_connection.hpp"
|
|
||||||
|
|
||||||
namespace pgxx
|
namespace pgxx
|
||||||
{
|
{
|
||||||
class manager : public hack::utils::singleton<manager>
|
class database : public hack::utils::singleton<database>
|
||||||
{
|
{
|
||||||
friend hack::utils::singleton<manager>;
|
friend hack::utils::singleton<database>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
~manager() = default;
|
~database() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
manager() = default;
|
database() = default;
|
||||||
std::map<std::string, pool_connection> m_data_connections;
|
std::map<std::string, pool_connection> m_connections;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bool ready() { return m_data_connections.size() != 0; }
|
bool ready() { return m_connections.size() != 0; }
|
||||||
|
|
||||||
void init(std::string connection_name, int connection_count, std::string connection_url)
|
void init(std::string connection_name, int connection_count, std::string connection_url)
|
||||||
{
|
{
|
||||||
m_data_connections[connection_name] = pool_connection { connection_count, connection_url };
|
m_connections[connection_name] = pool_connection { connection_count, connection_url };
|
||||||
hack::log("")("make connection [", connection_name, "] completed");
|
hack::log("")("make connection [", connection_name, "] completed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename... Args>
|
template<typename... Args>
|
||||||
JSON execute(const std::string connection_name, std::string func_name, const Args&... args)
|
pqxx::result execute(const std::string connection_name, std::string query)
|
||||||
{
|
{
|
||||||
auto query = prepare(func_name, args...);
|
pqxx::result r;
|
||||||
JSON result;
|
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
auto c = m_data_connections[connection_name].get();
|
auto c = m_connections[connection_name].get();
|
||||||
|
|
||||||
pqxx::result r;
|
|
||||||
pqxx::work work { *c };
|
pqxx::work work { *c };
|
||||||
r = work.exec(query);
|
r = work.exec(query);
|
||||||
|
|
||||||
std::string r_str;
|
|
||||||
for (auto row : r) r_str = row.at(0).c_str();
|
|
||||||
|
|
||||||
work.commit();
|
work.commit();
|
||||||
m_data_connections[connection_name].release(c);
|
m_connections[connection_name].release(c);
|
||||||
|
|
||||||
result = JSON::parse(r_str);
|
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (const std::exception& e)
|
||||||
{
|
{
|
||||||
@ -61,15 +51,6 @@ namespace pgxx
|
|||||||
ex.system_error(e);
|
ex.system_error(e);
|
||||||
ex.params("connection_name", connection_name);
|
ex.params("connection_name", connection_name);
|
||||||
ex.params("query", query);
|
ex.params("query", query);
|
||||||
ex.params("result", result);
|
|
||||||
|
|
||||||
if (connection_name == var::LOG_CONNECTION)
|
|
||||||
{
|
|
||||||
hack::error()("WARNING!!! ERROR LOG TO DATABASE");
|
|
||||||
hack::error()("query", query);
|
|
||||||
hack::error()(e.what());
|
|
||||||
std::terminate();
|
|
||||||
}
|
|
||||||
throw ex;
|
throw ex;
|
||||||
}
|
}
|
||||||
catch (...)
|
catch (...)
|
||||||
@ -79,44 +60,11 @@ namespace pgxx
|
|||||||
ex.message(var::EXECUTE_ERROR);
|
ex.message(var::EXECUTE_ERROR);
|
||||||
ex.params("connection_name", connection_name);
|
ex.params("connection_name", connection_name);
|
||||||
ex.params("query", query);
|
ex.params("query", query);
|
||||||
ex.params("result", result);
|
|
||||||
|
|
||||||
if (connection_name == var::LOG_CONNECTION)
|
|
||||||
{
|
|
||||||
hack::error()("WARNING!!! ERROR LOG TO DATABASE");
|
|
||||||
hack::error()("query", query);
|
|
||||||
std::terminate();
|
|
||||||
}
|
|
||||||
throw ex;
|
throw ex;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
template<typename... Args>
|
|
||||||
std::string prepare(const std::string func_name, const Args&... args)
|
|
||||||
{
|
|
||||||
std::string query;
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
query = builder::make_query(func_name, args...);
|
|
||||||
}
|
|
||||||
catch (const std::exception& e)
|
|
||||||
{
|
|
||||||
hack::exception ex;
|
|
||||||
ex.description("database dont create query from args");
|
|
||||||
ex.system_error(e);
|
|
||||||
ex.params("query", query);
|
|
||||||
ex.variadic_params(args...);
|
|
||||||
throw ex;
|
|
||||||
}
|
|
||||||
|
|
||||||
return query;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,25 +2,22 @@
|
|||||||
|
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
#include "hack/string/string_concat_helper.hpp"
|
|
||||||
#include "hack/concepts/concepts.hpp"
|
#include "hack/concepts/concepts.hpp"
|
||||||
|
|
||||||
#include "utils/using.hpp"
|
|
||||||
|
|
||||||
namespace pgxx::builder
|
namespace pgxx::builder
|
||||||
{
|
{
|
||||||
template<hack::concepts::is_string First>
|
template<hack::concepts::is_string First>
|
||||||
std::string make_one(First f)
|
std::string make_one(First f)
|
||||||
{
|
{
|
||||||
f = std::regex_replace(f, std::regex("'"), "[quote]");
|
f = std::regex_replace(f, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + "'" + f + "',";
|
return "'" + f + "',";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string make_one(const char* f)
|
inline std::string make_one(const char* f)
|
||||||
{
|
{
|
||||||
auto f_str = std::string(f);
|
auto f_str = std::string(f);
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + "'" + f_str + "',";
|
return "'" + f_str + "',";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string make_one(char f)
|
inline std::string make_one(char f)
|
||||||
@ -34,35 +31,28 @@ namespace pgxx::builder
|
|||||||
{
|
{
|
||||||
auto f_str = std::to_string(f);
|
auto f_str = std::to_string(f);
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + "'" + f_str + "',";
|
return "'" + f_str + "',";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string make_one(const float f)
|
inline std::string make_one(const float f)
|
||||||
{
|
{
|
||||||
auto f_str = std::to_string(f);
|
auto f_str = std::to_string(f);
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + f_str + ",";
|
return f_str + ",";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string make_one(int f)
|
inline std::string make_one(int f)
|
||||||
{
|
{
|
||||||
auto f_str = std::to_string(f);
|
auto f_str = std::to_string(f);
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + f_str + ",";
|
return f_str + ",";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string make_one(const std::string& f)
|
inline std::string make_one(const std::string& f)
|
||||||
{
|
{
|
||||||
auto f_str = f;
|
auto f_str = f;
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
||||||
return hack::string::str_concat + "'" + f_str + "',";
|
return "'" + f_str + "',";
|
||||||
}
|
|
||||||
|
|
||||||
inline std::string make_one(const JSON& f)
|
|
||||||
{
|
|
||||||
auto f_str = f.dump();
|
|
||||||
f_str = std::regex_replace(f_str, std::regex("'"), "[quote]");
|
|
||||||
return hack::string::str_concat + "'" + f_str + "'::jsonb,";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// это заглушкa при компиляции пустых данных
|
// это заглушкa при компиляции пустых данных
|
||||||
|
5
src/pgxx/utils/aliases.hpp
Normal file
5
src/pgxx/utils/aliases.hpp
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef PGXX
|
||||||
|
#define PGXX() pgxx::database::instance()
|
||||||
|
#endif
|
@ -1,3 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#define PGXX() pgxx::manager::instance()
|
|
@ -1,8 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
|
|
||||||
namespace pgxx
|
|
||||||
{
|
|
||||||
using JSON = nlohmann::json;
|
|
||||||
}
|
|
194
tests/main.cpp
194
tests/main.cpp
@ -1,7 +1,177 @@
|
|||||||
#include <thread>
|
#include <future>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "pgxx/pgxx.hpp"
|
#include "pgxx/pgxx.hpp"
|
||||||
|
#include "thread/pool.hpp"
|
||||||
|
|
||||||
|
const int MAX_QUERY_IN_DB = 10'000;
|
||||||
|
const float MAX_OVERLAP = 12.f;
|
||||||
|
|
||||||
|
struct bit
|
||||||
|
{
|
||||||
|
std::string m_id;
|
||||||
|
std::string m_url;
|
||||||
|
std::string m_artist;
|
||||||
|
std::string m_song;
|
||||||
|
std::vector<int> m_key;
|
||||||
|
std::vector<int> m_duration;
|
||||||
|
};
|
||||||
|
|
||||||
|
bit target;
|
||||||
|
|
||||||
|
std::vector<int> parse_array(const std::string& arr_str)
|
||||||
|
{
|
||||||
|
std::vector<int> result;
|
||||||
|
std::stringstream ss(arr_str);
|
||||||
|
std::string item;
|
||||||
|
|
||||||
|
while (std::getline(ss, item, ','))
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
result.push_back(std::stoi(item));
|
||||||
|
}
|
||||||
|
catch (const std::invalid_argument&)
|
||||||
|
{
|
||||||
|
std::cerr << "Ошибка преобразования: " << item << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace calc
|
||||||
|
{
|
||||||
|
inline int calculation(int db_offset)
|
||||||
|
{
|
||||||
|
std::vector<bit> vb;
|
||||||
|
vb.reserve(MAX_QUERY_IN_DB);
|
||||||
|
|
||||||
|
std::string query = "SELECT * FROM t_media LIMIT " + std::to_string(MAX_QUERY_IN_DB) + " OFFSET " + std::to_string(db_offset);
|
||||||
|
auto r = PGXX().execute("con_1", query);
|
||||||
|
|
||||||
|
for (auto el : r)
|
||||||
|
{
|
||||||
|
bit b;
|
||||||
|
|
||||||
|
b.m_id = el["m_id"].as<std::string>();
|
||||||
|
b.m_url = el["m_url"].as<std::string>();
|
||||||
|
b.m_artist = el["m_artist"].as<std::string>();
|
||||||
|
b.m_song = el["m_song"].as<std::string>();
|
||||||
|
|
||||||
|
std::string str = el["m_key"].as<std::string>();
|
||||||
|
b.m_key = parse_array(str);
|
||||||
|
|
||||||
|
str = el["m_duration"].as<std::string>();
|
||||||
|
b.m_duration = parse_array(str);
|
||||||
|
vb.push_back(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
return vb.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace executor
|
||||||
|
{
|
||||||
|
inline void get_recomended()
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
int count = 0;
|
||||||
|
std::string query = "SELECT count(*) FROM t_media WHERE m_cleared = 0;";
|
||||||
|
auto r = PGXX().execute("con_1", query);
|
||||||
|
for (auto eld : r) count = eld["count"].as<int>();
|
||||||
|
|
||||||
|
thread::pool pool;
|
||||||
|
|
||||||
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
int i = 0;
|
||||||
|
int def = count / MAX_QUERY_IN_DB + 1;
|
||||||
|
int db_offset = 0;
|
||||||
|
|
||||||
|
std::vector<std::future<int>> futures;
|
||||||
|
futures.reserve(def);
|
||||||
|
|
||||||
|
while (def > 0)
|
||||||
|
{
|
||||||
|
futures.push_back(pool.enqueue(calc::calculation, db_offset));
|
||||||
|
db_offset += MAX_QUERY_IN_DB;
|
||||||
|
--def;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &f : futures)
|
||||||
|
{
|
||||||
|
int r = f.get();
|
||||||
|
i += r;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::chrono::duration<double> elapsed = std::chrono::high_resolution_clock::now() - start;
|
||||||
|
hack::log()(i, elapsed.count());
|
||||||
|
}
|
||||||
|
catch(const std::exception& e)
|
||||||
|
{
|
||||||
|
hack::log()(e.what());
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
catch(hack::exception& ex)
|
||||||
|
{
|
||||||
|
ex.log();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
hack::log()("ooops!");
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void convert_db()
|
||||||
|
{
|
||||||
|
bool is_work = true;
|
||||||
|
int count = 0;
|
||||||
|
while (is_work)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
std::string query = "SELECT * FROM t_media WHERE m_cleared = 0 LIMIT " + std::to_string(1) + ";";
|
||||||
|
auto r = PGXX().execute("con_1", query);
|
||||||
|
hack::log("")("size = ", r.size(), ", count = ", count);
|
||||||
|
|
||||||
|
for (auto el : r)
|
||||||
|
{
|
||||||
|
target.m_id = el["m_id"].as<std::string>();
|
||||||
|
target.m_url = el["m_url"].as<std::string>();
|
||||||
|
target.m_artist = el["m_artist"].as<std::string>();
|
||||||
|
target.m_song = el["m_song"].as<std::string>();
|
||||||
|
|
||||||
|
std::string key = el["m_key"].as<std::string>();
|
||||||
|
std::string duration = el["m_duration"].as<std::string>();
|
||||||
|
|
||||||
|
query = "UPDATE t_media SET m_cleared = 1 WHERE m_id = '" + target.m_id + "';";
|
||||||
|
PGXX().execute("con_1", query);
|
||||||
|
|
||||||
|
executor::get_recomended();
|
||||||
|
}
|
||||||
|
|
||||||
|
query = "SELECT count(*) FROM t_media WHERE m_cleared = 0;";
|
||||||
|
r = PGXX().execute("con_1", query);
|
||||||
|
|
||||||
|
for (auto eld : r)
|
||||||
|
{
|
||||||
|
auto c = eld["count"].as<int>();
|
||||||
|
if (count - c > 1) hack::log()("DOUBLE");
|
||||||
|
count = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count == 0) is_work = false;
|
||||||
|
}
|
||||||
|
catch(hack::exception& e)
|
||||||
|
{
|
||||||
|
e.log();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto main(int argc, char* args[]) -> int
|
auto main(int argc, char* args[]) -> int
|
||||||
{
|
{
|
||||||
@ -10,7 +180,6 @@ auto main(int argc, char* args[]) -> int
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
PGXX().init("con_1", 300, con);
|
PGXX().init("con_1", 300, con);
|
||||||
PGXX().init("con_2", 300, con);
|
|
||||||
}
|
}
|
||||||
catch(hack::exception& ex)
|
catch(hack::exception& ex)
|
||||||
{
|
{
|
||||||
@ -21,24 +190,5 @@ auto main(int argc, char* args[]) -> int
|
|||||||
if (!PGXX().ready())
|
if (!PGXX().ready())
|
||||||
hack::log()("error connection");
|
hack::log()("error connection");
|
||||||
|
|
||||||
pgxx::JSON j {
|
convert_db();
|
||||||
{
|
|
||||||
"params", { { "key_1", 1 }, { "key2", "value" } }
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto i = 0; i < 10; ++i)
|
|
||||||
{
|
|
||||||
std::thread th([&j](){
|
|
||||||
auto r = PGXX().execute("con_1", "read_and_write", j);
|
|
||||||
});
|
|
||||||
th.detach();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto i = 0; i < 10; ++i)
|
|
||||||
{
|
|
||||||
auto r = PGXX().execute("con_2", "read_and_write", j);
|
|
||||||
}
|
|
||||||
|
|
||||||
hack::log()("ok");
|
|
||||||
}
|
}
|
||||||
|
288
tests/thread/pool.hpp
Normal file
288
tests/thread/pool.hpp
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <concepts>
|
||||||
|
#include <deque>
|
||||||
|
#include <functional>
|
||||||
|
#include <future>
|
||||||
|
#include <memory>
|
||||||
|
#include <semaphore>
|
||||||
|
#include <thread>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "queue.hpp"
|
||||||
|
|
||||||
|
namespace thread
|
||||||
|
{
|
||||||
|
namespace details
|
||||||
|
{
|
||||||
|
using default_function_type = std::function<void()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FunctionType = details::default_function_type, typename ThreadType = std::jthread>
|
||||||
|
requires std::invocable<FunctionType> && std::is_same_v<void, std::invoke_result_t<FunctionType>>
|
||||||
|
class pool
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
template <typename InitializationFunction = std::function<void(std::size_t)>>
|
||||||
|
requires std::invocable<InitializationFunction, std::size_t> && std::is_same_v<void, std::invoke_result_t<InitializationFunction, std::size_t>>
|
||||||
|
explicit pool(const unsigned int& number_of_threads = std::thread::hardware_concurrency(), InitializationFunction init = [](std::size_t) {}) : tasks_(number_of_threads)
|
||||||
|
{
|
||||||
|
std::size_t current_id = 0;
|
||||||
|
for (std::size_t i = 0; i < number_of_threads; ++i)
|
||||||
|
{
|
||||||
|
priority_queue_.push_back(size_t(current_id));
|
||||||
|
try
|
||||||
|
{
|
||||||
|
threads_.emplace_back([&, id = current_id, init](const std::stop_token &stop_tok) {
|
||||||
|
// invoke the init function on the thread
|
||||||
|
try
|
||||||
|
{
|
||||||
|
std::invoke(init, id);
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
// suppress exceptions
|
||||||
|
}
|
||||||
|
|
||||||
|
do
|
||||||
|
{
|
||||||
|
// wait until signaled
|
||||||
|
tasks_[id].signal.acquire();
|
||||||
|
do
|
||||||
|
{
|
||||||
|
// invoke the task
|
||||||
|
while (auto task = tasks_[id].tasks.pop_front())
|
||||||
|
{
|
||||||
|
// decrement the unassigned tasks as the task is now going
|
||||||
|
// to be executed
|
||||||
|
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
|
||||||
|
// invoke the task
|
||||||
|
std::invoke(std::move(task.value()));
|
||||||
|
// the above task can push more work onto the pool, so we
|
||||||
|
// only decrement the in flights once the task has been
|
||||||
|
// executed because now it's now longer "in flight"
|
||||||
|
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to steal a task
|
||||||
|
for (std::size_t j = 1; j < tasks_.size(); ++j)
|
||||||
|
{
|
||||||
|
const std::size_t index = (id + j) % tasks_.size();
|
||||||
|
if (auto task = tasks_[index].tasks.steal())
|
||||||
|
{
|
||||||
|
// steal a task
|
||||||
|
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
|
||||||
|
std::invoke(std::move(task.value()));
|
||||||
|
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
|
||||||
|
// stop stealing once we have invoked a stolen task
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check if there are any unassigned tasks before rotating to the
|
||||||
|
// front and waiting for more work
|
||||||
|
} while (unassigned_tasks_.load(std::memory_order_acquire) > 0);
|
||||||
|
|
||||||
|
priority_queue_.rotate_to_front(id);
|
||||||
|
// check if all tasks are completed and release the "barrier"
|
||||||
|
if (in_flight_tasks_.load(std::memory_order_acquire) == 0)
|
||||||
|
{
|
||||||
|
// in theory, only one thread will set this
|
||||||
|
threads_complete_signal_.store(true, std::memory_order_release);
|
||||||
|
threads_complete_signal_.notify_one();
|
||||||
|
}
|
||||||
|
} while (!stop_tok.stop_requested());
|
||||||
|
});
|
||||||
|
// increment the thread id
|
||||||
|
++current_id;
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
tasks_.pop_back();
|
||||||
|
std::ignore = priority_queue_.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~pool()
|
||||||
|
{
|
||||||
|
wait_for_tasks();
|
||||||
|
// stop all threads
|
||||||
|
for (std::size_t i = 0; i < threads_.size(); ++i)
|
||||||
|
{
|
||||||
|
threads_[i].request_stop();
|
||||||
|
tasks_[i].signal.release();
|
||||||
|
threads_[i].join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// thread pool is non-copyable
|
||||||
|
pool(const pool &) = delete;
|
||||||
|
pool &operator=(const pool &) = delete;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Enqueue a task into the thread pool that returns a result.
|
||||||
|
* @details Note that task execution begins once the task is enqueued.
|
||||||
|
* @tparam Function An invokable type.
|
||||||
|
* @tparam Args Argument parameter pack
|
||||||
|
* @tparam ReturnType The return type of the Function
|
||||||
|
* @param f The callable function
|
||||||
|
* @param args The parameters that will be passed (copied) to the function.
|
||||||
|
* @return A std::future<ReturnType> that can be used to retrieve the returned value.
|
||||||
|
*/
|
||||||
|
template <typename Function, typename... Args, typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
|
||||||
|
requires std::invocable<Function, Args...>
|
||||||
|
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
* use shared promise here so that we don't break the promise later (until C++23)
|
||||||
|
*
|
||||||
|
* with C++23 we can do the following:
|
||||||
|
*
|
||||||
|
* std::promise<ReturnType> promise;
|
||||||
|
* auto future = promise.get_future();
|
||||||
|
* auto task = [func = std::move(f), ...largs = std::move(args),
|
||||||
|
promise = std::move(promise)]() mutable {...};
|
||||||
|
*/
|
||||||
|
auto shared_promise = std::make_shared<std::promise<ReturnType>>();
|
||||||
|
auto task = [func = std::move(f), ... largs = std::move(args), promise = shared_promise]() {
|
||||||
|
try
|
||||||
|
{
|
||||||
|
if constexpr (std::is_same_v<ReturnType, void>)
|
||||||
|
{
|
||||||
|
func(largs...);
|
||||||
|
promise->set_value();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
promise->set_value(func(largs...));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
promise->set_exception(std::current_exception());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// get the future before enqueuing the task
|
||||||
|
auto future = shared_promise->get_future();
|
||||||
|
// enqueue the task
|
||||||
|
enqueue_task(std::move(task));
|
||||||
|
return future;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Enqueue a task to be executed in the thread pool. Any return value of the function
|
||||||
|
* will be ignored.
|
||||||
|
* @tparam Function An invokable type.
|
||||||
|
* @tparam Args Argument parameter pack for Function
|
||||||
|
* @param func The callable to be executed
|
||||||
|
* @param args Arguments that will be passed to the function.
|
||||||
|
*/
|
||||||
|
template <typename Function, typename... Args>
|
||||||
|
requires std::invocable<Function, Args...>
|
||||||
|
void enqueue_detach(Function &&func, Args &&...args)
|
||||||
|
{
|
||||||
|
enqueue_task(std::move([f = std::forward<Function>(func), ... largs = std::forward<Args>(args)]() mutable -> decltype(auto) {
|
||||||
|
// suppress exceptions
|
||||||
|
try
|
||||||
|
{
|
||||||
|
if constexpr (std::is_same_v<void, std::invoke_result_t<Function &&, Args &&...>>)
|
||||||
|
{
|
||||||
|
std::invoke(f, largs...);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// the function returns an argument, but can be ignored
|
||||||
|
std::ignore = std::invoke(f, largs...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Returns the number of threads in the pool.
|
||||||
|
*
|
||||||
|
* @return std::size_t The number of threads in the pool.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] auto size() const { return threads_.size(); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Wait for all tasks to finish.
|
||||||
|
* @details This function will block until all tasks have been completed.
|
||||||
|
*/
|
||||||
|
void wait_for_tasks()
|
||||||
|
{
|
||||||
|
if (in_flight_tasks_.load(std::memory_order_acquire) > 0)
|
||||||
|
{
|
||||||
|
// wait for all tasks to finish
|
||||||
|
threads_complete_signal_.wait(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Makes best-case attempt to clear all tasks from the thread_pool
|
||||||
|
* @details Note that this does not guarantee that all tasks will be cleared, as currently
|
||||||
|
* running tasks could add additional tasks. Also a thread could steal a task from another
|
||||||
|
* in the middle of this.
|
||||||
|
* @return number of tasks cleared
|
||||||
|
*/
|
||||||
|
size_t clear_tasks()
|
||||||
|
{
|
||||||
|
size_t removed_task_count{0};
|
||||||
|
for (auto &task_list : tasks_)
|
||||||
|
{
|
||||||
|
removed_task_count += task_list.tasks.clear();
|
||||||
|
}
|
||||||
|
in_flight_tasks_.fetch_sub(removed_task_count, std::memory_order_release);
|
||||||
|
unassigned_tasks_.fetch_sub(removed_task_count, std::memory_order_release);
|
||||||
|
|
||||||
|
return removed_task_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename Function>
|
||||||
|
void enqueue_task(Function &&f)
|
||||||
|
{
|
||||||
|
auto i_opt = priority_queue_.copy_front_and_rotate_to_back();
|
||||||
|
if (!i_opt.has_value())
|
||||||
|
{
|
||||||
|
// would only be a problem if there are zero threads
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// get the index
|
||||||
|
auto i = *(i_opt);
|
||||||
|
|
||||||
|
// increment the unassigned tasks and in flight tasks
|
||||||
|
unassigned_tasks_.fetch_add(1, std::memory_order_release);
|
||||||
|
const auto prev_in_flight = in_flight_tasks_.fetch_add(1, std::memory_order_release);
|
||||||
|
|
||||||
|
// reset the in flight signal if the list was previously empty
|
||||||
|
if (prev_in_flight == 0)
|
||||||
|
{
|
||||||
|
threads_complete_signal_.store(false, std::memory_order_release);
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign work
|
||||||
|
tasks_[i].tasks.push_back(std::forward<Function>(f));
|
||||||
|
tasks_[i].signal.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct task_item
|
||||||
|
{
|
||||||
|
queue<FunctionType> tasks{};
|
||||||
|
std::binary_semaphore signal{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<ThreadType> threads_;
|
||||||
|
std::deque<task_item> tasks_;
|
||||||
|
queue<std::size_t> priority_queue_;
|
||||||
|
// guarantee these get zero-initialized
|
||||||
|
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
|
||||||
|
std::atomic_bool threads_complete_signal_{false};
|
||||||
|
};
|
||||||
|
}
|
108
tests/thread/queue.hpp
Normal file
108
tests/thread/queue.hpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <deque>
|
||||||
|
#include <mutex>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace thread
|
||||||
|
{
|
||||||
|
template <typename T>
|
||||||
|
class queue
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using value_type = T;
|
||||||
|
using size_type = typename std::deque<T>::size_type;
|
||||||
|
|
||||||
|
public:
|
||||||
|
queue() = default;
|
||||||
|
|
||||||
|
void push_back(T&& value)
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
m_data.push_back(std::forward<T>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_front(T&& value)
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
m_data.push_front(std::forward<T>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool empty() const
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
return m_data.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_type clear()
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
auto size = m_data.size();
|
||||||
|
m_data.clear();
|
||||||
|
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<T> pop_front()
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
if (m_data.empty()) return std::nullopt;
|
||||||
|
|
||||||
|
auto front = std::move(m_data.front());
|
||||||
|
m_data.pop_front();
|
||||||
|
return front;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<T> pop_back()
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
if (m_data.empty()) return std::nullopt;
|
||||||
|
|
||||||
|
auto back = std::move(m_data.back());
|
||||||
|
m_data.pop_back();
|
||||||
|
return back;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<T> steal()
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
if (m_data.empty()) return std::nullopt;
|
||||||
|
|
||||||
|
auto back = std::move(m_data.back());
|
||||||
|
m_data.pop_back();
|
||||||
|
return back;
|
||||||
|
}
|
||||||
|
|
||||||
|
void rotate_to_front(const T& item)
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
auto iter = std::find(m_data.begin(), m_data.end(), item);
|
||||||
|
|
||||||
|
if (iter != m_data.end())
|
||||||
|
{
|
||||||
|
std::ignore = m_data.erase(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
m_data.push_front(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::optional<T> copy_front_and_rotate_to_back()
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(m_mutex);
|
||||||
|
|
||||||
|
if (m_data.empty()) return std::nullopt;
|
||||||
|
|
||||||
|
auto front = m_data.front();
|
||||||
|
m_data.pop_front();
|
||||||
|
|
||||||
|
m_data.push_back(front);
|
||||||
|
|
||||||
|
return front;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::deque<T> m_data{};
|
||||||
|
std::mutex m_mutex{};
|
||||||
|
};
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user