From f1cc3c667d37150da7bf10f790243684a53a9db9 Mon Sep 17 00:00:00 2001 From: Yves Date: Fri, 21 Feb 2025 12:46:41 +0100 Subject: [PATCH] Split server in smaller modules --- CMakeLists.txt | 3 + src/event_dispatcher.cpp | 77 ++++++++++ src/http_server.cpp | 237 ++++--------------------------- src/include/event_dispatcher.hpp | 32 +++++ src/include/http_server.hpp | 44 ++---- src/include/utils/md_helpers.hpp | 8 ++ src/include/watcher.hpp | 28 ++++ src/settings.cpp | 2 + src/utils/md_helpers.cpp | 33 +++++ src/watcher.cpp | 120 ++++++++++++++++ 10 files changed, 339 insertions(+), 245 deletions(-) create mode 100644 src/event_dispatcher.cpp create mode 100644 src/include/event_dispatcher.hpp create mode 100644 src/include/utils/md_helpers.hpp create mode 100644 src/include/watcher.hpp create mode 100644 src/utils/md_helpers.cpp create mode 100644 src/watcher.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 047243d..3ef3c84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,12 +16,15 @@ include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib) set(EXTENSION_SOURCES src/ui_extension.cpp + src/event_dispatcher.cpp src/http_server.cpp src/settings.cpp src/state.cpp + src/watcher.cpp src/utils/encoding.cpp src/utils/env.cpp src/utils/helpers.cpp + src/utils/md_helpers.cpp src/utils/serialization.cpp) find_package(Git) diff --git a/src/event_dispatcher.cpp b/src/event_dispatcher.cpp new file mode 100644 index 0000000..8a33548 --- /dev/null +++ b/src/event_dispatcher.cpp @@ -0,0 +1,77 @@ +#include "event_dispatcher.hpp" + +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +namespace httplib = duckdb_httplib_openssl; + +// Chosen to be no more than half of the lesser of the two limits: +// - The default httplib thread pool size = 8 +// - The browser limit on the number of server-sent event connections = 6 +#define MAX_EVENT_WAIT_COUNT 3 + +namespace duckdb { +namespace ui { +// An empty Server-Sent Events message. See +// https://html.spec.whatwg.org/multipage/server-sent-events.html#authoring-notes +constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r"; +constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; + +bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { + std::unique_lock lock(mutex); + // Don't allow too many simultaneous waits, because each consumes a thread in + // the httplib thread pool, and also browsers limit the number of server-sent + // event connections. + if (closed || wait_count >= MAX_EVENT_WAIT_COUNT) { + return false; + } + int target_id = next_id; + wait_count++; + cv.wait_for(lock, std::chrono::seconds(5)); + wait_count--; + if (closed) { + return false; + } + if (current_id == target_id) { + sink->write(message.data(), message.size()); + } else { + // Our wait timer expired. Write an empty, no-op message. + // This enables detecting when the client is gone. + sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); + } + return true; +} + +void EventDispatcher::SendEvent(const std::string &_message) { + std::lock_guard guard(mutex); + if (closed) { + return; + } + + current_id = next_id++; + message = _message; + cv.notify_all(); +} + +void EventDispatcher::SendConnectedEvent(const std::string &token) { + SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); +} + +void EventDispatcher::SendCatalogChangedEvent() { + SendEvent("event: CatalogChangeEvent\ndata:\n\n"); +} + +void EventDispatcher::Close() { + std::lock_guard guard(mutex); + if (closed) { + return; + } + + current_id = next_id++; + closed = true; + cv.notify_all(); +} +} // namespace ui +} // namespace duckdb diff --git a/src/http_server.cpp b/src/http_server.cpp index 73d0ebe..f73daba 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -4,72 +4,18 @@ #include "state.hpp" #include "utils/encoding.hpp" #include "utils/env.hpp" +#include "event_dispatcher.hpp" +#include "utils/md_helpers.hpp" #include "utils/serialization.hpp" +#include "watcher.hpp" #include #include #include #include -// Chosen to be no more than half of the lesser of the two limits: -// - The default httplib thread pool size = 8 -// - The browser limit on the number of server-sent event connections = 6 -#define MAX_EVENT_WAIT_COUNT 3 - namespace duckdb { namespace ui { -// An empty Server-Sent Events message. See -// https://html.spec.whatwg.org/multipage/server-sent-events.html#authoring-notes -constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r"; -constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; - -bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { - std::unique_lock lock(mutex); - // Don't allow too many simultaneous waits, because each consumes a thread in - // the httplib thread pool, and also browsers limit the number of server-sent - // event connections. - if (closed || wait_count >= MAX_EVENT_WAIT_COUNT) { - return false; - } - int target_id = next_id; - wait_count++; - cv.wait_for(lock, std::chrono::seconds(5)); - wait_count--; - if (closed) { - return false; - } - if (current_id == target_id) { - sink->write(message.data(), message.size()); - } else { - // Our wait timer expired. Write an empty, no-op message. - // This enables detecting when the client is gone. - sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); - } - return true; -} - -void EventDispatcher::SendEvent(const std::string &_message) { - std::lock_guard guard(mutex); - if (closed) { - return; - } - - current_id = next_id++; - message = _message; - cv.notify_all(); -} - -void EventDispatcher::Close() { - std::lock_guard guard(mutex); - if (closed) { - return; - } - - current_id = next_id++; - closed = true; - cv.notify_all(); -} - unique_ptr HttpServer::server_instance; HttpServer *HttpServer::GetInstance(ClientContext &context) { @@ -92,11 +38,11 @@ void HttpServer::UpdateDatabaseInstanceIfRunning( void HttpServer::UpdateDatabaseInstance( shared_ptr context_db) { - const auto current_db = server_instance->ddb_instance.lock(); + const auto current_db = server_instance->LockDatabaseInstance(); if (current_db != context_db) { - server_instance->StopWatcher(); // Likely already stopped, but just in case + server_instance->watcher->Stop(); server_instance->ddb_instance = context_db; - server_instance->StartWatcher(); + server_instance->watcher->Start(); } } @@ -105,7 +51,7 @@ bool HttpServer::Started() { } void HttpServer::StopInstance() { - if (server_instance) { + if (Started()) { server_instance->DoStop(); } } @@ -117,9 +63,11 @@ const HttpServer &HttpServer::Start(ClientContext &context, bool *was_started) { } return *GetInstance(context); } + if (was_started) { *was_started = false; } + const auto remote_url = GetRemoteUrl(context); const auto port = GetLocalPort(context); auto server = GetInstance(context); @@ -143,51 +91,36 @@ void HttpServer::DoStart(const uint16_t _local_port, UI_EXTENSION_GIT_SHA, DuckDB::Platform()); event_dispatcher = make_uniq(); main_thread = make_uniq(&HttpServer::Run, this); - StartWatcher(); -} - -void HttpServer::StartWatcher() { - { - std::lock_guard guard(watcher_mutex); - watcher_should_run = true; - } - - if (!watcher_thread) { - watcher_thread = make_uniq(&HttpServer::Watch, this); - } -} - -void HttpServer::StopWatcher() { - if (!watcher_thread) { - return; - } - - { - std::lock_guard guard(watcher_mutex); - watcher_should_run = false; - } - watcher_cv.notify_all(); - watcher_thread->join(); - watcher_thread.reset(); + watcher = make_uniq(*this); + watcher->Start(); } bool HttpServer::Stop() { if (!Started()) { return false; } + server_instance->DoStop(); return true; } void HttpServer::DoStop() { - event_dispatcher->Close(); + if (event_dispatcher) { + event_dispatcher->Close(); + event_dispatcher = nullptr; + } server.stop(); - StopWatcher(); + if (watcher) { + watcher->Stop(); + watcher = nullptr; + } + + if (main_thread) { + main_thread->join(); + main_thread.reset(); + } - main_thread->join(); - main_thread.reset(); - event_dispatcher.reset(); ddb_instance.reset(); remote_url = ""; local_port = 0; @@ -197,124 +130,8 @@ std::string HttpServer::LocalUrl() const { return StringUtil::Format("http://localhost:%d/", local_port); } -void HttpServer::SendConnectedEvent(const std::string &token) { - SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); -} - -void HttpServer::SendCatalogChangedEvent() { - SendEvent("event: CatalogChangeEvent\ndata:\n\n"); -} - -void HttpServer::SendEvent(const std::string &message) { - if (event_dispatcher) { - event_dispatcher->SendEvent(message); - } -} - -bool WasCatalogUpdated(DatabaseInstance &db, Connection &connection, - CatalogState &last_state) { - bool has_change = false; - auto &context = *connection.context; - connection.BeginTransaction(); - - const auto &databases = db.GetDatabaseManager().GetDatabases(context); - std::set db_oids; - - // Check currently attached databases - for (const auto &db_ref : databases) { - auto &db = db_ref.get(); - if (db.IsTemporary()) { - continue; // ignore temp databases - } - - db_oids.insert(db.oid); - auto &catalog = db.GetCatalog(); - auto current_version = catalog.GetCatalogVersion(context); - auto last_version_it = last_state.db_to_catalog_version.find(db.oid); - if (last_version_it == last_state.db_to_catalog_version.end() // first time - || !(last_version_it->second == current_version)) { // updated - has_change = true; - last_state.db_to_catalog_version[db.oid] = current_version; - } - } - - // Now check if any databases have been detached - for (auto it = last_state.db_to_catalog_version.begin(); - it != last_state.db_to_catalog_version.end();) { - if (db_oids.find(it->first) == db_oids.end()) { - has_change = true; - it = last_state.db_to_catalog_version.erase(it); - } else { - ++it; - } - } - - connection.Rollback(); - return has_change; -} - -std::string GetMDToken(Connection &connection) { - auto query_res = connection.Query("CALL GET_MD_TOKEN()"); - if (query_res->HasError()) { - query_res->ThrowError(); - return ""; // unreachable - } - - auto chunk = query_res->Fetch(); - return chunk->GetValue(0, 0).GetValue(); -} - -bool IsMDConnected(Connection &con) { - if (!con.context->db->ExtensionIsLoaded("motherduck")) { - return false; - } - auto query_res = con.Query("CALL MD_IS_CONNECTED()"); - if (query_res->HasError()) { - std::cerr << "Error in IsMDConnected: " << query_res->GetError() - << std::endl; - return false; - } - - auto chunk = query_res->Fetch(); - return chunk->GetValue(0, 0).GetValue(); -} - -void HttpServer::Watch() { - CatalogState last_state; - bool is_md_connected = false; - while (watcher_should_run) { - auto db = ddb_instance.lock(); - if (!db) { - break; // DB went away, nothing to watch - } - - duckdb::Connection con{*db}; - auto polling_interval = GetPollingInterval(*con.context); - if (polling_interval == 0) { - return; // Disable watcher - } - - try { - if (WasCatalogUpdated(*db, con, last_state)) { - SendCatalogChangedEvent(); - } - - if (!is_md_connected && IsMDConnected(con)) { - is_md_connected = true; - SendConnectedEvent(GetMDToken(con)); - } - } catch (std::exception &ex) { - // Do not crash with uncaught exception, but quit. - std::cerr << "Error in watcher: " << ex.what() << std::endl; - std::cerr << "Will now terminate." << std::endl; - return; - } - - { - std::unique_lock lock(watcher_mutex); - watcher_cv.wait_for(lock, std::chrono::milliseconds(polling_interval)); - } - } +shared_ptr HttpServer::LockDatabaseInstance() { + return ddb_instance.lock(); } void HttpServer::Run() { diff --git a/src/include/event_dispatcher.hpp b/src/include/event_dispatcher.hpp new file mode 100644 index 0000000..c062fab --- /dev/null +++ b/src/include/event_dispatcher.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace duckdb_httplib_openssl { +class DataSink; +} + +namespace duckdb { + +namespace ui { + +class EventDispatcher { +public: + void SendConnectedEvent(const std::string &token); + void SendCatalogChangedEvent(); + + bool WaitEvent(duckdb_httplib_openssl::DataSink *sink); + void Close(); + +private: + void SendEvent(const std::string &message); + std::mutex mutex; + std::condition_variable cv; + std::atomic_int next_id{0}; + std::atomic_int current_id{-1}; + std::atomic_int wait_count{0}; + std::string message; + std::atomic_bool closed{false}; +}; +} // namespace ui +} // namespace duckdb diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index ffeb0b2..0c3dccd 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -1,6 +1,6 @@ #pragma once -#include "duckdb.hpp" +#include #define CPPHTTPLIB_OPENSSL_SUPPORT #include "httplib.hpp" @@ -11,6 +11,9 @@ #include #include +#include "watcher.hpp" +#include "event_dispatcher.hpp" + namespace httplib = duckdb_httplib_openssl; namespace duckdb { @@ -18,26 +21,6 @@ class MemoryStream; namespace ui { -struct CatalogState { - std::map db_to_catalog_version; -}; - -class EventDispatcher { -public: - bool WaitEvent(httplib::DataSink *sink); - void SendEvent(const std::string &message); - void Close(); - -private: - std::mutex mutex; - std::condition_variable cv; - std::atomic_int next_id{0}; - std::atomic_int current_id{-1}; - std::atomic_int wait_count{0}; - std::string message; - std::atomic_bool closed{false}; -}; - class HttpServer { public: @@ -53,17 +36,14 @@ public: std::string LocalUrl() const; private: + friend class Watcher; + // Lifecycle void DoStart(const uint16_t local_port, const std::string &remote_url); void DoStop(); void Run(); void UpdateDatabaseInstance(shared_ptr context_db); - // Watcher - void Watch(); - void StartWatcher(); - void StopWatcher(); - // Http handlers void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); @@ -83,10 +63,8 @@ private: void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); - // Events - void SendEvent(const std::string &message); - void SendConnectedEvent(const std::string &token); - void SendCatalogChangedEvent(); + // Misc + shared_ptr LockDatabaseInstance(); uint16_t local_port; std::string remote_url; @@ -94,12 +72,8 @@ private: std::string user_agent; httplib::Server server; unique_ptr main_thread; - unique_ptr watcher_thread; - std::mutex watcher_mutex; - std::condition_variable watcher_cv; - std::atomic watcher_should_run; - unique_ptr event_dispatcher; + unique_ptr watcher; static unique_ptr server_instance; }; diff --git a/src/include/utils/md_helpers.hpp b/src/include/utils/md_helpers.hpp new file mode 100644 index 0000000..d417eec --- /dev/null +++ b/src/include/utils/md_helpers.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace duckdb { +bool IsMDConnected(Connection &); +std::string GetMDToken(Connection &); +} // namespace duckdb diff --git a/src/include/watcher.hpp b/src/include/watcher.hpp new file mode 100644 index 0000000..282a0b2 --- /dev/null +++ b/src/include/watcher.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace duckdb { +namespace ui { +struct CatalogState { + std::map db_to_catalog_version; +}; +class HttpServer; +class Watcher { +public: + Watcher(HttpServer &server); + + void Start(); + void Stop(); + +private: + void Watch(); + unique_ptr thread; + std::mutex mutex; + std::condition_variable cv; + std::atomic should_run; + HttpServer &server; +}; +} // namespace ui +} // namespace duckdb diff --git a/src/settings.cpp b/src/settings.cpp index f5af7e5..ad5c782 100644 --- a/src/settings.cpp +++ b/src/settings.cpp @@ -5,9 +5,11 @@ namespace duckdb { std::string GetRemoteUrl(const ClientContext &context) { return internal::GetSetting(context, UI_REMOTE_URL_SETTING_NAME); } + uint16_t GetLocalPort(const ClientContext &context) { return internal::GetSetting(context, UI_LOCAL_PORT_SETTING_NAME); } + uint32_t GetPollingInterval(const ClientContext &context) { return internal::GetSetting(context, UI_POLLING_INTERVAL_SETTING_NAME); diff --git a/src/utils/md_helpers.cpp b/src/utils/md_helpers.cpp new file mode 100644 index 0000000..a5adf33 --- /dev/null +++ b/src/utils/md_helpers.cpp @@ -0,0 +1,33 @@ +#include "utils/md_helpers.hpp" + +#include +#include + +namespace duckdb { +std::string GetMDToken(Connection &connection) { + auto query_res = connection.Query("CALL GET_MD_TOKEN()"); + if (query_res->HasError()) { + query_res->ThrowError(); + return ""; // unreachable + } + + auto chunk = query_res->Fetch(); + return chunk->GetValue(0, 0).GetValue(); +} + +bool IsMDConnected(Connection &con) { + if (!con.context->db->ExtensionIsLoaded("motherduck")) { + return false; + } + + auto query_res = con.Query("CALL MD_IS_CONNECTED()"); + if (query_res->HasError()) { + std::cerr << "Error in IsMDConnected: " << query_res->GetError() + << std::endl; + return false; + } + + auto chunk = query_res->Fetch(); + return chunk->GetValue(0, 0).GetValue(); +} +} // namespace duckdb diff --git a/src/watcher.cpp b/src/watcher.cpp new file mode 100644 index 0000000..9e64b44 --- /dev/null +++ b/src/watcher.cpp @@ -0,0 +1,120 @@ +#include "watcher.hpp" + +#include + +#include "utils/md_helpers.hpp" +#include "http_server.hpp" +#include "settings.hpp" + +namespace duckdb { +namespace ui { + +Watcher::Watcher(HttpServer &_server) : server(_server), should_run(false) {} + +bool WasCatalogUpdated(DatabaseInstance &db, Connection &connection, + CatalogState &last_state) { + bool has_change = false; + auto &context = *connection.context; + connection.BeginTransaction(); + + const auto &databases = db.GetDatabaseManager().GetDatabases(context); + std::set db_oids; + + // Check currently attached databases + for (const auto &db_ref : databases) { + auto &db = db_ref.get(); + if (db.IsTemporary()) { + continue; // ignore temp databases + } + + db_oids.insert(db.oid); + auto &catalog = db.GetCatalog(); + auto current_version = catalog.GetCatalogVersion(context); + auto last_version_it = last_state.db_to_catalog_version.find(db.oid); + if (last_version_it == last_state.db_to_catalog_version.end() // first time + || !(last_version_it->second == current_version)) { // updated + has_change = true; + last_state.db_to_catalog_version[db.oid] = current_version; + } + } + + // Now check if any databases have been detached + for (auto it = last_state.db_to_catalog_version.begin(); + it != last_state.db_to_catalog_version.end();) { + if (db_oids.find(it->first) == db_oids.end()) { + has_change = true; + it = last_state.db_to_catalog_version.erase(it); + } else { + ++it; + } + } + + connection.Rollback(); + return has_change; +} + +void Watcher::Watch() { + CatalogState last_state; + bool is_md_connected = false; + while (should_run) { + auto db = server.LockDatabaseInstance(); + if (!db) { + break; // DB went away, nothing to watch + } + + duckdb::Connection con{*db}; + auto polling_interval = GetPollingInterval(*con.context); + if (polling_interval == 0) { + return; // Disable watcher + } + + try { + if (WasCatalogUpdated(*db, con, last_state)) { + server.event_dispatcher->SendCatalogChangedEvent(); + } + + if (!is_md_connected && IsMDConnected(con)) { + is_md_connected = true; + server.event_dispatcher->SendConnectedEvent(GetMDToken(con)); + } + } catch (std::exception &ex) { + // Do not crash with uncaught exception, but quit. + std::cerr << "Error in watcher: " << ex.what() << std::endl; + std::cerr << "Will now terminate." << std::endl; + return; + } + + { + std::unique_lock lock(mutex); + cv.wait_for(lock, std::chrono::milliseconds(polling_interval)); + } + } +} + +void Watcher::Start() { + { + std::lock_guard guard(mutex); + should_run = true; + } + + if (!thread) { + thread = make_uniq(&Watcher::Watch, this); + } +} + +void Watcher::Stop() { + if (!thread) { + return; + } + + { + std::lock_guard guard(mutex); + should_run = false; + } + cv.notify_all(); + thread->join(); + thread.reset(); +} + +} // namespace ui +} // namespace duckdb