diff --git a/src/http_server.cpp b/src/http_server.cpp index e960702..ad7ad04 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -1,5 +1,6 @@ #include "http_server.hpp" +#include "settings.hpp" #include "state.hpp" #include "utils/encoding.hpp" #include "utils/env.hpp" @@ -105,18 +106,31 @@ bool HttpServer::Started() { void HttpServer::StopInstance() { if (server_instance) { - server_instance->Stop(); + server_instance->DoStop(); } } -const HttpServer &HttpServer::Start(const uint16_t _local_port, - const std::string &_remote_url, - bool *was_started) { - if (main_thread) { +const HttpServer &HttpServer::Start(ClientContext &context, bool *was_started) { + if (Started()) { if (was_started) { *was_started = true; } - return *this; + return *GetInstance(context); + } + if (was_started) { + *was_started = false; + } + const auto remote_url = GetRemoteUrl(context); + const auto port = GetLocalPort(context); + auto server = GetInstance(context); + server->DoStart(port, remote_url); + return *server; +} + +void HttpServer::DoStart(const uint16_t _local_port, + const std::string &_remote_url) { + if (Started()) { + throw std::runtime_error("HttpServer already started"); } local_port = _local_port; @@ -130,7 +144,6 @@ const HttpServer &HttpServer::Start(const uint16_t _local_port, event_dispatcher = make_uniq(); main_thread = make_uniq(&HttpServer::Run, this); StartWatcher(); - return *this; } void HttpServer::StartWatcher() { @@ -159,10 +172,14 @@ void HttpServer::StopWatcher() { } bool HttpServer::Stop() { - if (!main_thread) { + if (!Started()) { return false; } + server_instance->DoStop(); + return true; +} +void HttpServer::DoStop() { event_dispatcher->Close(); server.stop(); @@ -174,7 +191,6 @@ bool HttpServer::Stop() { ddb_instance.reset(); remote_url = ""; local_port = 0; - return true; } std::string HttpServer::LocalUrl() const { diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index 9706a7f..ffeb0b2 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -48,21 +48,23 @@ public: static bool Started(); static void StopInstance(); - const HttpServer &Start(const uint16_t local_port, - const std::string &remote_url, - bool *was_started = nullptr); - bool Stop(); + static const HttpServer &Start(ClientContext &, bool *was_started = nullptr); + static bool Stop(); std::string LocalUrl() const; - void SendConnectedEvent(const std::string &token); - void SendCatalogChangedEvent(); private: - void UpdateDatabaseInstance(shared_ptr context_db); - void SendEvent(const std::string &message); + // Lifecycle + void DoStart(const uint16_t local_port, const std::string &remote_url); + void DoStop(); void Run(); + void UpdateDatabaseInstance(shared_ptr context_db); + + // Watcher void Watch(); void StartWatcher(); void StopWatcher(); + + // Http handlers void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); @@ -76,10 +78,16 @@ private: const httplib::ContentReader &content_reader); std::string ReadContent(const httplib::ContentReader &content_reader); + // Http responses void SetResponseContent(httplib::Response &res, const MemoryStream &content); void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); + // Events + void SendEvent(const std::string &message); + void SendConnectedEvent(const std::string &token); + void SendCatalogChangedEvent(); + uint16_t local_port; std::string remote_url; weak_ptr ddb_instance; diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index 52423ff..da2e0d7 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -1,13 +1,15 @@ #define DUCKDB_EXTENSION_MAIN -#include "ui_extension.hpp" -#include "http_server.hpp" -#include "state.hpp" -#include "utils/env.hpp" -#include "utils/helpers.hpp" #include #include +#include "http_server.hpp" +#include "settings.hpp" +#include "state.hpp" +#include "ui_extension.hpp" +#include "utils/env.hpp" +#include "utils/helpers.hpp" + #ifdef _WIN32 #define OPEN_COMMAND "start" #elif __linux__ @@ -16,36 +18,13 @@ #define OPEN_COMMAND "open" #endif -#define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" -#define UI_LOCAL_PORT_SETTING_DEFAULT 4213 - -#define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" -#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com" - namespace duckdb { -namespace internal { - -const ui::HttpServer &StartHttpServer(ClientContext &context, - bool *was_started = nullptr) { - const auto remote_url = - GetSetting(context, UI_REMOTE_URL_SETTING_NAME, - GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, - UI_REMOTE_URL_SETTING_DEFAULT)); - const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, - UI_LOCAL_PORT_SETTING_DEFAULT); - return ui::HttpServer::GetInstance(context)->Start(port, remote_url, - was_started); -} - -} // namespace internal - std::string StartUIFunction(ClientContext &context) { - const auto &server = internal::StartHttpServer(context); + const auto &server = ui::HttpServer::Start(context); const auto local_url = server.LocalUrl(); - const std::string command = - StringUtil::Format("%s %s", OPEN_COMMAND, local_url); + const auto command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); return system(command.c_str()) ? StringUtil::Format("Navigate browser to %s", local_url) // open command failed @@ -54,25 +33,15 @@ std::string StartUIFunction(ClientContext &context) { std::string StartUIServerFunction(ClientContext &context) { bool was_started = false; - const auto &server = internal::StartHttpServer(context, &was_started); + const auto &server = ui::HttpServer::Start(context, &was_started); const char *already = was_started ? "already " : ""; return StringUtil::Format("UI server %sstarted at %s", already, server.LocalUrl()); } std::string StopUIServerFunction(ClientContext &context) { - return ui::HttpServer::GetInstance(context)->Stop() - ? "UI server stopped" - : "UI server already stopped"; -} - -unique_ptr SingleBoolResultBind(ClientContext &, - TableFunctionBindInput &, - vector &out_types, - vector &out_names) { - out_names.emplace_back("result"); - out_types.emplace_back(LogicalType::BOOLEAN); - return nullptr; + return ui::HttpServer::Stop() ? "UI server stopped" + : "UI server already stopped"; } void IsUIStartedTableFunc(ClientContext &context, TableFunctionInput &input, @@ -101,16 +70,21 @@ static void LoadInternal(DatabaseInstance &instance) { ui::HttpServer::UpdateDatabaseInstanceIfRunning(instance.shared_from_this()); auto &config = DBConfig::GetConfig(instance); - config.AddExtensionOption( - UI_LOCAL_PORT_SETTING_NAME, "Local port on which the UI server listens", - LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); + { + auto default_port = GetEnvOrDefaultInt(UI_LOCAL_PORT_SETTING_NAME, 4213); + config.AddExtensionOption( + UI_LOCAL_PORT_SETTING_NAME, "Local port on which the UI server listens", + LogicalType::USMALLINT, Value::USMALLINT(default_port)); + } - config.AddExtensionOption( - UI_REMOTE_URL_SETTING_NAME, - "Remote URL to which the UI server forwards GET requests", - LogicalType::VARCHAR, - Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, - UI_REMOTE_URL_SETTING_DEFAULT))); + { + auto def = GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, + "https://app.motherduck.com"); + config.AddExtensionOption( + UI_REMOTE_URL_SETTING_NAME, + "Remote URL to which the UI server forwards GET requests", + LogicalType::VARCHAR, Value(def)); + } RESISTER_TF("start_ui", StartUIFunction); RESISTER_TF("start_ui_server", StartUIServerFunction);