#include "http_server.hpp" #include "event_dispatcher.hpp" #include "settings.hpp" #include "state.hpp" #include "utils/encoding.hpp" #include "utils/env.hpp" #include "utils/md_helpers.hpp" #include "utils/serialization.hpp" #include "version.hpp" #include "watcher.hpp" #include #include #include #include #include namespace duckdb { namespace ui { unique_ptr HttpServer::server_instance; HttpServer *HttpServer::GetInstance(ClientContext &context) { if (server_instance) { // We already have an instance, make sure we're running on the right DB server_instance->UpdateDatabaseInstance(context.db); } else { server_instance = make_uniq(context.db); std::atexit(HttpServer::StopInstance); } return server_instance.get(); } void HttpServer::UpdateDatabaseInstanceIfRunning( shared_ptr db) { if (server_instance) { server_instance->UpdateDatabaseInstance(db); } } void HttpServer::UpdateDatabaseInstance( shared_ptr context_db) { const auto current_db = server_instance->LockDatabaseInstance(); if (current_db == context_db) { return; } bool has_watcher = !!server_instance->watcher; if (has_watcher) { server_instance->watcher->Stop(); server_instance->watcher = nullptr; } server_instance->ddb_instance = context_db; if (has_watcher) { server_instance->watcher = make_uniq(*this); server_instance->watcher->Start(); } } bool HttpServer::IsRunningOnMachine(ClientContext &context) { if (Started()) { return true; } const auto local_port = GetLocalPort(context); auto local_url = StringUtil::Format("http://localhost:%d", local_port); httplib::Client client(local_url); return client.Get("/info"); } bool HttpServer::Started() { return server_instance && server_instance->main_thread; } void HttpServer::StopInstance() { if (Started()) { server_instance->DoStop(); } } const HttpServer &HttpServer::Start(ClientContext &context, bool *was_started) { if (Started()) { if (was_started) { *was_started = true; } return *GetInstance(context); } if (was_started) { *was_started = false; } const auto remote_url = GetRemoteUrl(context); const auto port = GetLocalPort(context); auto &http_util = HTTPUtil::Get(*context.db); // FIXME - https://github.com/duckdb/duckdb/pull/17655 will remove `unused` auto http_params = http_util.InitializeParameters(context, "unused"); auto server = GetInstance(context); server->DoStart(port, remote_url, std::move(http_params)); return *server; } void HttpServer::DoStart(const uint16_t _local_port, const std::string &_remote_url, unique_ptr _http_params) { if (Started()) { throw std::runtime_error("HttpServer already started"); } local_port = _local_port; local_url = StringUtil::Format("http://localhost:%d", local_port); remote_url = _remote_url; http_params = std::move(_http_params); user_agent = StringUtil::Format("duckdb-ui/%s-%s(%s)", DuckDB::LibraryVersion(), UI_EXTENSION_VERSION, DuckDB::Platform()); event_dispatcher = make_uniq(); main_thread = make_uniq(&HttpServer::Run, this); watcher = make_uniq(*this); watcher->Start(); } bool HttpServer::Stop() { if (!Started()) { return false; } server_instance->DoStop(); return true; } void HttpServer::DoStop() { if (event_dispatcher) { event_dispatcher->Close(); event_dispatcher = nullptr; } server.stop(); if (watcher) { watcher->Stop(); watcher = nullptr; } if (main_thread) { main_thread->join(); main_thread.reset(); } ddb_instance.reset(); http_params = nullptr; remote_url = ""; local_port = 0; } std::string HttpServer::LocalUrl() const { return StringUtil::Format("http://localhost:%d/", local_port); } shared_ptr HttpServer::LockDatabaseInstance() { return ddb_instance.lock(); } void HttpServer::Run() { server.Get("/info", [&](const httplib::Request &req, httplib::Response &res) { HandleGetInfo(req, res); }); server.Get("/localEvents", [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalEvents(req, res); }); server.Get("/localToken", [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalToken(req, res); }); server.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { HandleGet(req, res); }); server.Post("/ddb/interrupt", [&](const httplib::Request &req, httplib::Response &res) { HandleInterrupt(req, res); }); server.Post("/ddb/run", [&](const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { HandleRun(req, res, content_reader); }); server.Post("/ddb/tokenize", [&](const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { HandleTokenize(req, res, content_reader); }); server.listen("localhost", local_port); } void HttpServer::HandleGetInfo(const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", "*"); res.set_header("X-DuckDB-Version", DuckDB::LibraryVersion()); res.set_header("X-DuckDB-Platform", DuckDB::Platform()); res.set_header("X-DuckDB-UI-Extension-Version", UI_EXTENSION_VERSION); res.set_content("", "text/plain"); } void HttpServer::HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res) { res.set_chunked_content_provider( "text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { if (event_dispatcher->WaitEvent(&sink)) { return true; } sink.done(); return false; }); } void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { // GET requests don't include Origin, so use Referer instead. // Referer includes the path, so only compare the start. auto referer = req.get_header_value("Referer"); if (referer.compare(0, local_url.size(), local_url) != 0) { res.status = 401; return; } auto db = ddb_instance.lock(); if (!db) { res.status = 500; res.set_content("Database was invalidated, UI needs to be restarted", "text/plain"); return; } Connection connection{*db}; try { auto token = GetMDToken(connection); res.status = 200; res.set_content(token, "text/plain"); } catch (std::exception &ex) { res.status = 500; res.set_content("Could not get token: " + std::string(ex.what()), "text/plain"); } } // Adapted from // https://github.com/duckdb/duckdb/blob/1f8b6839ea7864c3e3fb020574f67384cb58124c/src/main/http/http_util.cpp#L129-L147 // Which is not currently exposed. void HttpServer::InitClientFromParams(httplib::Client &client) { auto sec = static_cast(http_params->timeout); auto usec = static_cast(http_params->timeout_usec); client.set_keep_alive(true); client.set_write_timeout(sec, usec); client.set_read_timeout(sec, usec); client.set_connection_timeout(sec, usec); if (!http_params->http_proxy.empty()) { client.set_proxy(http_params->http_proxy, static_cast(http_params->http_proxy_port)); if (!http_params->http_proxy_username.empty()) { client.set_proxy_basic_auth(http_params->http_proxy_username, http_params->http_proxy_password); } } } void HttpServer::HandleGet(const httplib::Request &req, httplib::Response &res) { // Create HTTP client to remote URL // TODO: Can this be created once and shared? httplib::Client client(remote_url); InitClientFromParams(client); if (IsEnvEnabled("ui_disable_server_certificate_verification")) { client.enable_server_certificate_verification(false); } httplib::Headers headers = {{"User-Agent", user_agent}}; auto cookie = req.get_header_value("Cookie"); if (!cookie.empty()) { headers.emplace("Cookie", cookie); } // forward GET to remote URL auto result = client.Get(req.path, req.params, headers); if (!result) { res.status = 500; res.set_content("Could not fetch: '" + req.path + "' from '" + remote_url + "': " + to_string(result.error()), "text/plain"); return; } // Repond with result of forwarded GET res = result.value(); // If this is the config request, return additional information. if (req.path == "/config") { res.set_header("X-DuckDB-Version", DuckDB::LibraryVersion()); res.set_header("X-DuckDB-Platform", DuckDB::Platform()); // The UI looks for this to select the appropriate DuckDB mode (HTTP or // Wasm). res.set_header("X-DuckDB-UI-Extension-Version", UI_EXTENSION_VERSION); } // httplib will set Content-Length, remove it so it is not duplicated. res.headers.erase("Content-Length"); } void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { auto origin = req.get_header_value("Origin"); if (origin != local_url) { res.status = 401; return; } auto description = req.get_header_value("X-DuckDB-UI-Request-Description"); auto connection_name = req.get_header_value("X-DuckDB-UI-Connection-Name"); auto db = ddb_instance.lock(); if (!db) { res.status = 404; return; } auto connection = UIStorageExtensionInfo::GetState(*db).FindConnection(connection_name); if (!connection) { res.status = 404; return; } connection->Interrupt(); SetResponseEmptyResult(res); } void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { try { DoHandleRun(req, res, content_reader); } catch (const std::exception &ex) { SetResponseErrorResult(res, ex.what()); } } void HttpServer::DoHandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { auto origin = req.get_header_value("Origin"); if (origin != local_url) { res.status = 401; return; } auto description = req.get_header_value("X-DuckDB-UI-Request-Description"); auto connection_name = req.get_header_value("X-DuckDB-UI-Connection-Name"); auto database_name = DecodeBase64(req.get_header_value("X-DuckDB-UI-Database-Name")); std::vector parameter_values; auto parameter_count_string = req.get_header_value("X-DuckDB-UI-Parameter-Count"); if (!parameter_count_string.empty()) { auto parameter_count = std::stoi(parameter_count_string); for (idx_t i = 0; i < parameter_count; ++i) { auto parameter_value = DecodeBase64(req.get_header_value( StringUtil::Format("X-DuckDB-UI-Parameter-Value-%d", i))); parameter_values.push_back(parameter_value); } } std::string content = ReadContent(content_reader); auto db = ddb_instance.lock(); if (!db) { SetResponseErrorResult( res, "Database was invalidated, UI needs to be restarted"); return; } auto connection = UIStorageExtensionInfo::GetState(*db).FindOrCreateConnection( *db, connection_name); // Set current database if optional header is provided. if (!database_name.empty()) { auto &context = *connection->context; context.RunFunctionInTransaction([&] { auto &manager = context.db->GetDatabaseManager(); manager.SetDefaultDatabase(context, database_name); }); } // We use a pending query so we can execute tasks and fetch chunks // incrementally. This enables cancellation. unique_ptr pending; // Create pending query, with request content as SQL. if (parameter_values.size() > 0) { auto prepared = connection->Prepare(content); if (prepared->HasError()) { SetResponseErrorResult(res, prepared->GetError()); return; } vector values; for (auto ¶meter_value : parameter_values) { // TODO: support non-string parameters? values.push_back(Value(parameter_value)); } pending = prepared->PendingQuery(values, true); } else { pending = connection->PendingQuery(content, true); } if (pending->HasError()) { SetResponseErrorResult(res, pending->GetError()); return; } // Execute tasks until result is ready (or there's an error). auto exec_result = PendingExecutionResult::RESULT_NOT_READY; while (!PendingQueryResult::IsResultReady(exec_result)) { exec_result = pending->ExecuteTask(); if (exec_result == PendingExecutionResult::BLOCKED || exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } switch (exec_result) { case PendingExecutionResult::EXECUTION_ERROR: SetResponseErrorResult(res, pending->GetError()); break; case PendingExecutionResult::EXECUTION_FINISHED: case PendingExecutionResult::RESULT_READY: { // Get the result. This should be quick because it's ready. auto result = pending->Execute(); // Fetch the chunks and serialize the result. SuccessResult success_result; success_result.column_names_and_types = {std::move(result->names), std::move(result->types)}; // TODO: support limiting the number of chunks fetched auto chunk = result->Fetch(); while (chunk) { success_result.chunks.push_back( {static_cast(chunk->size()), std::move(chunk->data)}); chunk = result->Fetch(); } MemoryStream success_response_content; BinarySerializer::Serialize(success_result, success_response_content); SetResponseContent(res, success_response_content); break; } default: SetResponseErrorResult(res, "Unexpected PendingExecutionResult"); break; } } void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { auto origin = req.get_header_value("Origin"); if (origin != local_url) { res.status = 401; return; } auto description = req.get_header_value("X-DuckDB-UI-Request-Description"); std::string content = ReadContent(content_reader); auto tokens = Parser::Tokenize(content); // Read and serialize result TokenizeResult result; result.offsets.reserve(tokens.size()); result.types.reserve(tokens.size()); for (auto token : tokens) { result.offsets.push_back(token.start); result.types.push_back(token.type); } MemoryStream response_content; BinarySerializer::Serialize(result, response_content); SetResponseContent(res, response_content); } std::string HttpServer::ReadContent(const httplib::ContentReader &content_reader) { std::ostringstream oss; content_reader([&](const char *data, size_t data_length) { oss.write(data, data_length); return true; }); return oss.str(); } void HttpServer::SetResponseContent(httplib::Response &res, const MemoryStream &content) { auto data = content.GetData(); auto length = content.GetPosition(); res.set_content(reinterpret_cast(data), length, "application/octet-stream"); } void HttpServer::SetResponseEmptyResult(httplib::Response &res) { EmptyResult empty_result; MemoryStream response_content; BinarySerializer::Serialize(empty_result, response_content); SetResponseContent(res, response_content); } void HttpServer::SetResponseErrorResult(httplib::Response &res, const std::string &error) { ErrorResult error_result; error_result.error = error; MemoryStream response_content; BinarySerializer::Serialize(error_result, response_content); SetResponseContent(res, response_content); } } // namespace ui } // namespace duckdb