diff --git a/CMakeLists.txt b/CMakeLists.txt index aa254f4..8011a16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,8 +15,13 @@ project(${TARGET_NAME}) include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib) set(EXTENSION_SOURCES - src/ui_extension.cpp src/http_server.cpp src/utils/encoding.cpp - src/utils/env.cpp src/utils/helpers.cpp src/utils/serialization.cpp) + src/ui_extension.cpp + src/http_server.cpp + src/state.cpp + src/utils/encoding.cpp + src/utils/env.cpp + src/utils/helpers.cpp + src/utils/serialization.cpp) find_package(Git) if(NOT Git_FOUND) diff --git a/src/http_server.cpp b/src/http_server.cpp index 5fba5a1..d9507af 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -7,6 +7,7 @@ #include "utils/env.hpp" #include "utils/serialization.hpp" #include "utils/encoding.hpp" +#include "state.hpp" // Chosen to be no more than half of the lesser of the two limits: // - The default httplib thread pool size = 8 @@ -70,14 +71,34 @@ void EventDispatcher::Close() { unique_ptr HttpServer::server_instance; -HttpServer *HttpServer::instance() { - if (!server_instance) { - server_instance = make_uniq(); +HttpServer *HttpServer::GetInstance(ClientContext &context) { + if (server_instance) { + // We already have an instance, make sure we're running on the right DB + server_instance->UpdateDatabaseInstance(context.db); + } else { + server_instance = make_uniq(context.db); std::atexit(HttpServer::StopInstance); } return server_instance.get(); } +void HttpServer::UpdateDatabaseInstanceIfRunning( + shared_ptr db) { + if (server_instance) { + server_instance->UpdateDatabaseInstance(db); + } +} + +void HttpServer::UpdateDatabaseInstance( + shared_ptr context_db) { + const auto current_db = server_instance->ddb_instance.lock(); + if (current_db != context_db) { + server_instance->StopWatcher(); // Likely already stopped, but just in case + server_instance->ddb_instance = context_db; + server_instance->StartWatcher(); + } +} + bool HttpServer::Started() { return server_instance && server_instance->main_thread; } @@ -88,16 +109,18 @@ void HttpServer::StopInstance() { } } -bool HttpServer::Start(const uint16_t _local_port, - const std::string &_remote_url, - const shared_ptr &_ddb_instance) { +const HttpServer &HttpServer::Start(const uint16_t _local_port, + const std::string &_remote_url, + bool *was_started) { if (main_thread) { - return false; + if (was_started) { + *was_started = true; + } + return *this; } local_port = _local_port; remote_url = _remote_url; - ddb_instance = _ddb_instance; #ifndef UI_EXTENSION_GIT_SHA #error "UI_EXTENSION_GIT_SHA must be defined" #endif @@ -105,12 +128,33 @@ bool HttpServer::Start(const uint16_t _local_port, DuckDB::Platform()); event_dispatcher = make_uniq(); main_thread = make_uniq(&HttpServer::Run, this); + StartWatcher(); + return *this; +} + +void HttpServer::StartWatcher() { { std::lock_guard guard(watcher_mutex); watcher_should_run = true; } - watcher_thread = make_uniq(&HttpServer::Watch, this); - return true; + + if (!watcher_thread) { + watcher_thread = make_uniq(&HttpServer::Watch, this); + } +} + +void HttpServer::StopWatcher() { + if (!watcher_thread) { + return; + } + + { + std::lock_guard guard(watcher_mutex); + watcher_should_run = false; + } + watcher_cv.notify_all(); + watcher_thread->join(); + watcher_thread.reset(); } bool HttpServer::Stop() { @@ -121,27 +165,20 @@ bool HttpServer::Stop() { event_dispatcher->Close(); server.stop(); - if (watcher_thread) { - { - std::lock_guard guard(watcher_mutex); - watcher_should_run = false; - } - watcher_cv.notify_all(); - watcher_thread->join(); - watcher_thread.reset(); - } + StopWatcher(); main_thread->join(); main_thread.reset(); event_dispatcher.reset(); - connections.clear(); ddb_instance.reset(); remote_url = ""; local_port = 0; return true; } -uint16_t HttpServer::LocalPort() { return local_port; } +std::string HttpServer::LocalUrl() const { + return StringUtil::Format("http://localhost:%d/", local_port); +} void HttpServer::SendConnectedEvent(const std::string &token) { SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); @@ -157,13 +194,14 @@ void HttpServer::SendEvent(const std::string &message) { } } -void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) { +void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db, + CatalogState &last_state) { bool has_change = false; - duckdb::Connection con{*ddb_instance}; + duckdb::Connection con{db}; auto &context = *con.context; con.BeginTransaction(); - const auto &databases = - ddb_instance->GetDatabaseManager().GetDatabases(context); + + const auto &databases = db.GetDatabaseManager().GetDatabases(context); std::set db_oids; // Check currently attached databases @@ -204,7 +242,11 @@ void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) { void HttpServer::Watch() { CatalogState last_state; while (watcher_should_run) { - WatchForCatalogUpdate(last_state); + auto db = ddb_instance.lock(); + if (!db) { + break; // DB went away, nothing to watch + } + WatchForCatalogUpdate(*db, last_state); { std::unique_lock lock(watcher_mutex); watcher_cv.wait_for(lock, @@ -256,13 +298,21 @@ void HttpServer::HandleGetLocalEvents(const httplib::Request &req, void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { - if (!ddb_instance->ExtensionIsLoaded("motherduck")) { + auto db = ddb_instance.lock(); + if (!db) { + res.status = 500; + res.set_content("Database was invalidated, UI needs to be restarted", + "text/plain"); + return; + } + + if (!db->ExtensionIsLoaded("motherduck")) { res.set_content("", "text/plain"); // UI expects an empty response if the // extension is not loaded return; } - Connection connection(*ddb_instance); + Connection connection(*db); auto query_res = connection.Query("CALL get_md_token()"); if (query_res->HasError()) { res.status = 500; @@ -316,7 +366,14 @@ void HttpServer::HandleInterrupt(const httplib::Request &req, auto description = req.get_header_value("X-MD-Description"); auto connection_name = req.get_header_value("X-MD-Connection-Name"); - auto connection = FindConnection(connection_name); + auto db = ddb_instance.lock(); + if (!db) { + res.status = 404; + return; + } + + auto connection = + UIStorageExtensionInfo::GetState(*db).FindConnection(connection_name); if (!connection) { res.status = 404; return; @@ -347,13 +404,23 @@ void HttpServer::DoHandleRun(const httplib::Request &req, std::string content = ReadContent(content_reader); - auto connection = FindOrCreateConnection(connection_name); + auto db = ddb_instance.lock(); + if (!db) { + SetResponseErrorResult( + res, "Database was invalidated, UI needs to be restarted"); + return; + } + + auto connection = + UIStorageExtensionInfo::GetState(*db).FindOrCreateConnection( + *db, connection_name); // Set current database if optional header is provided. if (!database_name.empty()) { - connection->context->RunFunctionInTransaction([&] { - ddb_instance->GetDatabaseManager().SetDefaultDatabase( - *connection->context, database_name); + auto &context = *connection->context; + context.RunFunctionInTransaction([&] { + auto &manager = context.db->GetDatabaseManager(); + manager.SetDefaultDatabase(context, database_name); }); } @@ -464,48 +531,6 @@ HttpServer::ReadContent(const httplib::ContentReader &content_reader) { return oss.str(); } -shared_ptr -HttpServer::FindConnection(const std::string &connection_name) { - if (connection_name.empty()) { - return nullptr; - } - - // Need to protect access to the connections map because this can be called - // from multiple threads. - std::lock_guard guard(connections_mutex); - - auto result = connections.find(connection_name); - if (result != connections.end()) { - return result->second; - } - - return nullptr; -} - -shared_ptr -HttpServer::FindOrCreateConnection(const std::string &connection_name) { - if (connection_name.empty()) { - // If no connection name was provided, create and return a new connection - // but don't remember it. - return make_shared_ptr(*ddb_instance); - } - - // Need to protect access to the connections map because this can be called - // from multiple threads. - std::lock_guard guard(connections_mutex); - - // If an existing connection with the provided name was found, return it. - auto result = connections.find(connection_name); - if (result != connections.end()) { - return result->second; - } - - // Otherwise, create a new one, remember it, and return it. - auto connection = make_shared_ptr(*ddb_instance); - connections[connection_name] = connection; - return connection; -} - void HttpServer::SetResponseContent(httplib::Response &res, const MemoryStream &content) { auto data = content.GetData(); diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index e910ff3..52d0f60 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -41,46 +41,51 @@ private: class HttpServer { public: - static HttpServer *instance(); + HttpServer(shared_ptr _ddb_instance) + : ddb_instance(_ddb_instance) {} + static HttpServer *GetInstance(ClientContext &); + static void UpdateDatabaseInstanceIfRunning(shared_ptr); static bool Started(); static void StopInstance(); - bool Start(const uint16_t localPort, const std::string &remoteUrl, - const shared_ptr &ddbInstance); + const HttpServer &Start(const uint16_t local_port, + const std::string &remote_url, + bool *was_started = nullptr); bool Stop(); - uint16_t LocalPort(); + std::string LocalUrl() const; void SendConnectedEvent(const std::string &token); void SendCatalogChangedEvent(); private: + void UpdateDatabaseInstance(shared_ptr context_db); void SendEvent(const std::string &message); void Run(); void Watch(); + void StartWatcher(); + void StopWatcher(); void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); void HandleGet(const httplib::Request &req, httplib::Response &res); void HandleInterrupt(const httplib::Request &req, httplib::Response &res); void DoHandleRun(const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &contentReader); + const httplib::ContentReader &content_reader); void HandleRun(const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &contentReader); + const httplib::ContentReader &content_reader); void HandleTokenize(const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &contentReader); - std::string ReadContent(const httplib::ContentReader &contentReader); - shared_ptr FindConnection(const std::string &connectionName); - shared_ptr - FindOrCreateConnection(const std::string &connectionName); + const httplib::ContentReader &content_reader); + std::string ReadContent(const httplib::ContentReader &content_reader); + void SetResponseContent(httplib::Response &res, const MemoryStream &content); void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); // Watchers - void WatchForCatalogUpdate(CatalogState &last_state); + void WatchForCatalogUpdate(DatabaseInstance &, CatalogState &last_state); uint16_t local_port; std::string remote_url; - shared_ptr ddb_instance; + weak_ptr ddb_instance; std::string user_agent; httplib::Server server; unique_ptr main_thread; @@ -89,8 +94,6 @@ private: std::condition_variable watcher_cv; std::atomic watcher_should_run; - std::mutex connections_mutex; - std::unordered_map> connections; unique_ptr event_dispatcher; static unique_ptr server_instance; diff --git a/src/include/state.hpp b/src/include/state.hpp new file mode 100644 index 0000000..a7d2bfc --- /dev/null +++ b/src/include/state.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +namespace duckdb { +const static std::string STORAGE_EXTENSION_KEY = "ui"; + +class UIStorageExtensionInfo : public StorageExtensionInfo { +public: + static UIStorageExtensionInfo &GetState(const DatabaseInstance &instance); + + shared_ptr FindConnection(const std::string &connection_name); + shared_ptr + FindOrCreateConnection(DatabaseInstance &db, + const std::string &connection_name); + +private: + std::mutex connections_mutex; + std::unordered_map> connections; +}; + +} // namespace duckdb diff --git a/src/include/utils/helpers.hpp b/src/include/utils/helpers.hpp index 76168e4..609791b 100644 --- a/src/include/utils/helpers.hpp +++ b/src/include/utils/helpers.hpp @@ -35,12 +35,6 @@ bool ShouldRun(TableFunctionInput &input); template struct CallFunctionHelper; -template <> struct CallFunctionHelper { - static std::string call(ClientContext &context, TableFunctionInput &input, - std::string (*f)()) { - return f(); - } -}; template <> struct CallFunctionHelper { static std::string call(ClientContext &context, TableFunctionInput &input, diff --git a/src/state.cpp b/src/state.cpp new file mode 100644 index 0000000..035b6ca --- /dev/null +++ b/src/state.cpp @@ -0,0 +1,60 @@ +#include "state.hpp" + +#include + +namespace duckdb { + +UIStorageExtensionInfo & +UIStorageExtensionInfo::GetState(const DatabaseInstance &instance) { + auto &config = instance.config; + auto it = config.storage_extensions.find(STORAGE_EXTENSION_KEY); + if (it == config.storage_extensions.end()) { + throw std::runtime_error( + "Fatal error: couldn't find the UI extension state."); + } + return *static_cast(it->second->storage_info.get()); +} + +shared_ptr +UIStorageExtensionInfo::FindConnection(const std::string &connection_name) { + if (connection_name.empty()) { + return nullptr; + } + + // Need to protect access to the connections map because this can be called + // from multiple threads. + std::lock_guard guard(connections_mutex); + + auto result = connections.find(connection_name); + if (result != connections.end()) { + return result->second; + } + + return nullptr; +} + +shared_ptr UIStorageExtensionInfo::FindOrCreateConnection( + DatabaseInstance &db, const std::string &connection_name) { + if (connection_name.empty()) { + // If no connection name was provided, create and return a new connection + // but don't remember it. + return make_shared_ptr(db); + } + + // If an existing connection with the provided name was found, return it. + auto connection = FindConnection(connection_name); + if (connection) { + return connection; + } + + // Otherwise, create a new one, remember it, and return it. + auto new_con = make_shared_ptr(db); + + // Need to protect access to the connections map because this can be called + // from multiple threads. + std::lock_guard guard(connections_mutex); + connections[connection_name] = new_con; + return new_con; +} + +} // namespace duckdb diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index 0d88fcf..a8b393f 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -4,6 +4,7 @@ #include "utils/helpers.hpp" #include "ui_extension.hpp" #include "http_server.hpp" +#include "state.hpp" #include #include @@ -29,27 +30,23 @@ namespace duckdb { namespace internal { -bool StartHttpServer(const ClientContext &context) { - const auto url = +const ui::HttpServer &StartHttpServer(ClientContext &context, + bool *was_started = nullptr) { + const auto remote_url = GetSetting(context, UI_REMOTE_URL_SETTING_NAME, GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT)); const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DEFAULT); - ; - return ui::HttpServer::instance()->Start(port, url, context.db); -} - -std::string GetHttpServerLocalURL() { - return StringUtil::Format("http://localhost:%d/", - ui::HttpServer::instance()->LocalPort()); + return ui::HttpServer::GetInstance(context)->Start(port, remote_url, + was_started); } } // namespace internal std::string StartUIFunction(ClientContext &context) { - internal::StartHttpServer(context); - auto local_url = internal::GetHttpServerLocalURL(); + const auto &server = internal::StartHttpServer(context); + const auto local_url = server.LocalUrl(); const std::string command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); @@ -60,14 +57,17 @@ std::string StartUIFunction(ClientContext &context) { } std::string StartUIServerFunction(ClientContext &context) { - const char *already = internal::StartHttpServer(context) ? "already " : ""; + bool was_started = false; + const auto &server = internal::StartHttpServer(context, &was_started); + const char *already = was_started ? "already " : ""; return StringUtil::Format("MotherDuck UI server %sstarted at %s", already, - internal::GetHttpServerLocalURL()); + server.LocalUrl()); } -std::string StopUIServerFunction() { - return ui::HttpServer::instance()->Stop() ? "UI server stopped" - : "UI server already stopped"; +std::string StopUIServerFunction(ClientContext &context) { + return ui::HttpServer::GetInstance(context)->Stop() + ? "UI server stopped" + : "UI server already stopped"; } // Connected notification @@ -92,20 +92,29 @@ NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input, std::string NotifyConnectedFunction(ClientContext &context, TableFunctionInput &input) { auto &inputs = input.bind_data->Cast(); - ui::HttpServer::instance()->SendConnectedEvent(inputs.token); + ui::HttpServer::GetInstance(context)->SendConnectedEvent(inputs.token); return "OK"; } // - connected notification -std::string NotifyCatalogChangedFunction() { - ui::HttpServer::instance()->SendCatalogChangedEvent(); +std::string NotifyCatalogChangedFunction(ClientContext &context) { + ui::HttpServer::GetInstance(context)->SendCatalogChangedEvent(); return "OK"; } +void InitStorageExtension(duckdb::DatabaseInstance &db) { + auto &config = db.config; + auto ext = duckdb::make_uniq(); + ext->storage_info = duckdb::make_uniq(); + config.storage_extensions[STORAGE_EXTENSION_KEY] = std::move(ext); +} + static void LoadInternal(DatabaseInstance &instance) { auto &config = DBConfig::GetConfig(instance); + InitStorageExtension(instance); + config.AddExtensionOption( UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION, LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); @@ -122,6 +131,11 @@ static void LoadInternal(DatabaseInstance &instance) { RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction); RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR}, NotifyConnectedFunction, NotifyConnectedBind); + + // If the server is already running we need to update the database instance + // since the previous one was invalidated (eg. in the shell when we '.open' + // a new database) + ui::HttpServer::UpdateDatabaseInstanceIfRunning(instance.shared_from_this()); } void UiExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }