diff --git a/CMakeLists.txt b/CMakeLists.txt index 7898b2b..1c434c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.5...3.31.5) # Set extension name here set(TARGET_NAME ui) @@ -9,19 +9,26 @@ set(TARGET_NAME ui) find_package(OpenSSL REQUIRED) set(EXTENSION_NAME ${TARGET_NAME}_extension) -set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) -include_directories(src/include) +include_directories( + src/include + ${DuckDB_SOURCE_DIR}/third_party/httplib +) -set(EXTENSION_SOURCES src/ui_extension.cpp) +set(EXTENSION_SOURCES + src/ui_extension.cpp + src/http_server.cpp + src/utils/encoding.cpp + src/utils/env.cpp + src/utils/helpers.cpp + src/utils/serialization.cpp +) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) -# Link OpenSSL in both the static library as the loadable extension target_link_libraries(${EXTENSION_NAME} OpenSSL::SSL OpenSSL::Crypto) -target_link_libraries(${LOADABLE_EXTENSION_NAME} OpenSSL::SSL OpenSSL::Crypto) install( TARGETS ${EXTENSION_NAME} diff --git a/src/http_server.cpp b/src/http_server.cpp new file mode 100644 index 0000000..1d42626 --- /dev/null +++ b/src/http_server.cpp @@ -0,0 +1,408 @@ +#include "http_server.hpp" + +#include +#include +#include +#include "utils/env.hpp" +#include "utils/serialization.hpp" +#include "utils/encoding.hpp" + +// Chosen to be no more than half of the lesser of the two limits: +// - The default httplib thread pool size = 8 +// - The browser limit on the number of server-sent event connections = 6 +#define MAX_EVENT_WAIT_COUNT 3 + +namespace duckdb { +namespace ui { + +// An empty Server-Sent Events message. See +// https://html.spec.whatwg.org/multipage/server-sent-events.html#authoring-notes +constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r"; +constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; + +bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { + std::unique_lock lock(mutex_); + // Don't allow too many simultaneous waits, because each consumes a thread in the httplib thread pool, and also + // browsers limit the number of server-sent event connections. + if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) { + return false; + } + int target_id = next_id_; + wait_count_++; + cv_.wait_for(lock, std::chrono::seconds(5)); + wait_count_--; + if (closed_) { + return false; + } + if (current_id_ == target_id) { + sink->write(message_.data(), message_.size()); + } else { + // Our wait timer expired. Write an empty, no-op message. + // This enables detecting when the client is gone. + sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); + } + return true; +} + +void EventDispatcher::SendEvent(const std::string &message) { + std::lock_guard guard(mutex_); + if (closed_) { + return; + } + + current_id_ = next_id_++; + message_ = message; + cv_.notify_all(); +} + +void EventDispatcher::Close() { + std::lock_guard guard(mutex_); + if (closed_) { + return; + } + current_id_ = next_id_++; + closed_ = true; + cv_.notify_all(); +} + +unique_ptr HttpServer::instance_; + +HttpServer* HttpServer::instance() { + if (!instance_) { + instance_ = make_uniq(); + std::atexit(HttpServer::StopInstance); + } + return instance_.get(); +} + +bool HttpServer::Started() { + return instance_ && instance_->thread_; +} + +void HttpServer::StopInstance() { + if (instance_) { + instance_->Stop(); + } +} + +bool HttpServer::Start(const uint16_t local_port, const std::string &remote_url, + const shared_ptr &ddb_instance) { + if (thread_) { + return false; + } + + local_port_ = local_port; + remote_url_ = remote_url; + ddb_instance_ = ddb_instance; +#ifndef EXT_VERSION_UI +#error "EXT_VERSION_UI must be defined" +#endif + user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", EXT_VERSION_UI, DuckDB::Platform()); + event_dispatcher_ = make_uniq(); + thread_ = make_uniq(&HttpServer::Run, this); + return true; +} + +bool HttpServer::Stop() { + if (!thread_) { + return false; + } + + event_dispatcher_->Close(); + server_.stop(); + thread_->join(); + thread_.reset(); + event_dispatcher_.reset(); + connections_.clear(); + ddb_instance_.reset(); + remote_url_ = ""; + local_port_ = 0; + return true; +} + +uint16_t HttpServer::LocalPort() { + return local_port_; +} + +void HttpServer::SendConnectedEvent(const std::string &token) { + SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); +} + +void HttpServer::SendCatalogChangedEvent() { + SendEvent("event: CatalogChangeEvent\ndata:\n\n"); +} + +void HttpServer::SendEvent(const std::string &message) { + if (event_dispatcher_) { + event_dispatcher_->SendEvent(message); + } +} + +void HttpServer::Run() { + server_.Get("/localEvents", + [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalEvents(req, res); }); + server_.Get("/localToken", + [&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalToken(req, res); }); + server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { HandleGet(req, res); }); + server_.Post("/ddb/interrupt", + [&](const httplib::Request &req, httplib::Response &res) { HandleInterrupt(req, res); }); + server_.Post("/ddb/run", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { HandleRun(req, res, content_reader); }); + server_.Post("/ddb/tokenize", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { HandleTokenize(req, res, content_reader); }); + server_.listen("localhost", local_port_); +} + +void HttpServer::HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res) { + res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { + if (event_dispatcher_->WaitEvent(&sink)) { + return true; + } + sink.done(); + return false; + }); +} + +void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { + throw new std::runtime_error("Not implemented"); // FIXME + // auto md_config = ClientExtension::GetState(*ddb_instance_).GetConfig(); + // std::string md_token = md_config ? md_config->token : ""; + // res.set_content(md_token, "text/plain"); +} + +void HttpServer::HandleGet(const httplib::Request &req, httplib::Response &res) { + // Create HTTP client to remote URL + // TODO: Can this be created once and shared? + httplib::Client client(remote_url_); + client.set_keep_alive(true); + + // Provide a way to turn on or off server certificate verification, at least for now, because it requires httplib to + // correctly get the root certficates on each platform, which doesn't appear to always work. + // Currently, default to no verification, until we understand when it breaks things. + if (IsEnvEnabled("ui_enable_server_certificate_verification")) { + client.enable_server_certificate_verification(true); + } else { + client.enable_server_certificate_verification(false); + } + + // forward GET to remote URL + auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}}); + if (!result) { + res.status = 500; + return; + } + + // Repond with result of forwarded GET + res = result.value(); + + // If this is the config request, set the X-MD-DuckDB-Mode header to HTTP. + // The UI looks for this to select the appropriate DuckDB mode (HTTP or Wasm). + if (req.path == "/config") { + res.set_header("X-MD-DuckDB-Mode", "HTTP"); + } +} + +void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { + auto description = req.get_header_value("X-MD-Description"); + auto connection_name = req.get_header_value("X-MD-Connection-Name"); + + auto connection = FindConnection(connection_name); + if (!connection) { + res.status = 404; + return; + } + + connection->Interrupt(); + + SetResponseEmptyResult(res); +} + +void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + try { + auto description = req.get_header_value("X-MD-Description"); + auto connection_name = req.get_header_value("X-MD-Connection-Name"); + + auto database_name = DecodeBase64(req.get_header_value("X-MD-Database-Name")); + auto parameter_count = req.get_header_value_count("X-MD-Parameter"); + + std::string content = ReadContent(content_reader); + + + auto connection = FindOrCreateConnection(connection_name); + + // Set current database if optional header is provided. + if (!database_name.empty()) { + connection->context->RunFunctionInTransaction( + [&] { ddb_instance_->GetDatabaseManager().SetDefaultDatabase(*connection->context, database_name); }); + } + + // We use a pending query so we can execute tasks and fetch chunks incrementally. + // This enables cancellation. + unique_ptr pending; + + // Create pending query, with request content as SQL. + if (parameter_count > 0) { + auto prepared = connection->Prepare(content); + if (prepared->HasError()) { + SetResponseErrorResult(res, prepared->GetError()); + return; + } + + vector values; + for (idx_t i = 0; i < parameter_count; ++i) { + auto parameter = DecodeBase64(req.get_header_value("X-MD-Parameter", i)); + values.push_back(Value(parameter)); // TODO: support non-string parameters? (SURF-1546) + } + pending = prepared->PendingQuery(values, true); + } else { + pending = connection->PendingQuery(content, true); + } + + if (pending->HasError()) { + SetResponseErrorResult(res, pending->GetError()); + return; + } + + // Execute tasks until result is ready (or there's an error). + auto exec_result = PendingExecutionResult::RESULT_NOT_READY; + while (!PendingQueryResult::IsResultReady(exec_result)) { + exec_result = pending->ExecuteTask(); + if (exec_result == PendingExecutionResult::BLOCKED || + exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + + switch (exec_result) { + + case PendingExecutionResult::EXECUTION_ERROR: { + SetResponseErrorResult(res, pending->GetError()); + } break; + + case PendingExecutionResult::EXECUTION_FINISHED: + case PendingExecutionResult::RESULT_READY: { + // Get the result. This should be quick because it's ready. + auto result = pending->Execute(); + + // Fetch the chunks and serialize the result. + SuccessResult success_result; + success_result.column_names_and_types = {std::move(result->names), std::move(result->types)}; + + // TODO: support limiting the number of chunks fetched (SURF-1540) + auto chunk = result->Fetch(); + while (chunk) { + success_result.chunks.push_back({static_cast(chunk->size()), std::move(chunk->data)}); + chunk = result->Fetch(); + } + + MemoryStream success_response_content; + BinarySerializer::Serialize(success_result, success_response_content); + SetResponseContent(res, success_response_content); + } break; + + default: { + SetResponseErrorResult(res, "Unexpected PendingExecutionResult"); + } break; + } + + } catch (const std::exception &ex) { + SetResponseErrorResult(res, ex.what()); + } +} + +void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + auto description = req.get_header_value("X-MD-Description"); + + std::string content = ReadContent(content_reader); + + + auto tokens = Parser::Tokenize(content); + + // Read and serialize result + TokenizeResult result; + result.offsets.reserve(tokens.size()); + result.types.reserve(tokens.size()); + + for (auto token : tokens) { + result.offsets.push_back(token.start); + result.types.push_back(token.type); + } + + MemoryStream response_content; + BinarySerializer::Serialize(result, response_content); + SetResponseContent(res, response_content); +} + +std::string HttpServer::ReadContent(const httplib::ContentReader &content_reader) { + std::ostringstream oss; + content_reader([&](const char *data, size_t data_length) { + oss.write(data, data_length); + return true; + }); + return oss.str(); +} + +shared_ptr HttpServer::FindConnection(const std::string &connection_name) { + if (connection_name.empty()) { + return nullptr; + } + + // Need to protect access to the connections map because this can be called from multiple threads. + std::lock_guard guard(connections_mutex_); + + auto result = connections_.find(connection_name); + if (result != connections_.end()) { + return result->second; + } + + return nullptr; +} + +shared_ptr HttpServer::FindOrCreateConnection(const std::string &connection_name) { + if (connection_name.empty()) { + // If no connection name was provided, create and return a new connection but don't remember it. + return make_shared_ptr(*ddb_instance_); + } + + // Need to protect access to the connections map because this can be called from multiple threads. + std::lock_guard guard(connections_mutex_); + + // If an existing connection with the provided name was found, return it. + auto result = connections_.find(connection_name); + if (result != connections_.end()) { + return result->second; + } + + // Otherwise, create a new one, remember it, and return it. + auto connection = make_shared_ptr(*ddb_instance_); + connections_[connection_name] = connection; + return connection; +} + +void HttpServer::SetResponseContent(httplib::Response &res, const MemoryStream &content) { + auto data = content.GetData(); + auto length = content.GetPosition(); + res.set_content(reinterpret_cast(data), length, "application/octet-stream"); +} + +void HttpServer::SetResponseEmptyResult(httplib::Response &res) { + EmptyResult empty_result; + MemoryStream response_content; + BinarySerializer::Serialize(empty_result, response_content); + SetResponseContent(res, response_content); +} + +void HttpServer::SetResponseErrorResult(httplib::Response &res, const std::string &error) { + ErrorResult error_result; + error_result.error = error; + MemoryStream response_content; + BinarySerializer::Serialize(error_result, response_content); + SetResponseContent(res, response_content); +} + +} // namespace ui +} // namespace md diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp new file mode 100644 index 0000000..12e811e --- /dev/null +++ b/src/include/http_server.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include "duckdb.hpp" + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +#include +#include +#include +#include +#include + +namespace httplib = duckdb_httplib_openssl; + +namespace duckdb { +class MemoryStream; + +namespace ui { + +class EventDispatcher { +public: + bool WaitEvent(httplib::DataSink *sink); + void SendEvent(const std::string &message); + void Close(); + +private: + std::mutex mutex_; + std::condition_variable cv_; + std::atomic_int next_id_ {0}; + std::atomic_int current_id_ {-1}; + std::atomic_int wait_count_ {0}; + std::string message_; + std::atomic_bool closed_ {false}; +}; + +class HttpServer { + +public: + static HttpServer* instance(); + static bool Started(); + static void StopInstance(); + + bool Start(const uint16_t localPort, const std::string &remoteUrl, + const shared_ptr &ddbInstance); + bool Stop(); + uint16_t LocalPort(); + void SendConnectedEvent(const std::string &token); + void SendCatalogChangedEvent(); + +private: + void SendEvent(const std::string &message); + void Run(); + void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); + void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); + void HandleGet(const httplib::Request &req, httplib::Response &res); + void HandleInterrupt(const httplib::Request &req, httplib::Response &res); + void HandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &contentReader); + void HandleTokenize(const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &contentReader); + std::string ReadContent(const httplib::ContentReader &contentReader); + shared_ptr FindConnection(const std::string &connectionName); + shared_ptr FindOrCreateConnection(const std::string &connectionName); + void SetResponseContent(httplib::Response &res, const MemoryStream &content); + void SetResponseEmptyResult(httplib::Response &res); + void SetResponseErrorResult(httplib::Response &res, const std::string &error); + + uint16_t local_port_; + std::string remote_url_; + shared_ptr ddb_instance_; + std::string user_agent_; + httplib::Server server_; + unique_ptr thread_; + std::mutex connections_mutex_; + std::unordered_map> connections_; + unique_ptr event_dispatcher_; + + static unique_ptr instance_; +};; + +} // namespace ui +} // namespace duckdb diff --git a/src/include/utils/encoding.hpp b/src/include/utils/encoding.hpp new file mode 100644 index 0000000..34cb5e1 --- /dev/null +++ b/src/include/utils/encoding.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace duckdb { + +std::string DecodeBase64(const std::string &str); + +} // namespace duckdb diff --git a/src/include/utils/env.hpp b/src/include/utils/env.hpp new file mode 100644 index 0000000..74a6cba --- /dev/null +++ b/src/include/utils/env.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace duckdb { + +const char *TryGetEnv(const char *name); + +std::string GetEnvOrDefault(const char *name, const char *default_value); + +bool IsEnvEnabled(const char *name); + +} // namespace duckdb diff --git a/src/include/utils/helpers.hpp b/src/include/utils/helpers.hpp new file mode 100644 index 0000000..70c7bc0 --- /dev/null +++ b/src/include/utils/helpers.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace duckdb { + +struct RunOnceTableFunctionState : GlobalTableFunctionState { + RunOnceTableFunctionState() : run(false) {}; + std::atomic run; + + static unique_ptr Init(ClientContext &, + TableFunctionInitInput &) { + return make_uniq(); + } +}; + +template +T GetSetting(const ClientContext &context, const char *setting_name, const T default_value) { + Value value; + return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue() : default_value; +} + +bool ShouldRun(TableFunctionInput &input); + +void RegisterTF(DatabaseInstance &instance, const char* name, table_function_t func); + +} // namespace duckdb diff --git a/src/include/utils/serialization.hpp b/src/include/utils/serialization.hpp new file mode 100644 index 0000000..9772a4c --- /dev/null +++ b/src/include/utils/serialization.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "duckdb.hpp" + +#include + +namespace duckdb { +namespace ui { + +struct EmptyResult { + void Serialize(duckdb::Serializer &serializer) const; +}; + +struct TokenizeResult { + duckdb::vector offsets; + duckdb::vector types; + + void Serialize(duckdb::Serializer &serializer) const; +}; + +struct ColumnNamesAndTypes { + duckdb::vector names; + duckdb::vector types; + + void Serialize(duckdb::Serializer &serializer) const; +}; + +struct Chunk { + uint16_t row_count; + duckdb::vector vectors; + + void Serialize(duckdb::Serializer &serializer) const; +}; + +struct SuccessResult { + ColumnNamesAndTypes column_names_and_types; + duckdb::vector chunks; + + void Serialize(duckdb::Serializer &serializer) const; +}; + +struct ErrorResult { + std::string error; + + void Serialize(duckdb::Serializer &serializer) const; +}; + +} // namespace ui +} // namespace duckdb diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index e313375..b66f8f9 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -1,47 +1,111 @@ #define DUCKDB_EXTENSION_MAIN +#include "utils/env.hpp" +#include "utils/helpers.hpp" #include "ui_extension.hpp" -#include "duckdb.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/main/extension_util.hpp" -#include +#include "http_server.hpp" +#include +#include -// OpenSSL linked through vcpkg -#include +#ifdef _WIN32 +#define OPEN_COMMAND "start" +#elif __linux__ +#define OPEN_COMMAND "xdg-open" +#else +#define OPEN_COMMAND "open" +#endif + +#define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" +#define UI_LOCAL_PORT_SETTING_DESCRIPTION "Local port on which the UI server listens" +#define UI_LOCAL_PORT_SETTING_DEFAULT 4213 + +#define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" +#define UI_REMOTE_URL_SETTING_DESCRIPTION "Remote URL to which the UI server forwards GET requests" +#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com" namespace duckdb { -inline void UiScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { - auto &name_vector = args.data[0]; - UnaryExecutor::Execute( - name_vector, result, args.size(), - [&](string_t name) { - return StringVector::AddString(result, "Ui "+name.GetString()+" 🐥"); - }); +namespace internal { + +bool StartHttpServer(const ClientContext &context) { + const auto url = GetSetting(context, UI_REMOTE_URL_SETTING_NAME, + GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT)); + const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DEFAULT);; + return ui::HttpServer::instance()->Start(port, url, context.db); } -inline void UiOpenSSLVersionScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { - auto &name_vector = args.data[0]; - UnaryExecutor::Execute( - name_vector, result, args.size(), - [&](string_t name) { - return StringVector::AddString(result, "Ui " + name.GetString() + - ", my linked OpenSSL version is " + - OPENSSL_VERSION_TEXT ); - }); +std::string GetHttpServerLocalURL() { + return StringUtil::Format("http://localhost:%d/", ui::HttpServer::instance()->LocalPort()); } +} // namespace internal + +void OutputResult(const std::string &result, DataChunk &out_chunk) { + out_chunk.SetCardinality(1); + out_chunk.SetValue(0, 0, result); +} + +void StartUIFunction(ClientContext &context, TableFunctionInput &input, + DataChunk &out_chunk) { + if (!ShouldRun(input)) { + return; + } + + internal::StartHttpServer(context); + auto local_url = internal::GetHttpServerLocalURL(); + + const std::string command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); + std::string result = system(command.c_str()) ? + StringUtil::Format("Navigate browser to %s", local_url) // open command failed + : StringUtil::Format("MotherDuck UI started at %s", local_url); + OutputResult(result, out_chunk); +} + +void StartUIServerFunction(ClientContext &context, TableFunctionInput &input, + DataChunk &out_chunk) { + if (!ShouldRun(input)) { + return; + } + + const bool already = internal::StartHttpServer(context); + const char* already_str = already ? "already " : ""; + auto result = StringUtil::Format("MotherDuck UI server %sstarted at %s", already_str, internal::GetHttpServerLocalURL()); + OutputResult(result, out_chunk); +} + +void StopUIServerFunction(ClientContext &, TableFunctionInput &input, + DataChunk &out_chunk) { + if (!ShouldRun(input)) { + return; + } + + auto result = ui::HttpServer::instance()->Stop() ? "UI server stopped" : "UI server already stopped"; + OutputResult(result, out_chunk); +} + +// FIXME +// void HandleConnected(const std::string &token) { +// ui::HttpServer::instance()->SendConnectedEvent(token); +// } + +// FIXME +// void HandleCatalogChanged() { +// ui::HttpServer::instance()->SendCatalogChangedEvent(); +// } + static void LoadInternal(DatabaseInstance &instance) { - // Register a scalar function - auto ui_scalar_function = ScalarFunction("ui", {LogicalType::VARCHAR}, LogicalType::VARCHAR, UiScalarFun); - ExtensionUtil::RegisterFunction(instance, ui_scalar_function); + auto &config = DBConfig::GetConfig(instance); - // Register another scalar function - auto ui_openssl_version_scalar_function = ScalarFunction("ui_openssl_version", {LogicalType::VARCHAR}, - LogicalType::VARCHAR, UiOpenSSLVersionScalarFun); - ExtensionUtil::RegisterFunction(instance, ui_openssl_version_scalar_function); + config.AddExtensionOption(UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION, + LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); + + config.AddExtensionOption( + UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION, LogicalType::VARCHAR, + Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT))); + + RegisterTF(instance, "start_ui", StartUIFunction); + RegisterTF(instance, "start_ui_server", StartUIServerFunction); + RegisterTF(instance, "stop_ui_server", StopUIServerFunction); } void UiExtension::Load(DuckDB &db) { diff --git a/src/utils/encoding.cpp b/src/utils/encoding.cpp new file mode 100644 index 0000000..1cd876d --- /dev/null +++ b/src/utils/encoding.cpp @@ -0,0 +1,60 @@ +#include "utils/encoding.hpp" + +#include + +namespace duckdb { + +// Copied from https://www.mycplus.com/source-code/c-source-code/base64-encode-decode/ +constexpr char k_encoding_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+_"; + +std::vector BuildDecodingTable() { + std::vector decoding_table; + decoding_table.resize(256); + for (int i = 0; i < 64; ++i) { + decoding_table[static_cast(k_encoding_table[i])] = i; + } + return decoding_table; +} + +const static std::vector k_decoding_table = BuildDecodingTable(); + +std::string DecodeBase64(const std::string &data) { + size_t input_length = data.size(); + if (input_length < 4 || input_length % 4 != 0) { + // Handle this exception + return ""; + } + + size_t output_length = input_length / 4 * 3; + if (data[input_length - 1] == '=') { + output_length--; + } + if (data[input_length - 2] == '=') { + output_length--; + } + + std::string decoded_data; + decoded_data.resize(output_length); + for (size_t i = 0, j = 0; i < input_length;) { + uint32_t sextet_a = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_b = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_c = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + uint32_t sextet_d = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; + + uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); + + if (j < output_length) { + decoded_data[j++] = (triple >> 2 * 8) & 0xFF; + } + if (j < output_length) { + decoded_data[j++] = (triple >> 1 * 8) & 0xFF; + } + if (j < output_length) { + decoded_data[j++] = (triple >> 0 * 8) & 0xFF; + } + } + + return decoded_data; +} + +} // namespace duckdb diff --git a/src/utils/env.cpp b/src/utils/env.cpp new file mode 100644 index 0000000..6c96110 --- /dev/null +++ b/src/utils/env.cpp @@ -0,0 +1,34 @@ +#include "utils/env.hpp" + +#include +#include + +namespace duckdb { + +const char *TryGetEnv(const char *name) { + const char *res = std::getenv(name); + if (res) { + return res; + } + return std::getenv(StringUtil::Upper(name).c_str()); +} + +std::string GetEnvOrDefault(const char *name, const char *default_value) { + const char *res = TryGetEnv(name); + if (res) { + return res; + } + return default_value; +} + +bool IsEnvEnabled(const char *name) { + const char *res = TryGetEnv(name); + if (!res) { + return false; + } + + auto lc_res = StringUtil::Lower(res); + return lc_res == "1" || lc_res == "true"; +} + +} // namespace duckdb diff --git a/src/utils/helpers.cpp b/src/utils/helpers.cpp new file mode 100644 index 0000000..71c21d1 --- /dev/null +++ b/src/utils/helpers.cpp @@ -0,0 +1,30 @@ +#include "utils/helpers.hpp" +#include + +namespace duckdb { + +bool ShouldRun(TableFunctionInput &input) { + auto state = dynamic_cast(input.global_state.get()); + D_ASSERT(state != nullptr); + if (state->run) { + return false; + } + + state->run = true; + return true; +} + +unique_ptr ResultBind(ClientContext &, TableFunctionBindInput &, + vector &out_types, + vector &out_names) { + out_names.emplace_back("result"); + out_types.emplace_back(LogicalType::VARCHAR); + return nullptr; +} + +void RegisterTF(DatabaseInstance &instance, const char* name, table_function_t func) { + TableFunction tf(name, {}, func, ResultBind, RunOnceTableFunctionState::Init); + ExtensionUtil::RegisterFunction(instance, tf); +} + +} // namespace duckdb diff --git a/src/utils/serialization.cpp b/src/utils/serialization.cpp new file mode 100644 index 0000000..c04fc12 --- /dev/null +++ b/src/utils/serialization.cpp @@ -0,0 +1,49 @@ +#include "utils/serialization.hpp" + +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { +namespace ui { + +void EmptyResult::Serialize(Serializer &) const { +} + +void TokenizeResult::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "offsets", offsets); + serializer.WriteProperty(101, "types", types); +} + +// Adapted from parts of DataChunk::Serialize +void ColumnNamesAndTypes::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "names", names); + serializer.WriteProperty(101, "types", types); +} + +// Adapted from parts of DataChunk::Serialize +void Chunk::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "row_count", row_count); + serializer.WriteList(101, "vectors", vectors.size(), [&](Serializer::List &list, idx_t i) { + list.WriteObject([&](Serializer &object) { + // Reference the vector to avoid potentially mutating it during serialization + Vector serialized_vector(vectors[i].GetType()); + serialized_vector.Reference(vectors[i]); + serialized_vector.Serialize(object, row_count); + }); + }); +} + +void SuccessResult::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "success", true); + serializer.WriteProperty(101, "column_names_and_types", column_names_and_types); + serializer.WriteList(102, "chunks", chunks.size(), + [&](Serializer::List &list, idx_t i) { list.WriteElement(chunks[i]); }); +} + +void ErrorResult::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "success", false); + serializer.WriteProperty(101, "error", error); +} + +} // namespace ui +} // namespace duckdb