diff --git a/src/http_server.cpp b/src/http_server.cpp index 89cd7e3..13886f1 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -67,7 +67,8 @@ bool HttpServer::IsRunningOnMachine(ClientContext &context) { } const auto local_port = GetLocalPort(context); - auto local_url = StringUtil::Format("http://localhost:%d", local_port); + const auto local_host = GetLocalHost(context); + auto local_url = StringUtil::Format("http://%s:%d", local_host, local_port); httplib::Client client(local_url); return client.Get("/info"); @@ -97,15 +98,17 @@ const HttpServer &HttpServer::Start(ClientContext &context, bool *was_started) { const auto remote_url = GetRemoteUrl(context); const auto port = GetLocalPort(context); + const auto host = GetLocalHost(context); + auto server = GetInstance(context); auto &http_util = HTTPUtil::Get(*context.db); // FIXME - https://github.com/duckdb/duckdb/pull/17655 will remove `unused` auto http_params = http_util.InitializeParameters(context, "unused"); - auto server = GetInstance(context); - server->DoStart(port, remote_url, std::move(http_params)); + server->DoStart(port, host, remote_url, std::move(http_params)); return *server; } void HttpServer::DoStart(const uint16_t _local_port, + const std::string &_local_host, const std::string &_remote_url, unique_ptr _http_params) { if (Started()) { @@ -113,7 +116,8 @@ void HttpServer::DoStart(const uint16_t _local_port, } local_port = _local_port; - local_url = StringUtil::Format("http://localhost:%d", local_port); + local_host = _local_host; + local_url = StringUtil::Format("http://%s:%d", local_host, local_port); remote_url = _remote_url; http_params = std::move(_http_params); user_agent = @@ -155,10 +159,11 @@ void HttpServer::DoStop() { http_params = nullptr; remote_url = ""; local_port = 0; + local_host = ""; } std::string HttpServer::LocalUrl() const { - return StringUtil::Format("http://localhost:%d/", local_port); + return StringUtil::Format("http://%s:%d/", local_host, local_port); } shared_ptr HttpServer::LockDatabaseInstance() { @@ -194,7 +199,7 @@ void HttpServer::Run() { const httplib::ContentReader &content_reader) { HandleTokenize(req, res, content_reader); }); - server.listen("localhost", local_port); + server.listen(local_host, local_port); } void HttpServer::HandleGetInfo(const httplib::Request &req, diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index a22d131..f62f8ed 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -42,8 +42,8 @@ private: friend class Watcher; // Lifecycle - void DoStart(const uint16_t local_port, const std::string &remote_url, - unique_ptr); + void DoStart(const uint16_t local_port, const std::string &local_host, + const std::string &remote_url, unique_ptr); void DoStop(); void Run(); void UpdateDatabaseInstance(shared_ptr context_db); @@ -73,6 +73,7 @@ private: void InitClientFromParams(httplib::Client &); uint16_t local_port; + std::string local_host; std::string local_url; std::string remote_url; weak_ptr ddb_instance; diff --git a/src/include/settings.hpp b/src/include/settings.hpp index b9bdd88..6a0c1ba 100644 --- a/src/include/settings.hpp +++ b/src/include/settings.hpp @@ -5,6 +5,8 @@ #define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" #define UI_LOCAL_PORT_SETTING_DEFAULT 4213 +#define UI_LOCAL_HOST_SETTING_NAME "ui_local_host" +#define UI_LOCAL_HOST_SETTING_DEFAULT "localhost" #define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" #define UI_REMOTE_URL_SETTING_DEFAULT "https://ui.duckdb.org" #define UI_POLLING_INTERVAL_SETTING_NAME "ui_polling_interval" @@ -27,6 +29,7 @@ T GetSetting(const ClientContext &context, const char *setting_name) { std::string GetRemoteUrl(const ClientContext &); uint16_t GetLocalPort(const ClientContext &); +std::string GetLocalHost(const ClientContext &); uint32_t GetPollingInterval(const ClientContext &); } // namespace duckdb diff --git a/src/settings.cpp b/src/settings.cpp index e08aa59..d1af658 100644 --- a/src/settings.cpp +++ b/src/settings.cpp @@ -15,6 +15,10 @@ uint16_t GetLocalPort(const ClientContext &context) { return internal::GetSetting(context, UI_LOCAL_PORT_SETTING_NAME); } +std::string GetLocalHost(const ClientContext &context) { + return internal::GetSetting(context, UI_LOCAL_HOST_SETTING_NAME); +} + uint32_t GetPollingInterval(const ClientContext &context) { return internal::GetSetting(context, UI_POLLING_INTERVAL_SETTING_NAME); diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index 2e012aa..c84ba95 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -108,6 +108,14 @@ static void LoadInternal(DatabaseInstance &instance) { LogicalType::USMALLINT, Value::USMALLINT(default_port)); } + { + auto default_host = GetEnvOrDefault(UI_LOCAL_HOST_SETTING_NAME, + UI_LOCAL_HOST_SETTING_DEFAULT); + config.AddExtensionOption(UI_LOCAL_HOST_SETTING_NAME, + "Local host on which the UI server listens", + LogicalType::VARCHAR, Value(default_host)); + } + { auto def = GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT);