From 5811f237d822158377e995055400ec01231addb1 Mon Sep 17 00:00:00 2001 From: Yves Date: Wed, 19 Feb 2025 11:20:42 +0100 Subject: [PATCH] Format --- CMakeLists.txt | 32 +- src/http_server.cpp | 624 +++++++++++++++------------- src/include/http_server.hpp | 96 +++-- src/include/ui_extension.hpp | 6 +- src/include/utils/helpers.hpp | 105 +++-- src/include/utils/serialization.hpp | 30 +- src/ui_extension.cpp | 129 +++--- src/utils/encoding.cpp | 81 ++-- src/utils/env.cpp | 32 +- src/utils/helpers.cpp | 25 +- src/utils/serialization.cpp | 45 +- 11 files changed, 636 insertions(+), 569 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a648346..aa254f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,39 +3,31 @@ cmake_minimum_required(VERSION 3.5...3.31.5) # Set extension name here set(TARGET_NAME ui) -# DuckDB's extension distribution supports vcpkg. As such, dependencies can be added in ./vcpkg.json and then -# used in cmake with find_package. Feel free to remove or replace with other dependencies. -# Note that it should also be removed from vcpkg.json to prevent needlessly installing it.. +# DuckDB's extension distribution supports vcpkg. As such, dependencies can be +# added in ./vcpkg.json and then used in cmake with find_package. Feel free to +# remove or replace with other dependencies. Note that it should also be removed +# from vcpkg.json to prevent needlessly installing it.. find_package(OpenSSL REQUIRED) set(EXTENSION_NAME ${TARGET_NAME}_extension) project(${TARGET_NAME}) -include_directories( - src/include - ${DuckDB_SOURCE_DIR}/third_party/httplib -) +include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib) set(EXTENSION_SOURCES - src/ui_extension.cpp - src/http_server.cpp - src/utils/encoding.cpp - src/utils/env.cpp - src/utils/helpers.cpp - src/utils/serialization.cpp -) + src/ui_extension.cpp src/http_server.cpp src/utils/encoding.cpp + src/utils/env.cpp src/utils/helpers.cpp src/utils/serialization.cpp) find_package(Git) -if (NOT Git_FOUND) +if(NOT Git_FOUND) message(FATAL_ERROR "Git not found, unable to determine git sha") endif() execute_process( - COMMAND ${GIT_EXECUTABLE} rev-parse --short=10 HEAD - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - OUTPUT_VARIABLE UI_EXTENSION_GIT_SHA - OUTPUT_STRIP_TRAILING_WHITESPACE -) + COMMAND ${GIT_EXECUTABLE} rev-parse --short=10 HEAD + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + OUTPUT_VARIABLE UI_EXTENSION_GIT_SHA + OUTPUT_STRIP_TRAILING_WHITESPACE) message(STATUS "UI_EXTENSION_GIT_SHA=${UI_EXTENSION_GIT_SHA}") add_definitions(-DUI_EXTENSION_GIT_SHA="${UI_EXTENSION_GIT_SHA}") diff --git a/src/http_server.cpp b/src/http_server.cpp index a9580df..1e530cb 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -21,401 +21,437 @@ constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r"; constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { - std::unique_lock lock(mutex_); - // Don't allow too many simultaneous waits, because each consumes a thread in the httplib thread pool, and also - // browsers limit the number of server-sent event connections. - if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) { - return false; - } - int target_id = next_id_; - wait_count_++; - cv_.wait_for(lock, std::chrono::seconds(5)); - wait_count_--; - if (closed_) { - return false; - } - if (current_id_ == target_id) { - sink->write(message_.data(), message_.size()); - } else { - // Our wait timer expired. Write an empty, no-op message. - // This enables detecting when the client is gone. - sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); - } - return true; + std::unique_lock lock(mutex_); + // Don't allow too many simultaneous waits, because each consumes a thread in + // the httplib thread pool, and also browsers limit the number of server-sent + // event connections. + if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) { + return false; + } + int target_id = next_id_; + wait_count_++; + cv_.wait_for(lock, std::chrono::seconds(5)); + wait_count_--; + if (closed_) { + return false; + } + if (current_id_ == target_id) { + sink->write(message_.data(), message_.size()); + } else { + // Our wait timer expired. Write an empty, no-op message. + // This enables detecting when the client is gone. + sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); + } + return true; } void EventDispatcher::SendEvent(const std::string &message) { - std::lock_guard guard(mutex_); - if (closed_) { - return; - } + std::lock_guard guard(mutex_); + if (closed_) { + return; + } - current_id_ = next_id_++; - message_ = message; - cv_.notify_all(); + current_id_ = next_id_++; + message_ = message; + cv_.notify_all(); } void EventDispatcher::Close() { - std::lock_guard guard(mutex_); - if (closed_) { - return; - } - current_id_ = next_id_++; - closed_ = true; - cv_.notify_all(); + std::lock_guard guard(mutex_); + if (closed_) { + return; + } + current_id_ = next_id_++; + closed_ = true; + cv_.notify_all(); } unique_ptr HttpServer::instance_; -HttpServer* HttpServer::instance() { - if (!instance_) { - instance_ = make_uniq(); - std::atexit(HttpServer::StopInstance); - } - return instance_.get(); +HttpServer *HttpServer::instance() { + if (!instance_) { + instance_ = make_uniq(); + std::atexit(HttpServer::StopInstance); + } + return instance_.get(); } -bool HttpServer::Started() { - return instance_ && instance_->thread_; -} +bool HttpServer::Started() { return instance_ && instance_->thread_; } void HttpServer::StopInstance() { - if (instance_) { - instance_->Stop(); - } + if (instance_) { + instance_->Stop(); + } } bool HttpServer::Start(const uint16_t local_port, const std::string &remote_url, - const shared_ptr &ddb_instance) { - if (thread_) { - return false; - } + const shared_ptr &ddb_instance) { + if (thread_) { + return false; + } - local_port_ = local_port; - remote_url_ = remote_url; - ddb_instance_ = ddb_instance; + local_port_ = local_port; + remote_url_ = remote_url; + ddb_instance_ = ddb_instance; #ifndef UI_EXTENSION_GIT_SHA #error "UI_EXTENSION_GIT_SHA must be defined" #endif - user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA, DuckDB::Platform()); - event_dispatcher_ = make_uniq(); - thread_ = make_uniq(&HttpServer::Run, this); - return true; + user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA, + DuckDB::Platform()); + event_dispatcher_ = make_uniq(); + thread_ = make_uniq(&HttpServer::Run, this); + return true; } bool HttpServer::Stop() { - if (!thread_) { - return false; - } + if (!thread_) { + return false; + } - event_dispatcher_->Close(); - server_.stop(); - thread_->join(); - thread_.reset(); - event_dispatcher_.reset(); - connections_.clear(); - ddb_instance_.reset(); - remote_url_ = ""; - local_port_ = 0; - return true; + event_dispatcher_->Close(); + server_.stop(); + thread_->join(); + thread_.reset(); + event_dispatcher_.reset(); + connections_.clear(); + ddb_instance_.reset(); + remote_url_ = ""; + local_port_ = 0; + return true; } -uint16_t HttpServer::LocalPort() { - return local_port_; -} +uint16_t HttpServer::LocalPort() { return local_port_; } void HttpServer::SendConnectedEvent(const std::string &token) { - SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); + SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); } void HttpServer::SendCatalogChangedEvent() { - SendEvent("event: CatalogChangeEvent\ndata:\n\n"); + SendEvent("event: CatalogChangeEvent\ndata:\n\n"); } void HttpServer::SendEvent(const std::string &message) { - if (event_dispatcher_) { - event_dispatcher_->SendEvent(message); - } + if (event_dispatcher_) { + event_dispatcher_->SendEvent(message); + } } void HttpServer::Run() { - server_.Get("/localEvents", - [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalEvents(req, res); }); - server_.Get("/localToken", - [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalToken(req, res); }); - server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { HandleGet(req, res); }); - server_.Post("/ddb/interrupt", - [&](const httplib::Request &req, httplib::Response &res) { HandleInterrupt(req, res); }); - server_.Post("/ddb/run", - [&](const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &content_reader) { HandleRun(req, res, content_reader); }); - server_.Post("/ddb/tokenize", - [&](const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &content_reader) { HandleTokenize(req, res, content_reader); }); - server_.listen("localhost", local_port_); + server_.Get("/localEvents", + [&](const httplib::Request &req, httplib::Response &res) { + HandleGetLocalEvents(req, res); + }); + server_.Get("/localToken", + [&](const httplib::Request &req, httplib::Response &res) { + HandleGetLocalToken(req, res); + }); + server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { + HandleGet(req, res); + }); + server_.Post("/ddb/interrupt", + [&](const httplib::Request &req, httplib::Response &res) { + HandleInterrupt(req, res); + }); + server_.Post("/ddb/run", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + HandleRun(req, res, content_reader); + }); + server_.Post("/ddb/tokenize", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + HandleTokenize(req, res, content_reader); + }); + server_.listen("localhost", local_port_); } -void HttpServer::HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res) { - res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { - if (event_dispatcher_->WaitEvent(&sink)) { - return true; - } - sink.done(); - return false; - }); +void HttpServer::HandleGetLocalEvents(const httplib::Request &req, + httplib::Response &res) { + res.set_chunked_content_provider( + "text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { + if (event_dispatcher_->WaitEvent(&sink)) { + return true; + } + sink.done(); + return false; + }); } -void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { - if (!ddb_instance_->ExtensionIsLoaded("motherduck")) { - res.set_content("", "text/plain"); // UI expects an empty response if the extension is not loaded - return; - } +void HttpServer::HandleGetLocalToken(const httplib::Request &req, + httplib::Response &res) { + if (!ddb_instance_->ExtensionIsLoaded("motherduck")) { + res.set_content("", "text/plain"); // UI expects an empty response if the + // extension is not loaded + return; + } - Connection connection(*ddb_instance_); - auto query_res = connection.Query("CALL get_md_token()"); - if (query_res->HasError()) { - res.status = 500; - res.set_content("Could not get token: " + query_res->GetError(), "text/plain"); - return; - } + Connection connection(*ddb_instance_); + auto query_res = connection.Query("CALL get_md_token()"); + if (query_res->HasError()) { + res.status = 500; + res.set_content("Could not get token: " + query_res->GetError(), + "text/plain"); + return; + } - auto chunk = query_res->Fetch(); - auto token = chunk->GetValue(0, 0).GetValue(); - res.status = 200; - res.set_content(token, "text/plain"); + auto chunk = query_res->Fetch(); + auto token = chunk->GetValue(0, 0).GetValue(); + res.status = 200; + res.set_content(token, "text/plain"); } -void HttpServer::HandleGet(const httplib::Request &req, httplib::Response &res) { - // Create HTTP client to remote URL - // TODO: Can this be created once and shared? - httplib::Client client(remote_url_); - client.set_keep_alive(true); +void HttpServer::HandleGet(const httplib::Request &req, + httplib::Response &res) { + // Create HTTP client to remote URL + // TODO: Can this be created once and shared? + httplib::Client client(remote_url_); + client.set_keep_alive(true); - // Provide a way to turn on or off server certificate verification, at least for now, because it requires httplib to - // correctly get the root certficates on each platform, which doesn't appear to always work. - // Currently, default to no verification, until we understand when it breaks things. - if (IsEnvEnabled("ui_enable_server_certificate_verification")) { - client.enable_server_certificate_verification(true); - } else { - client.enable_server_certificate_verification(false); - } + // Provide a way to turn on or off server certificate verification, at least + // for now, because it requires httplib to correctly get the root certficates + // on each platform, which doesn't appear to always work. Currently, default + // to no verification, until we understand when it breaks things. + if (IsEnvEnabled("ui_enable_server_certificate_verification")) { + client.enable_server_certificate_verification(true); + } else { + client.enable_server_certificate_verification(false); + } - // forward GET to remote URL - auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}}); - if (!result) { - res.status = 500; - return; - } + // forward GET to remote URL + auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}}); + if (!result) { + res.status = 500; + return; + } - // Repond with result of forwarded GET - res = result.value(); + // Repond with result of forwarded GET + res = result.value(); - // If this is the config request, set the X-MD-DuckDB-Mode header to HTTP. - // The UI looks for this to select the appropriate DuckDB mode (HTTP or Wasm). - if (req.path == "/config") { - res.set_header("X-MD-DuckDB-Mode", "HTTP"); - } + // If this is the config request, set the X-MD-DuckDB-Mode header to HTTP. + // The UI looks for this to select the appropriate DuckDB mode (HTTP or Wasm). + if (req.path == "/config") { + res.set_header("X-MD-DuckDB-Mode", "HTTP"); + } } -void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { - auto description = req.get_header_value("X-MD-Description"); - auto connection_name = req.get_header_value("X-MD-Connection-Name"); +void HttpServer::HandleInterrupt(const httplib::Request &req, + httplib::Response &res) { + auto description = req.get_header_value("X-MD-Description"); + auto connection_name = req.get_header_value("X-MD-Connection-Name"); - auto connection = FindConnection(connection_name); - if (!connection) { - res.status = 404; - return; - } + auto connection = FindConnection(connection_name); + if (!connection) { + res.status = 404; + return; + } - connection->Interrupt(); + connection->Interrupt(); - SetResponseEmptyResult(res); + SetResponseEmptyResult(res); } void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + try { + DoHandleRun(req, res, content_reader); + } catch (const std::exception &ex) { + SetResponseErrorResult(res, ex.what()); + } +} + +void HttpServer::DoHandleRun(const httplib::Request &req, + httplib::Response &res, const httplib::ContentReader &content_reader) { - try { - auto description = req.get_header_value("X-MD-Description"); - auto connection_name = req.get_header_value("X-MD-Connection-Name"); + auto description = req.get_header_value("X-MD-Description"); + auto connection_name = req.get_header_value("X-MD-Connection-Name"); - auto database_name = DecodeBase64(req.get_header_value("X-MD-Database-Name")); - auto parameter_count = req.get_header_value_count("X-MD-Parameter"); + auto database_name = DecodeBase64(req.get_header_value("X-MD-Database-Name")); + auto parameter_count = req.get_header_value_count("X-MD-Parameter"); - std::string content = ReadContent(content_reader); + std::string content = ReadContent(content_reader); + auto connection = FindOrCreateConnection(connection_name); - auto connection = FindOrCreateConnection(connection_name); + // Set current database if optional header is provided. + if (!database_name.empty()) { + connection->context->RunFunctionInTransaction([&] { + ddb_instance_->GetDatabaseManager().SetDefaultDatabase( + *connection->context, database_name); + }); + } - // Set current database if optional header is provided. - if (!database_name.empty()) { - connection->context->RunFunctionInTransaction( - [&] { ddb_instance_->GetDatabaseManager().SetDefaultDatabase(*connection->context, database_name); }); - } + // We use a pending query so we can execute tasks and fetch chunks + // incrementally. This enables cancellation. + unique_ptr pending; - // We use a pending query so we can execute tasks and fetch chunks incrementally. - // This enables cancellation. - unique_ptr pending; + // Create pending query, with request content as SQL. + if (parameter_count > 0) { + auto prepared = connection->Prepare(content); + if (prepared->HasError()) { + SetResponseErrorResult(res, prepared->GetError()); + return; + } - // Create pending query, with request content as SQL. - if (parameter_count > 0) { - auto prepared = connection->Prepare(content); - if (prepared->HasError()) { - SetResponseErrorResult(res, prepared->GetError()); - return; - } + vector values; + for (idx_t i = 0; i < parameter_count; ++i) { + auto parameter = DecodeBase64(req.get_header_value("X-MD-Parameter", i)); + values.push_back( + Value(parameter)); // TODO: support non-string parameters? (SURF-1546) + } + pending = prepared->PendingQuery(values, true); + } else { + pending = connection->PendingQuery(content, true); + } - vector values; - for (idx_t i = 0; i < parameter_count; ++i) { - auto parameter = DecodeBase64(req.get_header_value("X-MD-Parameter", i)); - values.push_back(Value(parameter)); // TODO: support non-string parameters? (SURF-1546) - } - pending = prepared->PendingQuery(values, true); - } else { - pending = connection->PendingQuery(content, true); - } + if (pending->HasError()) { + SetResponseErrorResult(res, pending->GetError()); + return; + } - if (pending->HasError()) { - SetResponseErrorResult(res, pending->GetError()); - return; - } + // Execute tasks until result is ready (or there's an error). + auto exec_result = PendingExecutionResult::RESULT_NOT_READY; + while (!PendingQueryResult::IsResultReady(exec_result)) { + exec_result = pending->ExecuteTask(); + if (exec_result == PendingExecutionResult::BLOCKED || + exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } - // Execute tasks until result is ready (or there's an error). - auto exec_result = PendingExecutionResult::RESULT_NOT_READY; - while (!PendingQueryResult::IsResultReady(exec_result)) { - exec_result = pending->ExecuteTask(); - if (exec_result == PendingExecutionResult::BLOCKED || - exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - } + switch (exec_result) { - switch (exec_result) { + case PendingExecutionResult::EXECUTION_ERROR: + SetResponseErrorResult(res, pending->GetError()); + break; - case PendingExecutionResult::EXECUTION_ERROR: { - SetResponseErrorResult(res, pending->GetError()); - } break; + case PendingExecutionResult::EXECUTION_FINISHED: + case PendingExecutionResult::RESULT_READY: { + // Get the result. This should be quick because it's ready. + auto result = pending->Execute(); - case PendingExecutionResult::EXECUTION_FINISHED: - case PendingExecutionResult::RESULT_READY: { - // Get the result. This should be quick because it's ready. - auto result = pending->Execute(); + // Fetch the chunks and serialize the result. + SuccessResult success_result; + success_result.column_names_and_types = {std::move(result->names), + std::move(result->types)}; - // Fetch the chunks and serialize the result. - SuccessResult success_result; - success_result.column_names_and_types = {std::move(result->names), std::move(result->types)}; + // TODO: support limiting the number of chunks fetched (SURF-1540) + auto chunk = result->Fetch(); + while (chunk) { + success_result.chunks.push_back( + {static_cast(chunk->size()), std::move(chunk->data)}); + chunk = result->Fetch(); + } - // TODO: support limiting the number of chunks fetched (SURF-1540) - auto chunk = result->Fetch(); - while (chunk) { - success_result.chunks.push_back({static_cast(chunk->size()), std::move(chunk->data)}); - chunk = result->Fetch(); - } - - MemoryStream success_response_content; - BinarySerializer::Serialize(success_result, success_response_content); - SetResponseContent(res, success_response_content); - } break; - - default: { - SetResponseErrorResult(res, "Unexpected PendingExecutionResult"); - } break; - } - - } catch (const std::exception &ex) { - SetResponseErrorResult(res, ex.what()); - } + MemoryStream success_response_content; + BinarySerializer::Serialize(success_result, success_response_content); + SetResponseContent(res, success_response_content); + break; + } + default: + SetResponseErrorResult(res, "Unexpected PendingExecutionResult"); + break; + } } -void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &content_reader) { - auto description = req.get_header_value("X-MD-Description"); +void HttpServer::HandleTokenize(const httplib::Request &req, + httplib::Response &res, + const httplib::ContentReader &content_reader) { + auto description = req.get_header_value("X-MD-Description"); - std::string content = ReadContent(content_reader); + std::string content = ReadContent(content_reader); + auto tokens = Parser::Tokenize(content); - auto tokens = Parser::Tokenize(content); + // Read and serialize result + TokenizeResult result; + result.offsets.reserve(tokens.size()); + result.types.reserve(tokens.size()); - // Read and serialize result - TokenizeResult result; - result.offsets.reserve(tokens.size()); - result.types.reserve(tokens.size()); + for (auto token : tokens) { + result.offsets.push_back(token.start); + result.types.push_back(token.type); + } - for (auto token : tokens) { - result.offsets.push_back(token.start); - result.types.push_back(token.type); - } - - MemoryStream response_content; - BinarySerializer::Serialize(result, response_content); - SetResponseContent(res, response_content); + MemoryStream response_content; + BinarySerializer::Serialize(result, response_content); + SetResponseContent(res, response_content); } -std::string HttpServer::ReadContent(const httplib::ContentReader &content_reader) { - std::ostringstream oss; - content_reader([&](const char *data, size_t data_length) { - oss.write(data, data_length); - return true; - }); - return oss.str(); +std::string +HttpServer::ReadContent(const httplib::ContentReader &content_reader) { + std::ostringstream oss; + content_reader([&](const char *data, size_t data_length) { + oss.write(data, data_length); + return true; + }); + return oss.str(); } -shared_ptr HttpServer::FindConnection(const std::string &connection_name) { - if (connection_name.empty()) { - return nullptr; - } +shared_ptr +HttpServer::FindConnection(const std::string &connection_name) { + if (connection_name.empty()) { + return nullptr; + } - // Need to protect access to the connections map because this can be called from multiple threads. - std::lock_guard guard(connections_mutex_); + // Need to protect access to the connections map because this can be called + // from multiple threads. + std::lock_guard guard(connections_mutex_); - auto result = connections_.find(connection_name); - if (result != connections_.end()) { - return result->second; - } + auto result = connections_.find(connection_name); + if (result != connections_.end()) { + return result->second; + } - return nullptr; + return nullptr; } -shared_ptr HttpServer::FindOrCreateConnection(const std::string &connection_name) { - if (connection_name.empty()) { - // If no connection name was provided, create and return a new connection but don't remember it. - return make_shared_ptr(*ddb_instance_); - } +shared_ptr +HttpServer::FindOrCreateConnection(const std::string &connection_name) { + if (connection_name.empty()) { + // If no connection name was provided, create and return a new connection + // but don't remember it. + return make_shared_ptr(*ddb_instance_); + } - // Need to protect access to the connections map because this can be called from multiple threads. - std::lock_guard guard(connections_mutex_); + // Need to protect access to the connections map because this can be called + // from multiple threads. + std::lock_guard guard(connections_mutex_); - // If an existing connection with the provided name was found, return it. - auto result = connections_.find(connection_name); - if (result != connections_.end()) { - return result->second; - } + // If an existing connection with the provided name was found, return it. + auto result = connections_.find(connection_name); + if (result != connections_.end()) { + return result->second; + } - // Otherwise, create a new one, remember it, and return it. - auto connection = make_shared_ptr(*ddb_instance_); - connections_[connection_name] = connection; - return connection; + // Otherwise, create a new one, remember it, and return it. + auto connection = make_shared_ptr(*ddb_instance_); + connections_[connection_name] = connection; + return connection; } -void HttpServer::SetResponseContent(httplib::Response &res, const MemoryStream &content) { - auto data = content.GetData(); - auto length = content.GetPosition(); - res.set_content(reinterpret_cast(data), length, "application/octet-stream"); +void HttpServer::SetResponseContent(httplib::Response &res, + const MemoryStream &content) { + auto data = content.GetData(); + auto length = content.GetPosition(); + res.set_content(reinterpret_cast(data), length, + "application/octet-stream"); } void HttpServer::SetResponseEmptyResult(httplib::Response &res) { - EmptyResult empty_result; - MemoryStream response_content; - BinarySerializer::Serialize(empty_result, response_content); - SetResponseContent(res, response_content); + EmptyResult empty_result; + MemoryStream response_content; + BinarySerializer::Serialize(empty_result, response_content); + SetResponseContent(res, response_content); } -void HttpServer::SetResponseErrorResult(httplib::Response &res, const std::string &error) { - ErrorResult error_result; - error_result.error = error; - MemoryStream response_content; - BinarySerializer::Serialize(error_result, response_content); - SetResponseContent(res, response_content); +void HttpServer::SetResponseErrorResult(httplib::Response &res, + const std::string &error) { + ErrorResult error_result; + error_result.error = error; + MemoryStream response_content; + BinarySerializer::Serialize(error_result, response_content); + SetResponseContent(res, response_content); } } // namespace ui -} // namespace md +} // namespace duckdb diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index 12e811e..4fe1d1c 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -20,63 +20,69 @@ namespace ui { class EventDispatcher { public: - bool WaitEvent(httplib::DataSink *sink); - void SendEvent(const std::string &message); - void Close(); + bool WaitEvent(httplib::DataSink *sink); + void SendEvent(const std::string &message); + void Close(); private: - std::mutex mutex_; - std::condition_variable cv_; - std::atomic_int next_id_ {0}; - std::atomic_int current_id_ {-1}; - std::atomic_int wait_count_ {0}; - std::string message_; - std::atomic_bool closed_ {false}; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic_int next_id_{0}; + std::atomic_int current_id_{-1}; + std::atomic_int wait_count_{0}; + std::string message_; + std::atomic_bool closed_{false}; }; class HttpServer { public: - static HttpServer* instance(); - static bool Started(); - static void StopInstance(); + static HttpServer *instance(); + static bool Started(); + static void StopInstance(); - bool Start(const uint16_t localPort, const std::string &remoteUrl, - const shared_ptr &ddbInstance); - bool Stop(); - uint16_t LocalPort(); - void SendConnectedEvent(const std::string &token); - void SendCatalogChangedEvent(); + bool Start(const uint16_t localPort, const std::string &remoteUrl, + const shared_ptr &ddbInstance); + bool Stop(); + uint16_t LocalPort(); + void SendConnectedEvent(const std::string &token); + void SendCatalogChangedEvent(); private: - void SendEvent(const std::string &message); - void Run(); - void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); - void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); - void HandleGet(const httplib::Request &req, httplib::Response &res); - void HandleInterrupt(const httplib::Request &req, httplib::Response &res); - void HandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &contentReader); - void HandleTokenize(const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &contentReader); - std::string ReadContent(const httplib::ContentReader &contentReader); - shared_ptr FindConnection(const std::string &connectionName); - shared_ptr FindOrCreateConnection(const std::string &connectionName); - void SetResponseContent(httplib::Response &res, const MemoryStream &content); - void SetResponseEmptyResult(httplib::Response &res); - void SetResponseErrorResult(httplib::Response &res, const std::string &error); + void SendEvent(const std::string &message); + void Run(); + void HandleGetLocalEvents(const httplib::Request &req, + httplib::Response &res); + void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); + void HandleGet(const httplib::Request &req, httplib::Response &res); + void HandleInterrupt(const httplib::Request &req, httplib::Response &res); + void DoHandleRun(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &contentReader); + void HandleRun(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &contentReader); + void HandleTokenize(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &contentReader); + std::string ReadContent(const httplib::ContentReader &contentReader); + shared_ptr FindConnection(const std::string &connectionName); + shared_ptr + FindOrCreateConnection(const std::string &connectionName); + void SetResponseContent(httplib::Response &res, const MemoryStream &content); + void SetResponseEmptyResult(httplib::Response &res); + void SetResponseErrorResult(httplib::Response &res, const std::string &error); - uint16_t local_port_; - std::string remote_url_; - shared_ptr ddb_instance_; - std::string user_agent_; - httplib::Server server_; - unique_ptr thread_; - std::mutex connections_mutex_; - std::unordered_map> connections_; - unique_ptr event_dispatcher_; + uint16_t local_port_; + std::string remote_url_; + shared_ptr ddb_instance_; + std::string user_agent_; + httplib::Server server_; + unique_ptr thread_; + std::mutex connections_mutex_; + std::unordered_map> connections_; + unique_ptr event_dispatcher_; - static unique_ptr instance_; -};; + static unique_ptr instance_; +}; +; } // namespace ui } // namespace duckdb diff --git a/src/include/ui_extension.hpp b/src/include/ui_extension.hpp index 0f2e73f..de39c83 100644 --- a/src/include/ui_extension.hpp +++ b/src/include/ui_extension.hpp @@ -6,9 +6,9 @@ namespace duckdb { class UiExtension : public Extension { public: - void Load(DuckDB &db) override; - std::string Name() override; - std::string Version() const override; + void Load(DuckDB &db) override; + std::string Name() override; + std::string Version() const override; }; } // namespace duckdb diff --git a/src/include/utils/helpers.hpp b/src/include/utils/helpers.hpp index 7c90af9..53fa38e 100644 --- a/src/include/utils/helpers.hpp +++ b/src/include/utils/helpers.hpp @@ -6,82 +6,95 @@ namespace duckdb { -typedef std::string (*simple_tf_t) (ClientContext &); +typedef std::string (*simple_tf_t)(ClientContext &); struct RunOnceTableFunctionState : GlobalTableFunctionState { - RunOnceTableFunctionState() : run(false) {}; - std::atomic run; + RunOnceTableFunctionState() : run(false){}; + std::atomic run; - static unique_ptr Init(ClientContext &, - TableFunctionInitInput &) { - return make_uniq(); - } + static unique_ptr Init(ClientContext &, + TableFunctionInitInput &) { + return make_uniq(); + } }; template -T GetSetting(const ClientContext &context, const char *setting_name, const T default_value) { - Value value; - return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue() : default_value; +T GetSetting(const ClientContext &context, const char *setting_name, + const T default_value) { + Value value; + return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue() + : default_value; } namespace internal { unique_ptr ResultBind(ClientContext &, TableFunctionBindInput &, - vector &, - vector &); + vector &, + vector &); bool ShouldRun(TableFunctionInput &input); -template -struct CallFunctionHelper; +template struct CallFunctionHelper; -template <> -struct CallFunctionHelper { - static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)()) { - return f(); - } +template <> struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, + std::string (*f)()) { + return f(); + } +}; + +template <> struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, + std::string (*f)(ClientContext &)) { + return f(context); + } }; template <> -struct CallFunctionHelper { - static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &)) { - return f(context); - } -}; - -template <> -struct CallFunctionHelper { - static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &, TableFunctionInput &)) { - return f(context, input); - } +struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, + std::string (*f)(ClientContext &, + TableFunctionInput &)) { + return f(context, input); + } }; template -void TableFunc(ClientContext &context, TableFunctionInput &input, DataChunk &output) { - if (!ShouldRun(input)) { - return; - } +void TableFunc(ClientContext &context, TableFunctionInput &input, + DataChunk &output) { + if (!ShouldRun(input)) { + return; + } - const std::string result = CallFunctionHelper::call(context, input, func); - output.SetCardinality(1); - output.SetValue(0, 0, result); + const std::string result = + CallFunctionHelper::call(context, input, func); + output.SetCardinality(1); + output.SetValue(0, 0, result); } template -void RegisterTF(DatabaseInstance &instance, const char* name) { - TableFunction tf(name, {}, internal::TableFunc, internal::ResultBind, RunOnceTableFunctionState::Init); - ExtensionUtil::RegisterFunction(instance, tf); +void RegisterTF(DatabaseInstance &instance, const char *name) { + TableFunction tf(name, {}, internal::TableFunc, + internal::ResultBind, RunOnceTableFunctionState::Init); + ExtensionUtil::RegisterFunction(instance, tf); } template -void RegisterTFWithArgs(DatabaseInstance &instance, const char* name, vector arguments, table_function_bind_t bind) { - TableFunction tf(name, arguments, internal::TableFunc, bind, RunOnceTableFunctionState::Init); - ExtensionUtil::RegisterFunction(instance, tf); +void RegisterTFWithArgs(DatabaseInstance &instance, const char *name, + vector arguments, + table_function_bind_t bind) { + TableFunction tf(name, arguments, internal::TableFunc, bind, + RunOnceTableFunctionState::Init); + ExtensionUtil::RegisterFunction(instance, tf); } -} +} // namespace internal -#define RESISTER_TF(name, func) internal::RegisterTF(instance, name) +#define RESISTER_TF(name, func) \ + internal::RegisterTF(instance, name) -#define RESISTER_TF_ARGS(name, args, func, bind) internal::RegisterTFWithArgs(instance, name, args, bind) +#define RESISTER_TF_ARGS(name, args, func, bind) \ + internal::RegisterTFWithArgs(instance, name, args, \ + bind) } // namespace duckdb diff --git a/src/include/utils/serialization.hpp b/src/include/utils/serialization.hpp index 9772a4c..47955a4 100644 --- a/src/include/utils/serialization.hpp +++ b/src/include/utils/serialization.hpp @@ -8,41 +8,41 @@ namespace duckdb { namespace ui { struct EmptyResult { - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; struct TokenizeResult { - duckdb::vector offsets; - duckdb::vector types; + duckdb::vector offsets; + duckdb::vector types; - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; struct ColumnNamesAndTypes { - duckdb::vector names; - duckdb::vector types; + duckdb::vector names; + duckdb::vector types; - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; struct Chunk { - uint16_t row_count; - duckdb::vector vectors; + uint16_t row_count; + duckdb::vector vectors; - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; struct SuccessResult { - ColumnNamesAndTypes column_names_and_types; - duckdb::vector chunks; + ColumnNamesAndTypes column_names_and_types; + duckdb::vector chunks; - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; struct ErrorResult { - std::string error; + std::string error; - void Serialize(duckdb::Serializer &serializer) const; + void Serialize(duckdb::Serializer &serializer) const; }; } // namespace ui diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index 2963cac..0d88fcf 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -15,110 +15,123 @@ #define OPEN_COMMAND "open" #endif -#define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" -#define UI_LOCAL_PORT_SETTING_DESCRIPTION "Local port on which the UI server listens" -#define UI_LOCAL_PORT_SETTING_DEFAULT 4213 +#define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" +#define UI_LOCAL_PORT_SETTING_DESCRIPTION \ + "Local port on which the UI server listens" +#define UI_LOCAL_PORT_SETTING_DEFAULT 4213 -#define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" -#define UI_REMOTE_URL_SETTING_DESCRIPTION "Remote URL to which the UI server forwards GET requests" -#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com" +#define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" +#define UI_REMOTE_URL_SETTING_DESCRIPTION \ + "Remote URL to which the UI server forwards GET requests" +#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com" namespace duckdb { namespace internal { bool StartHttpServer(const ClientContext &context) { - const auto url = GetSetting(context, UI_REMOTE_URL_SETTING_NAME, - GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT)); - const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DEFAULT);; - return ui::HttpServer::instance()->Start(port, url, context.db); + const auto url = + GetSetting(context, UI_REMOTE_URL_SETTING_NAME, + GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, + UI_REMOTE_URL_SETTING_DEFAULT)); + const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, + UI_LOCAL_PORT_SETTING_DEFAULT); + ; + return ui::HttpServer::instance()->Start(port, url, context.db); } std::string GetHttpServerLocalURL() { - return StringUtil::Format("http://localhost:%d/", ui::HttpServer::instance()->LocalPort()); + return StringUtil::Format("http://localhost:%d/", + ui::HttpServer::instance()->LocalPort()); } } // namespace internal std::string StartUIFunction(ClientContext &context) { - internal::StartHttpServer(context); - auto local_url = internal::GetHttpServerLocalURL(); + internal::StartHttpServer(context); + auto local_url = internal::GetHttpServerLocalURL(); - const std::string command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); - return system(command.c_str()) ? - StringUtil::Format("Navigate browser to %s", local_url) // open command failed - : StringUtil::Format("MotherDuck UI started at %s", local_url); + const std::string command = + StringUtil::Format("%s %s", OPEN_COMMAND, local_url); + return system(command.c_str()) + ? StringUtil::Format("Navigate browser to %s", + local_url) // open command failed + : StringUtil::Format("MotherDuck UI started at %s", local_url); } std::string StartUIServerFunction(ClientContext &context) { - const char* already = internal::StartHttpServer(context) ? "already " : ""; - return StringUtil::Format("MotherDuck UI server %sstarted at %s", already, internal::GetHttpServerLocalURL()); + const char *already = internal::StartHttpServer(context) ? "already " : ""; + return StringUtil::Format("MotherDuck UI server %sstarted at %s", already, + internal::GetHttpServerLocalURL()); } std::string StopUIServerFunction() { - return ui::HttpServer::instance()->Stop() ? "UI server stopped" : "UI server already stopped"; + return ui::HttpServer::instance()->Stop() ? "UI server stopped" + : "UI server already stopped"; } // Connected notification struct NotifyConnectedFunctionData : public TableFunctionData { - NotifyConnectedFunctionData(std::string _token) : token(_token) {} + NotifyConnectedFunctionData(std::string _token) : token(_token) {} - std::string token; + std::string token; }; -static unique_ptr NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input, - vector &out_types, vector &out_names) { - if (input.inputs[0].IsNull()) { - throw BinderException("Must provide a token"); - } +static unique_ptr +NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input, + vector &out_types, vector &out_names) { + if (input.inputs[0].IsNull()) { + throw BinderException("Must provide a token"); + } - out_names.emplace_back("result"); - out_types.emplace_back(LogicalType::VARCHAR); - return make_uniq(input.inputs[0].ToString()); + out_names.emplace_back("result"); + out_types.emplace_back(LogicalType::VARCHAR); + return make_uniq(input.inputs[0].ToString()); } -std::string NotifyConnectedFunction(ClientContext &context, TableFunctionInput &input) { - auto &inputs = input.bind_data->Cast(); - ui::HttpServer::instance()->SendConnectedEvent(inputs.token); - return "OK"; +std::string NotifyConnectedFunction(ClientContext &context, + TableFunctionInput &input) { + auto &inputs = input.bind_data->Cast(); + ui::HttpServer::instance()->SendConnectedEvent(inputs.token); + return "OK"; } // - connected notification std::string NotifyCatalogChangedFunction() { - ui::HttpServer::instance()->SendCatalogChangedEvent(); - return "OK"; + ui::HttpServer::instance()->SendCatalogChangedEvent(); + return "OK"; } static void LoadInternal(DatabaseInstance &instance) { - auto &config = DBConfig::GetConfig(instance); + auto &config = DBConfig::GetConfig(instance); - config.AddExtensionOption(UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION, - LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); + config.AddExtensionOption( + UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION, + LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); - config.AddExtensionOption( - UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION, LogicalType::VARCHAR, - Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT))); + config.AddExtensionOption( + UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION, + LogicalType::VARCHAR, + Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, + UI_REMOTE_URL_SETTING_DEFAULT))); - RESISTER_TF("start_ui", StartUIFunction); - RESISTER_TF("start_ui_server", StartUIServerFunction); - RESISTER_TF("stop_ui_server", StopUIServerFunction); - RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction); - RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR}, NotifyConnectedFunction, NotifyConnectedBind); + RESISTER_TF("start_ui", StartUIFunction); + RESISTER_TF("start_ui_server", StartUIServerFunction); + RESISTER_TF("stop_ui_server", StopUIServerFunction); + RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction); + RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR}, + NotifyConnectedFunction, NotifyConnectedBind); } -void UiExtension::Load(DuckDB &db) { - LoadInternal(*db.instance); -} -std::string UiExtension::Name() { - return "ui"; -} +void UiExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } +std::string UiExtension::Name() { return "ui"; } std::string UiExtension::Version() const { #ifdef UI_EXTENSION_GIT_SHA - return UI_EXTENSION_GIT_SHA; + return UI_EXTENSION_GIT_SHA; #else - return ""; + return ""; #endif } @@ -127,12 +140,12 @@ std::string UiExtension::Version() const { extern "C" { DUCKDB_EXTENSION_API void ui_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); + duckdb::DuckDB db_wrapper(db); + db_wrapper.LoadExtension(); } DUCKDB_EXTENSION_API const char *ui_version() { - return duckdb::DuckDB::LibraryVersion(); + return duckdb::DuckDB::LibraryVersion(); } } diff --git a/src/utils/encoding.cpp b/src/utils/encoding.cpp index 32054db..137f5c9 100644 --- a/src/utils/encoding.cpp +++ b/src/utils/encoding.cpp @@ -5,57 +5,60 @@ namespace duckdb { -// Copied from https://www.mycplus.com/source-code/c-source-code/base64-encode-decode/ -constexpr char k_encoding_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+_"; +// Copied from +// https://www.mycplus.com/source-code/c-source-code/base64-encode-decode/ +constexpr char k_encoding_table[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+_"; std::vector BuildDecodingTable() { - std::vector decoding_table; - decoding_table.resize(256); - for (int i = 0; i < 64; ++i) { - decoding_table[static_cast(k_encoding_table[i])] = i; - } - return decoding_table; + std::vector decoding_table; + decoding_table.resize(256); + for (int i = 0; i < 64; ++i) { + decoding_table[static_cast(k_encoding_table[i])] = i; + } + return decoding_table; } const static std::vector k_decoding_table = BuildDecodingTable(); std::string DecodeBase64(const std::string &data) { - size_t input_length = data.size(); - if (input_length < 4 || input_length % 4 != 0) { - // Handle this exception - return ""; - } + size_t input_length = data.size(); + if (input_length < 4 || input_length % 4 != 0) { + // Handle this exception + return ""; + } - size_t output_length = input_length / 4 * 3; - if (data[input_length - 1] == '=') { - output_length--; - } - if (data[input_length - 2] == '=') { - output_length--; - } + size_t output_length = input_length / 4 * 3; + if (data[input_length - 1] == '=') { + output_length--; + } + if (data[input_length - 2] == '=') { + output_length--; + } - std::string decoded_data; - decoded_data.resize(output_length); - for (size_t i = 0, j = 0; i < input_length;) { - uint32_t sextet_a = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; - uint32_t sextet_b = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; - uint32_t sextet_c = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; - uint32_t sextet_d = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + std::string decoded_data; + decoded_data.resize(output_length); + for (size_t i = 0, j = 0; i < input_length;) { + uint32_t sextet_a = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_b = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_c = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_d = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; - uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); + uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); - if (j < output_length) { - decoded_data[j++] = (triple >> 2 * 8) & 0xFF; - } - if (j < output_length) { - decoded_data[j++] = (triple >> 1 * 8) & 0xFF; - } - if (j < output_length) { - decoded_data[j++] = (triple >> 0 * 8) & 0xFF; - } - } + if (j < output_length) { + decoded_data[j++] = (triple >> 2 * 8) & 0xFF; + } + if (j < output_length) { + decoded_data[j++] = (triple >> 1 * 8) & 0xFF; + } + if (j < output_length) { + decoded_data[j++] = (triple >> 0 * 8) & 0xFF; + } + } - return decoded_data; + return decoded_data; } } // namespace duckdb diff --git a/src/utils/env.cpp b/src/utils/env.cpp index 6c96110..720cce7 100644 --- a/src/utils/env.cpp +++ b/src/utils/env.cpp @@ -6,29 +6,29 @@ namespace duckdb { const char *TryGetEnv(const char *name) { - const char *res = std::getenv(name); - if (res) { - return res; - } - return std::getenv(StringUtil::Upper(name).c_str()); + const char *res = std::getenv(name); + if (res) { + return res; + } + return std::getenv(StringUtil::Upper(name).c_str()); } std::string GetEnvOrDefault(const char *name, const char *default_value) { - const char *res = TryGetEnv(name); - if (res) { - return res; - } - return default_value; + const char *res = TryGetEnv(name); + if (res) { + return res; + } + return default_value; } bool IsEnvEnabled(const char *name) { - const char *res = TryGetEnv(name); - if (!res) { - return false; - } + const char *res = TryGetEnv(name); + if (!res) { + return false; + } - auto lc_res = StringUtil::Lower(res); - return lc_res == "1" || lc_res == "true"; + auto lc_res = StringUtil::Lower(res); + return lc_res == "1" || lc_res == "true"; } } // namespace duckdb diff --git a/src/utils/helpers.cpp b/src/utils/helpers.cpp index c359306..4de921d 100644 --- a/src/utils/helpers.cpp +++ b/src/utils/helpers.cpp @@ -4,22 +4,23 @@ namespace duckdb { namespace internal { bool ShouldRun(TableFunctionInput &input) { - auto state = dynamic_cast(input.global_state.get()); - D_ASSERT(state != nullptr); - if (state->run) { - return false; - } + auto state = + dynamic_cast(input.global_state.get()); + D_ASSERT(state != nullptr); + if (state->run) { + return false; + } - state->run = true; - return true; + state->run = true; + return true; } unique_ptr ResultBind(ClientContext &, TableFunctionBindInput &, - vector &out_types, - vector &out_names) { - out_names.emplace_back("result"); - out_types.emplace_back(LogicalType::VARCHAR); - return nullptr; + vector &out_types, + vector &out_names) { + out_names.emplace_back("result"); + out_types.emplace_back(LogicalType::VARCHAR); + return nullptr; } } // namespace internal diff --git a/src/utils/serialization.cpp b/src/utils/serialization.cpp index c04fc12..d0e097d 100644 --- a/src/utils/serialization.cpp +++ b/src/utils/serialization.cpp @@ -6,43 +6,46 @@ namespace duckdb { namespace ui { -void EmptyResult::Serialize(Serializer &) const { -} +void EmptyResult::Serialize(Serializer &) const {} void TokenizeResult::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "offsets", offsets); - serializer.WriteProperty(101, "types", types); + serializer.WriteProperty(100, "offsets", offsets); + serializer.WriteProperty(101, "types", types); } // Adapted from parts of DataChunk::Serialize void ColumnNamesAndTypes::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "names", names); - serializer.WriteProperty(101, "types", types); + serializer.WriteProperty(100, "names", names); + serializer.WriteProperty(101, "types", types); } // Adapted from parts of DataChunk::Serialize void Chunk::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "row_count", row_count); - serializer.WriteList(101, "vectors", vectors.size(), [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { - // Reference the vector to avoid potentially mutating it during serialization - Vector serialized_vector(vectors[i].GetType()); - serialized_vector.Reference(vectors[i]); - serialized_vector.Serialize(object, row_count); - }); - }); + serializer.WriteProperty(100, "row_count", row_count); + serializer.WriteList(101, "vectors", vectors.size(), + [&](Serializer::List &list, idx_t i) { + list.WriteObject([&](Serializer &object) { + // Reference the vector to avoid potentially mutating + // it during serialization + Vector serialized_vector(vectors[i].GetType()); + serialized_vector.Reference(vectors[i]); + serialized_vector.Serialize(object, row_count); + }); + }); } void SuccessResult::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "success", true); - serializer.WriteProperty(101, "column_names_and_types", column_names_and_types); - serializer.WriteList(102, "chunks", chunks.size(), - [&](Serializer::List &list, idx_t i) { list.WriteElement(chunks[i]); }); + serializer.WriteProperty(100, "success", true); + serializer.WriteProperty(101, "column_names_and_types", + column_names_and_types); + serializer.WriteList( + 102, "chunks", chunks.size(), + [&](Serializer::List &list, idx_t i) { list.WriteElement(chunks[i]); }); } void ErrorResult::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "success", false); - serializer.WriteProperty(101, "error", error); + serializer.WriteProperty(100, "success", false); + serializer.WriteProperty(101, "error", error); } } // namespace ui