diff --git a/src/http_server.cpp b/src/http_server.cpp index 531d8da..5fba5a1 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "utils/env.hpp" #include "utils/serialization.hpp" @@ -104,6 +105,11 @@ bool HttpServer::Start(const uint16_t _local_port, DuckDB::Platform()); event_dispatcher = make_uniq(); main_thread = make_uniq(&HttpServer::Run, this); + { + std::lock_guard guard(watcher_mutex); + watcher_should_run = true; + } + watcher_thread = make_uniq(&HttpServer::Watch, this); return true; } @@ -114,6 +120,17 @@ bool HttpServer::Stop() { event_dispatcher->Close(); server.stop(); + + if (watcher_thread) { + { + std::lock_guard guard(watcher_mutex); + watcher_should_run = false; + } + watcher_cv.notify_all(); + watcher_thread->join(); + watcher_thread.reset(); + } + main_thread->join(); main_thread.reset(); event_dispatcher.reset(); @@ -140,6 +157,62 @@ void HttpServer::SendEvent(const std::string &message) { } } +void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) { + bool has_change = false; + duckdb::Connection con{*ddb_instance}; + auto &context = *con.context; + con.BeginTransaction(); + const auto &databases = + ddb_instance->GetDatabaseManager().GetDatabases(context); + std::set db_oids; + + // Check currently attached databases + for (const auto &db_ref : databases) { + auto &db = db_ref.get(); + if (db.IsTemporary()) { + continue; // ignore temp databases + } + + db_oids.insert(db.oid); + auto &catalog = db.GetCatalog(); + auto current_version = catalog.GetCatalogVersion(context); + auto last_version_it = last_state.db_to_catalog_version.find(db.oid); + if (last_version_it == last_state.db_to_catalog_version.end() // first time + || !(last_version_it->second == current_version)) { // updated + has_change = true; + last_state.db_to_catalog_version[db.oid] = current_version; + } + } + + // Now check if any databases have been detached + for (auto it = last_state.db_to_catalog_version.begin(); + it != last_state.db_to_catalog_version.end();) { + if (db_oids.find(it->first) == db_oids.end()) { + has_change = true; + it = last_state.db_to_catalog_version.erase(it); + } else { + ++it; + } + } + + // If there was any change, notify the UI + if (has_change) { + SendCatalogChangedEvent(); + } +} + +void HttpServer::Watch() { + CatalogState last_state; + while (watcher_should_run) { + WatchForCatalogUpdate(last_state); + { + std::unique_lock lock(watcher_mutex); + watcher_cv.wait_for(lock, + std::chrono::milliseconds(2000)); // TODO - configure + } + } +} + void HttpServer::Run() { server.Get("/localEvents", [&](const httplib::Request &req, httplib::Response &res) { diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index 59fbaf7..e910ff3 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -18,6 +18,10 @@ class MemoryStream; namespace ui { +struct CatalogState { + std::map db_to_catalog_version; +}; + class EventDispatcher { public: bool WaitEvent(httplib::DataSink *sink); @@ -51,6 +55,7 @@ public: private: void SendEvent(const std::string &message); void Run(); + void Watch(); void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); @@ -70,12 +75,20 @@ private: void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); + // Watchers + void WatchForCatalogUpdate(CatalogState &last_state); + uint16_t local_port; std::string remote_url; shared_ptr ddb_instance; std::string user_agent; httplib::Server server; unique_ptr main_thread; + unique_ptr watcher_thread; + std::mutex watcher_mutex; + std::condition_variable watcher_cv; + std::atomic watcher_should_run; + std::mutex connections_mutex; std::unordered_map> connections; unique_ptr event_dispatcher; diff --git a/src/include/utils/helpers.hpp b/src/include/utils/helpers.hpp index 53fa38e..76168e4 100644 --- a/src/include/utils/helpers.hpp +++ b/src/include/utils/helpers.hpp @@ -9,7 +9,7 @@ namespace duckdb { typedef std::string (*simple_tf_t)(ClientContext &); struct RunOnceTableFunctionState : GlobalTableFunctionState { - RunOnceTableFunctionState() : run(false){}; + RunOnceTableFunctionState() : run(false) {}; std::atomic run; static unique_ptr Init(ClientContext &,