This commit is contained in:
Yves
2025-02-19 11:20:42 +01:00
parent 80df1f7ce1
commit 5811f237d8
11 changed files with 636 additions and 569 deletions

View File

@@ -3,39 +3,31 @@ cmake_minimum_required(VERSION 3.5...3.31.5)
# Set extension name here # Set extension name here
set(TARGET_NAME ui) set(TARGET_NAME ui)
# DuckDB's extension distribution supports vcpkg. As such, dependencies can be added in ./vcpkg.json and then # DuckDB's extension distribution supports vcpkg. As such, dependencies can be
# used in cmake with find_package. Feel free to remove or replace with other dependencies. # added in ./vcpkg.json and then used in cmake with find_package. Feel free to
# Note that it should also be removed from vcpkg.json to prevent needlessly installing it.. # remove or replace with other dependencies. Note that it should also be removed
# from vcpkg.json to prevent needlessly installing it..
find_package(OpenSSL REQUIRED) find_package(OpenSSL REQUIRED)
set(EXTENSION_NAME ${TARGET_NAME}_extension) set(EXTENSION_NAME ${TARGET_NAME}_extension)
project(${TARGET_NAME}) project(${TARGET_NAME})
include_directories( include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib)
src/include
${DuckDB_SOURCE_DIR}/third_party/httplib
)
set(EXTENSION_SOURCES set(EXTENSION_SOURCES
src/ui_extension.cpp src/ui_extension.cpp src/http_server.cpp src/utils/encoding.cpp
src/http_server.cpp src/utils/env.cpp src/utils/helpers.cpp src/utils/serialization.cpp)
src/utils/encoding.cpp
src/utils/env.cpp
src/utils/helpers.cpp
src/utils/serialization.cpp
)
find_package(Git) find_package(Git)
if (NOT Git_FOUND) if(NOT Git_FOUND)
message(FATAL_ERROR "Git not found, unable to determine git sha") message(FATAL_ERROR "Git not found, unable to determine git sha")
endif() endif()
execute_process( execute_process(
COMMAND ${GIT_EXECUTABLE} rev-parse --short=10 HEAD COMMAND ${GIT_EXECUTABLE} rev-parse --short=10 HEAD
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
OUTPUT_VARIABLE UI_EXTENSION_GIT_SHA OUTPUT_VARIABLE UI_EXTENSION_GIT_SHA
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE)
)
message(STATUS "UI_EXTENSION_GIT_SHA=${UI_EXTENSION_GIT_SHA}") message(STATUS "UI_EXTENSION_GIT_SHA=${UI_EXTENSION_GIT_SHA}")
add_definitions(-DUI_EXTENSION_GIT_SHA="${UI_EXTENSION_GIT_SHA}") add_definitions(-DUI_EXTENSION_GIT_SHA="${UI_EXTENSION_GIT_SHA}")

View File

@@ -21,401 +21,437 @@ constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r";
constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3;
bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { bool EventDispatcher::WaitEvent(httplib::DataSink *sink) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
// Don't allow too many simultaneous waits, because each consumes a thread in the httplib thread pool, and also // Don't allow too many simultaneous waits, because each consumes a thread in
// browsers limit the number of server-sent event connections. // the httplib thread pool, and also browsers limit the number of server-sent
if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) { // event connections.
return false; if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) {
} return false;
int target_id = next_id_; }
wait_count_++; int target_id = next_id_;
cv_.wait_for(lock, std::chrono::seconds(5)); wait_count_++;
wait_count_--; cv_.wait_for(lock, std::chrono::seconds(5));
if (closed_) { wait_count_--;
return false; if (closed_) {
} return false;
if (current_id_ == target_id) { }
sink->write(message_.data(), message_.size()); if (current_id_ == target_id) {
} else { sink->write(message_.data(), message_.size());
// Our wait timer expired. Write an empty, no-op message. } else {
// This enables detecting when the client is gone. // Our wait timer expired. Write an empty, no-op message.
sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH); // This enables detecting when the client is gone.
} sink->write(EMPTY_SSE_MESSAGE, EMPTY_SSE_MESSAGE_LENGTH);
return true; }
return true;
} }
void EventDispatcher::SendEvent(const std::string &message) { void EventDispatcher::SendEvent(const std::string &message) {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
if (closed_) { if (closed_) {
return; return;
} }
current_id_ = next_id_++; current_id_ = next_id_++;
message_ = message; message_ = message;
cv_.notify_all(); cv_.notify_all();
} }
void EventDispatcher::Close() { void EventDispatcher::Close() {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
if (closed_) { if (closed_) {
return; return;
} }
current_id_ = next_id_++; current_id_ = next_id_++;
closed_ = true; closed_ = true;
cv_.notify_all(); cv_.notify_all();
} }
unique_ptr<HttpServer> HttpServer::instance_; unique_ptr<HttpServer> HttpServer::instance_;
HttpServer* HttpServer::instance() { HttpServer *HttpServer::instance() {
if (!instance_) { if (!instance_) {
instance_ = make_uniq<HttpServer>(); instance_ = make_uniq<HttpServer>();
std::atexit(HttpServer::StopInstance); std::atexit(HttpServer::StopInstance);
} }
return instance_.get(); return instance_.get();
} }
bool HttpServer::Started() { bool HttpServer::Started() { return instance_ && instance_->thread_; }
return instance_ && instance_->thread_;
}
void HttpServer::StopInstance() { void HttpServer::StopInstance() {
if (instance_) { if (instance_) {
instance_->Stop(); instance_->Stop();
} }
} }
bool HttpServer::Start(const uint16_t local_port, const std::string &remote_url, bool HttpServer::Start(const uint16_t local_port, const std::string &remote_url,
const shared_ptr<DatabaseInstance> &ddb_instance) { const shared_ptr<DatabaseInstance> &ddb_instance) {
if (thread_) { if (thread_) {
return false; return false;
} }
local_port_ = local_port; local_port_ = local_port;
remote_url_ = remote_url; remote_url_ = remote_url;
ddb_instance_ = ddb_instance; ddb_instance_ = ddb_instance;
#ifndef UI_EXTENSION_GIT_SHA #ifndef UI_EXTENSION_GIT_SHA
#error "UI_EXTENSION_GIT_SHA must be defined" #error "UI_EXTENSION_GIT_SHA must be defined"
#endif #endif
user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA, DuckDB::Platform()); user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA,
event_dispatcher_ = make_uniq<EventDispatcher>(); DuckDB::Platform());
thread_ = make_uniq<std::thread>(&HttpServer::Run, this); event_dispatcher_ = make_uniq<EventDispatcher>();
return true; thread_ = make_uniq<std::thread>(&HttpServer::Run, this);
return true;
} }
bool HttpServer::Stop() { bool HttpServer::Stop() {
if (!thread_) { if (!thread_) {
return false; return false;
} }
event_dispatcher_->Close(); event_dispatcher_->Close();
server_.stop(); server_.stop();
thread_->join(); thread_->join();
thread_.reset(); thread_.reset();
event_dispatcher_.reset(); event_dispatcher_.reset();
connections_.clear(); connections_.clear();
ddb_instance_.reset(); ddb_instance_.reset();
remote_url_ = ""; remote_url_ = "";
local_port_ = 0; local_port_ = 0;
return true; return true;
} }
uint16_t HttpServer::LocalPort() { uint16_t HttpServer::LocalPort() { return local_port_; }
return local_port_;
}
void HttpServer::SendConnectedEvent(const std::string &token) { void HttpServer::SendConnectedEvent(const std::string &token) {
SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token));
} }
void HttpServer::SendCatalogChangedEvent() { void HttpServer::SendCatalogChangedEvent() {
SendEvent("event: CatalogChangeEvent\ndata:\n\n"); SendEvent("event: CatalogChangeEvent\ndata:\n\n");
} }
void HttpServer::SendEvent(const std::string &message) { void HttpServer::SendEvent(const std::string &message) {
if (event_dispatcher_) { if (event_dispatcher_) {
event_dispatcher_->SendEvent(message); event_dispatcher_->SendEvent(message);
} }
} }
void HttpServer::Run() { void HttpServer::Run() {
server_.Get("/localEvents", server_.Get("/localEvents",
[&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalEvents(req, res); }); [&](const httplib::Request &req, httplib::Response &res) {
server_.Get("/localToken", HandleGetLocalEvents(req, res);
[&](const httplib::Request &req, httplib::Response &res) { HandleGetLocalToken(req, res); }); });
server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { HandleGet(req, res); }); server_.Get("/localToken",
server_.Post("/ddb/interrupt", [&](const httplib::Request &req, httplib::Response &res) {
[&](const httplib::Request &req, httplib::Response &res) { HandleInterrupt(req, res); }); HandleGetLocalToken(req, res);
server_.Post("/ddb/run", });
[&](const httplib::Request &req, httplib::Response &res, server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) {
const httplib::ContentReader &content_reader) { HandleRun(req, res, content_reader); }); HandleGet(req, res);
server_.Post("/ddb/tokenize", });
[&](const httplib::Request &req, httplib::Response &res, server_.Post("/ddb/interrupt",
const httplib::ContentReader &content_reader) { HandleTokenize(req, res, content_reader); }); [&](const httplib::Request &req, httplib::Response &res) {
server_.listen("localhost", local_port_); HandleInterrupt(req, res);
});
server_.Post("/ddb/run",
[&](const httplib::Request &req, httplib::Response &res,
const httplib::ContentReader &content_reader) {
HandleRun(req, res, content_reader);
});
server_.Post("/ddb/tokenize",
[&](const httplib::Request &req, httplib::Response &res,
const httplib::ContentReader &content_reader) {
HandleTokenize(req, res, content_reader);
});
server_.listen("localhost", local_port_);
} }
void HttpServer::HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res) { void HttpServer::HandleGetLocalEvents(const httplib::Request &req,
res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { httplib::Response &res) {
if (event_dispatcher_->WaitEvent(&sink)) { res.set_chunked_content_provider(
return true; "text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) {
} if (event_dispatcher_->WaitEvent(&sink)) {
sink.done(); return true;
return false; }
}); sink.done();
return false;
});
} }
void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { void HttpServer::HandleGetLocalToken(const httplib::Request &req,
if (!ddb_instance_->ExtensionIsLoaded("motherduck")) { httplib::Response &res) {
res.set_content("", "text/plain"); // UI expects an empty response if the extension is not loaded if (!ddb_instance_->ExtensionIsLoaded("motherduck")) {
return; res.set_content("", "text/plain"); // UI expects an empty response if the
} // extension is not loaded
return;
}
Connection connection(*ddb_instance_); Connection connection(*ddb_instance_);
auto query_res = connection.Query("CALL get_md_token()"); auto query_res = connection.Query("CALL get_md_token()");
if (query_res->HasError()) { if (query_res->HasError()) {
res.status = 500; res.status = 500;
res.set_content("Could not get token: " + query_res->GetError(), "text/plain"); res.set_content("Could not get token: " + query_res->GetError(),
return; "text/plain");
} return;
}
auto chunk = query_res->Fetch(); auto chunk = query_res->Fetch();
auto token = chunk->GetValue(0, 0).GetValue<std::string>(); auto token = chunk->GetValue(0, 0).GetValue<std::string>();
res.status = 200; res.status = 200;
res.set_content(token, "text/plain"); res.set_content(token, "text/plain");
} }
void HttpServer::HandleGet(const httplib::Request &req, httplib::Response &res) { void HttpServer::HandleGet(const httplib::Request &req,
// Create HTTP client to remote URL httplib::Response &res) {
// TODO: Can this be created once and shared? // Create HTTP client to remote URL
httplib::Client client(remote_url_); // TODO: Can this be created once and shared?
client.set_keep_alive(true); httplib::Client client(remote_url_);
client.set_keep_alive(true);
// Provide a way to turn on or off server certificate verification, at least for now, because it requires httplib to // Provide a way to turn on or off server certificate verification, at least
// correctly get the root certficates on each platform, which doesn't appear to always work. // for now, because it requires httplib to correctly get the root certficates
// Currently, default to no verification, until we understand when it breaks things. // on each platform, which doesn't appear to always work. Currently, default
if (IsEnvEnabled("ui_enable_server_certificate_verification")) { // to no verification, until we understand when it breaks things.
client.enable_server_certificate_verification(true); if (IsEnvEnabled("ui_enable_server_certificate_verification")) {
} else { client.enable_server_certificate_verification(true);
client.enable_server_certificate_verification(false); } else {
} client.enable_server_certificate_verification(false);
}
// forward GET to remote URL // forward GET to remote URL
auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}}); auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}});
if (!result) { if (!result) {
res.status = 500; res.status = 500;
return; return;
} }
// Repond with result of forwarded GET // Repond with result of forwarded GET
res = result.value(); res = result.value();
// If this is the config request, set the X-MD-DuckDB-Mode header to HTTP. // If this is the config request, set the X-MD-DuckDB-Mode header to HTTP.
// The UI looks for this to select the appropriate DuckDB mode (HTTP or Wasm). // The UI looks for this to select the appropriate DuckDB mode (HTTP or Wasm).
if (req.path == "/config") { if (req.path == "/config") {
res.set_header("X-MD-DuckDB-Mode", "HTTP"); res.set_header("X-MD-DuckDB-Mode", "HTTP");
} }
} }
void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { void HttpServer::HandleInterrupt(const httplib::Request &req,
auto description = req.get_header_value("X-MD-Description"); httplib::Response &res) {
auto connection_name = req.get_header_value("X-MD-Connection-Name"); auto description = req.get_header_value("X-MD-Description");
auto connection_name = req.get_header_value("X-MD-Connection-Name");
auto connection = FindConnection(connection_name); auto connection = FindConnection(connection_name);
if (!connection) { if (!connection) {
res.status = 404; res.status = 404;
return; return;
} }
connection->Interrupt(); connection->Interrupt();
SetResponseEmptyResult(res); SetResponseEmptyResult(res);
} }
void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res,
const httplib::ContentReader &content_reader) {
try {
DoHandleRun(req, res, content_reader);
} catch (const std::exception &ex) {
SetResponseErrorResult(res, ex.what());
}
}
void HttpServer::DoHandleRun(const httplib::Request &req,
httplib::Response &res,
const httplib::ContentReader &content_reader) { const httplib::ContentReader &content_reader) {
try { auto description = req.get_header_value("X-MD-Description");
auto description = req.get_header_value("X-MD-Description"); auto connection_name = req.get_header_value("X-MD-Connection-Name");
auto connection_name = req.get_header_value("X-MD-Connection-Name");
auto database_name = DecodeBase64(req.get_header_value("X-MD-Database-Name")); auto database_name = DecodeBase64(req.get_header_value("X-MD-Database-Name"));
auto parameter_count = req.get_header_value_count("X-MD-Parameter"); auto parameter_count = req.get_header_value_count("X-MD-Parameter");
std::string content = ReadContent(content_reader); std::string content = ReadContent(content_reader);
auto connection = FindOrCreateConnection(connection_name);
auto connection = FindOrCreateConnection(connection_name); // Set current database if optional header is provided.
if (!database_name.empty()) {
connection->context->RunFunctionInTransaction([&] {
ddb_instance_->GetDatabaseManager().SetDefaultDatabase(
*connection->context, database_name);
});
}
// Set current database if optional header is provided. // We use a pending query so we can execute tasks and fetch chunks
if (!database_name.empty()) { // incrementally. This enables cancellation.
connection->context->RunFunctionInTransaction( unique_ptr<PendingQueryResult> pending;
[&] { ddb_instance_->GetDatabaseManager().SetDefaultDatabase(*connection->context, database_name); });
}
// We use a pending query so we can execute tasks and fetch chunks incrementally. // Create pending query, with request content as SQL.
// This enables cancellation. if (parameter_count > 0) {
unique_ptr<PendingQueryResult> pending; auto prepared = connection->Prepare(content);
if (prepared->HasError()) {
SetResponseErrorResult(res, prepared->GetError());
return;
}
// Create pending query, with request content as SQL. vector<Value> values;
if (parameter_count > 0) { for (idx_t i = 0; i < parameter_count; ++i) {
auto prepared = connection->Prepare(content); auto parameter = DecodeBase64(req.get_header_value("X-MD-Parameter", i));
if (prepared->HasError()) { values.push_back(
SetResponseErrorResult(res, prepared->GetError()); Value(parameter)); // TODO: support non-string parameters? (SURF-1546)
return; }
} pending = prepared->PendingQuery(values, true);
} else {
pending = connection->PendingQuery(content, true);
}
vector<Value> values; if (pending->HasError()) {
for (idx_t i = 0; i < parameter_count; ++i) { SetResponseErrorResult(res, pending->GetError());
auto parameter = DecodeBase64(req.get_header_value("X-MD-Parameter", i)); return;
values.push_back(Value(parameter)); // TODO: support non-string parameters? (SURF-1546) }
}
pending = prepared->PendingQuery(values, true);
} else {
pending = connection->PendingQuery(content, true);
}
if (pending->HasError()) { // Execute tasks until result is ready (or there's an error).
SetResponseErrorResult(res, pending->GetError()); auto exec_result = PendingExecutionResult::RESULT_NOT_READY;
return; while (!PendingQueryResult::IsResultReady(exec_result)) {
} exec_result = pending->ExecuteTask();
if (exec_result == PendingExecutionResult::BLOCKED ||
exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
// Execute tasks until result is ready (or there's an error). switch (exec_result) {
auto exec_result = PendingExecutionResult::RESULT_NOT_READY;
while (!PendingQueryResult::IsResultReady(exec_result)) {
exec_result = pending->ExecuteTask();
if (exec_result == PendingExecutionResult::BLOCKED ||
exec_result == PendingExecutionResult::NO_TASKS_AVAILABLE) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
switch (exec_result) { case PendingExecutionResult::EXECUTION_ERROR:
SetResponseErrorResult(res, pending->GetError());
break;
case PendingExecutionResult::EXECUTION_ERROR: { case PendingExecutionResult::EXECUTION_FINISHED:
SetResponseErrorResult(res, pending->GetError()); case PendingExecutionResult::RESULT_READY: {
} break; // Get the result. This should be quick because it's ready.
auto result = pending->Execute();
case PendingExecutionResult::EXECUTION_FINISHED: // Fetch the chunks and serialize the result.
case PendingExecutionResult::RESULT_READY: { SuccessResult success_result;
// Get the result. This should be quick because it's ready. success_result.column_names_and_types = {std::move(result->names),
auto result = pending->Execute(); std::move(result->types)};
// Fetch the chunks and serialize the result. // TODO: support limiting the number of chunks fetched (SURF-1540)
SuccessResult success_result; auto chunk = result->Fetch();
success_result.column_names_and_types = {std::move(result->names), std::move(result->types)}; while (chunk) {
success_result.chunks.push_back(
{static_cast<uint16_t>(chunk->size()), std::move(chunk->data)});
chunk = result->Fetch();
}
// TODO: support limiting the number of chunks fetched (SURF-1540) MemoryStream success_response_content;
auto chunk = result->Fetch(); BinarySerializer::Serialize(success_result, success_response_content);
while (chunk) { SetResponseContent(res, success_response_content);
success_result.chunks.push_back({static_cast<uint16_t>(chunk->size()), std::move(chunk->data)}); break;
chunk = result->Fetch(); }
} default:
SetResponseErrorResult(res, "Unexpected PendingExecutionResult");
MemoryStream success_response_content; break;
BinarySerializer::Serialize(success_result, success_response_content); }
SetResponseContent(res, success_response_content);
} break;
default: {
SetResponseErrorResult(res, "Unexpected PendingExecutionResult");
} break;
}
} catch (const std::exception &ex) {
SetResponseErrorResult(res, ex.what());
}
} }
void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, void HttpServer::HandleTokenize(const httplib::Request &req,
const httplib::ContentReader &content_reader) { httplib::Response &res,
auto description = req.get_header_value("X-MD-Description"); const httplib::ContentReader &content_reader) {
auto description = req.get_header_value("X-MD-Description");
std::string content = ReadContent(content_reader); std::string content = ReadContent(content_reader);
auto tokens = Parser::Tokenize(content);
auto tokens = Parser::Tokenize(content); // Read and serialize result
TokenizeResult result;
result.offsets.reserve(tokens.size());
result.types.reserve(tokens.size());
// Read and serialize result for (auto token : tokens) {
TokenizeResult result; result.offsets.push_back(token.start);
result.offsets.reserve(tokens.size()); result.types.push_back(token.type);
result.types.reserve(tokens.size()); }
for (auto token : tokens) { MemoryStream response_content;
result.offsets.push_back(token.start); BinarySerializer::Serialize(result, response_content);
result.types.push_back(token.type); SetResponseContent(res, response_content);
}
MemoryStream response_content;
BinarySerializer::Serialize(result, response_content);
SetResponseContent(res, response_content);
} }
std::string HttpServer::ReadContent(const httplib::ContentReader &content_reader) { std::string
std::ostringstream oss; HttpServer::ReadContent(const httplib::ContentReader &content_reader) {
content_reader([&](const char *data, size_t data_length) { std::ostringstream oss;
oss.write(data, data_length); content_reader([&](const char *data, size_t data_length) {
return true; oss.write(data, data_length);
}); return true;
return oss.str(); });
return oss.str();
} }
shared_ptr<Connection> HttpServer::FindConnection(const std::string &connection_name) { shared_ptr<Connection>
if (connection_name.empty()) { HttpServer::FindConnection(const std::string &connection_name) {
return nullptr; if (connection_name.empty()) {
} return nullptr;
}
// Need to protect access to the connections map because this can be called from multiple threads. // Need to protect access to the connections map because this can be called
std::lock_guard<std::mutex> guard(connections_mutex_); // from multiple threads.
std::lock_guard<std::mutex> guard(connections_mutex_);
auto result = connections_.find(connection_name); auto result = connections_.find(connection_name);
if (result != connections_.end()) { if (result != connections_.end()) {
return result->second; return result->second;
} }
return nullptr; return nullptr;
} }
shared_ptr<Connection> HttpServer::FindOrCreateConnection(const std::string &connection_name) { shared_ptr<Connection>
if (connection_name.empty()) { HttpServer::FindOrCreateConnection(const std::string &connection_name) {
// If no connection name was provided, create and return a new connection but don't remember it. if (connection_name.empty()) {
return make_shared_ptr<Connection>(*ddb_instance_); // If no connection name was provided, create and return a new connection
} // but don't remember it.
return make_shared_ptr<Connection>(*ddb_instance_);
}
// Need to protect access to the connections map because this can be called from multiple threads. // Need to protect access to the connections map because this can be called
std::lock_guard<std::mutex> guard(connections_mutex_); // from multiple threads.
std::lock_guard<std::mutex> guard(connections_mutex_);
// If an existing connection with the provided name was found, return it. // If an existing connection with the provided name was found, return it.
auto result = connections_.find(connection_name); auto result = connections_.find(connection_name);
if (result != connections_.end()) { if (result != connections_.end()) {
return result->second; return result->second;
} }
// Otherwise, create a new one, remember it, and return it. // Otherwise, create a new one, remember it, and return it.
auto connection = make_shared_ptr<Connection>(*ddb_instance_); auto connection = make_shared_ptr<Connection>(*ddb_instance_);
connections_[connection_name] = connection; connections_[connection_name] = connection;
return connection; return connection;
} }
void HttpServer::SetResponseContent(httplib::Response &res, const MemoryStream &content) { void HttpServer::SetResponseContent(httplib::Response &res,
auto data = content.GetData(); const MemoryStream &content) {
auto length = content.GetPosition(); auto data = content.GetData();
res.set_content(reinterpret_cast<const char *>(data), length, "application/octet-stream"); auto length = content.GetPosition();
res.set_content(reinterpret_cast<const char *>(data), length,
"application/octet-stream");
} }
void HttpServer::SetResponseEmptyResult(httplib::Response &res) { void HttpServer::SetResponseEmptyResult(httplib::Response &res) {
EmptyResult empty_result; EmptyResult empty_result;
MemoryStream response_content; MemoryStream response_content;
BinarySerializer::Serialize(empty_result, response_content); BinarySerializer::Serialize(empty_result, response_content);
SetResponseContent(res, response_content); SetResponseContent(res, response_content);
} }
void HttpServer::SetResponseErrorResult(httplib::Response &res, const std::string &error) { void HttpServer::SetResponseErrorResult(httplib::Response &res,
ErrorResult error_result; const std::string &error) {
error_result.error = error; ErrorResult error_result;
MemoryStream response_content; error_result.error = error;
BinarySerializer::Serialize(error_result, response_content); MemoryStream response_content;
SetResponseContent(res, response_content); BinarySerializer::Serialize(error_result, response_content);
SetResponseContent(res, response_content);
} }
} // namespace ui } // namespace ui
} // namespace md } // namespace duckdb

View File

@@ -20,63 +20,69 @@ namespace ui {
class EventDispatcher { class EventDispatcher {
public: public:
bool WaitEvent(httplib::DataSink *sink); bool WaitEvent(httplib::DataSink *sink);
void SendEvent(const std::string &message); void SendEvent(const std::string &message);
void Close(); void Close();
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
std::atomic_int next_id_ {0}; std::atomic_int next_id_{0};
std::atomic_int current_id_ {-1}; std::atomic_int current_id_{-1};
std::atomic_int wait_count_ {0}; std::atomic_int wait_count_{0};
std::string message_; std::string message_;
std::atomic_bool closed_ {false}; std::atomic_bool closed_{false};
}; };
class HttpServer { class HttpServer {
public: public:
static HttpServer* instance(); static HttpServer *instance();
static bool Started(); static bool Started();
static void StopInstance(); static void StopInstance();
bool Start(const uint16_t localPort, const std::string &remoteUrl, bool Start(const uint16_t localPort, const std::string &remoteUrl,
const shared_ptr<DatabaseInstance> &ddbInstance); const shared_ptr<DatabaseInstance> &ddbInstance);
bool Stop(); bool Stop();
uint16_t LocalPort(); uint16_t LocalPort();
void SendConnectedEvent(const std::string &token); void SendConnectedEvent(const std::string &token);
void SendCatalogChangedEvent(); void SendCatalogChangedEvent();
private: private:
void SendEvent(const std::string &message); void SendEvent(const std::string &message);
void Run(); void Run();
void HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res); void HandleGetLocalEvents(const httplib::Request &req,
void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res); httplib::Response &res);
void HandleGet(const httplib::Request &req, httplib::Response &res); void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res);
void HandleInterrupt(const httplib::Request &req, httplib::Response &res); void HandleGet(const httplib::Request &req, httplib::Response &res);
void HandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &contentReader); void HandleInterrupt(const httplib::Request &req, httplib::Response &res);
void HandleTokenize(const httplib::Request &req, httplib::Response &res, void DoHandleRun(const httplib::Request &req, httplib::Response &res,
const httplib::ContentReader &contentReader); const httplib::ContentReader &contentReader);
std::string ReadContent(const httplib::ContentReader &contentReader); void HandleRun(const httplib::Request &req, httplib::Response &res,
shared_ptr<Connection> FindConnection(const std::string &connectionName); const httplib::ContentReader &contentReader);
shared_ptr<Connection> FindOrCreateConnection(const std::string &connectionName); void HandleTokenize(const httplib::Request &req, httplib::Response &res,
void SetResponseContent(httplib::Response &res, const MemoryStream &content); const httplib::ContentReader &contentReader);
void SetResponseEmptyResult(httplib::Response &res); std::string ReadContent(const httplib::ContentReader &contentReader);
void SetResponseErrorResult(httplib::Response &res, const std::string &error); shared_ptr<Connection> FindConnection(const std::string &connectionName);
shared_ptr<Connection>
FindOrCreateConnection(const std::string &connectionName);
void SetResponseContent(httplib::Response &res, const MemoryStream &content);
void SetResponseEmptyResult(httplib::Response &res);
void SetResponseErrorResult(httplib::Response &res, const std::string &error);
uint16_t local_port_; uint16_t local_port_;
std::string remote_url_; std::string remote_url_;
shared_ptr<DatabaseInstance> ddb_instance_; shared_ptr<DatabaseInstance> ddb_instance_;
std::string user_agent_; std::string user_agent_;
httplib::Server server_; httplib::Server server_;
unique_ptr<std::thread> thread_; unique_ptr<std::thread> thread_;
std::mutex connections_mutex_; std::mutex connections_mutex_;
std::unordered_map<std::string, shared_ptr<Connection>> connections_; std::unordered_map<std::string, shared_ptr<Connection>> connections_;
unique_ptr<EventDispatcher> event_dispatcher_; unique_ptr<EventDispatcher> event_dispatcher_;
static unique_ptr<HttpServer> instance_; static unique_ptr<HttpServer> instance_;
};; };
;
} // namespace ui } // namespace ui
} // namespace duckdb } // namespace duckdb

View File

@@ -6,9 +6,9 @@ namespace duckdb {
class UiExtension : public Extension { class UiExtension : public Extension {
public: public:
void Load(DuckDB &db) override; void Load(DuckDB &db) override;
std::string Name() override; std::string Name() override;
std::string Version() const override; std::string Version() const override;
}; };
} // namespace duckdb } // namespace duckdb

View File

@@ -6,82 +6,95 @@
namespace duckdb { namespace duckdb {
typedef std::string (*simple_tf_t) (ClientContext &); typedef std::string (*simple_tf_t)(ClientContext &);
struct RunOnceTableFunctionState : GlobalTableFunctionState { struct RunOnceTableFunctionState : GlobalTableFunctionState {
RunOnceTableFunctionState() : run(false) {}; RunOnceTableFunctionState() : run(false){};
std::atomic<bool> run; std::atomic<bool> run;
static unique_ptr<GlobalTableFunctionState> Init(ClientContext &, static unique_ptr<GlobalTableFunctionState> Init(ClientContext &,
TableFunctionInitInput &) { TableFunctionInitInput &) {
return make_uniq<RunOnceTableFunctionState>(); return make_uniq<RunOnceTableFunctionState>();
} }
}; };
template <typename T> template <typename T>
T GetSetting(const ClientContext &context, const char *setting_name, const T default_value) { T GetSetting(const ClientContext &context, const char *setting_name,
Value value; const T default_value) {
return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue<T>() : default_value; Value value;
return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue<T>()
: default_value;
} }
namespace internal { namespace internal {
unique_ptr<FunctionData> ResultBind(ClientContext &, TableFunctionBindInput &, unique_ptr<FunctionData> ResultBind(ClientContext &, TableFunctionBindInput &,
vector<LogicalType> &, vector<LogicalType> &,
vector<std::string> &); vector<std::string> &);
bool ShouldRun(TableFunctionInput &input); bool ShouldRun(TableFunctionInput &input);
template <typename Func> template <typename Func> struct CallFunctionHelper;
struct CallFunctionHelper;
template <> template <> struct CallFunctionHelper<std::string (*)()> {
struct CallFunctionHelper<std::string(*)()> { static std::string call(ClientContext &context, TableFunctionInput &input,
static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)()) { std::string (*f)()) {
return f(); return f();
} }
};
template <> struct CallFunctionHelper<std::string (*)(ClientContext &)> {
static std::string call(ClientContext &context, TableFunctionInput &input,
std::string (*f)(ClientContext &)) {
return f(context);
}
}; };
template <> template <>
struct CallFunctionHelper<std::string(*)(ClientContext &)> { struct CallFunctionHelper<std::string (*)(ClientContext &,
static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &)) { TableFunctionInput &)> {
return f(context); static std::string call(ClientContext &context, TableFunctionInput &input,
} std::string (*f)(ClientContext &,
}; TableFunctionInput &)) {
return f(context, input);
template <> }
struct CallFunctionHelper<std::string(*)(ClientContext &, TableFunctionInput &)> {
static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &, TableFunctionInput &)) {
return f(context, input);
}
}; };
template <typename Func, Func func> template <typename Func, Func func>
void TableFunc(ClientContext &context, TableFunctionInput &input, DataChunk &output) { void TableFunc(ClientContext &context, TableFunctionInput &input,
if (!ShouldRun(input)) { DataChunk &output) {
return; if (!ShouldRun(input)) {
} return;
}
const std::string result = CallFunctionHelper<Func>::call(context, input, func); const std::string result =
output.SetCardinality(1); CallFunctionHelper<Func>::call(context, input, func);
output.SetValue(0, 0, result); output.SetCardinality(1);
output.SetValue(0, 0, result);
} }
template <typename Func, Func func> template <typename Func, Func func>
void RegisterTF(DatabaseInstance &instance, const char* name) { void RegisterTF(DatabaseInstance &instance, const char *name) {
TableFunction tf(name, {}, internal::TableFunc<Func, func>, internal::ResultBind, RunOnceTableFunctionState::Init); TableFunction tf(name, {}, internal::TableFunc<Func, func>,
ExtensionUtil::RegisterFunction(instance, tf); internal::ResultBind, RunOnceTableFunctionState::Init);
ExtensionUtil::RegisterFunction(instance, tf);
} }
template <typename Func, Func func> template <typename Func, Func func>
void RegisterTFWithArgs(DatabaseInstance &instance, const char* name, vector<LogicalType> arguments, table_function_bind_t bind) { void RegisterTFWithArgs(DatabaseInstance &instance, const char *name,
TableFunction tf(name, arguments, internal::TableFunc<Func, func>, bind, RunOnceTableFunctionState::Init); vector<LogicalType> arguments,
ExtensionUtil::RegisterFunction(instance, tf); table_function_bind_t bind) {
TableFunction tf(name, arguments, internal::TableFunc<Func, func>, bind,
RunOnceTableFunctionState::Init);
ExtensionUtil::RegisterFunction(instance, tf);
} }
} } // namespace internal
#define RESISTER_TF(name, func) internal::RegisterTF<decltype(&func), &func>(instance, name) #define RESISTER_TF(name, func) \
internal::RegisterTF<decltype(&func), &func>(instance, name)
#define RESISTER_TF_ARGS(name, args, func, bind) internal::RegisterTFWithArgs<decltype(&func), &func>(instance, name, args, bind) #define RESISTER_TF_ARGS(name, args, func, bind) \
internal::RegisterTFWithArgs<decltype(&func), &func>(instance, name, args, \
bind)
} // namespace duckdb } // namespace duckdb

View File

@@ -8,41 +8,41 @@ namespace duckdb {
namespace ui { namespace ui {
struct EmptyResult { struct EmptyResult {
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
struct TokenizeResult { struct TokenizeResult {
duckdb::vector<idx_t> offsets; duckdb::vector<idx_t> offsets;
duckdb::vector<duckdb::SimplifiedTokenType> types; duckdb::vector<duckdb::SimplifiedTokenType> types;
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
struct ColumnNamesAndTypes { struct ColumnNamesAndTypes {
duckdb::vector<std::string> names; duckdb::vector<std::string> names;
duckdb::vector<duckdb::LogicalType> types; duckdb::vector<duckdb::LogicalType> types;
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
struct Chunk { struct Chunk {
uint16_t row_count; uint16_t row_count;
duckdb::vector<duckdb::Vector> vectors; duckdb::vector<duckdb::Vector> vectors;
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
struct SuccessResult { struct SuccessResult {
ColumnNamesAndTypes column_names_and_types; ColumnNamesAndTypes column_names_and_types;
duckdb::vector<Chunk> chunks; duckdb::vector<Chunk> chunks;
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
struct ErrorResult { struct ErrorResult {
std::string error; std::string error;
void Serialize(duckdb::Serializer &serializer) const; void Serialize(duckdb::Serializer &serializer) const;
}; };
} // namespace ui } // namespace ui

View File

@@ -15,110 +15,123 @@
#define OPEN_COMMAND "open" #define OPEN_COMMAND "open"
#endif #endif
#define UI_LOCAL_PORT_SETTING_NAME "ui_local_port" #define UI_LOCAL_PORT_SETTING_NAME "ui_local_port"
#define UI_LOCAL_PORT_SETTING_DESCRIPTION "Local port on which the UI server listens" #define UI_LOCAL_PORT_SETTING_DESCRIPTION \
#define UI_LOCAL_PORT_SETTING_DEFAULT 4213 "Local port on which the UI server listens"
#define UI_LOCAL_PORT_SETTING_DEFAULT 4213
#define UI_REMOTE_URL_SETTING_NAME "ui_remote_url" #define UI_REMOTE_URL_SETTING_NAME "ui_remote_url"
#define UI_REMOTE_URL_SETTING_DESCRIPTION "Remote URL to which the UI server forwards GET requests" #define UI_REMOTE_URL_SETTING_DESCRIPTION \
#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com" "Remote URL to which the UI server forwards GET requests"
#define UI_REMOTE_URL_SETTING_DEFAULT "https://app.motherduck.com"
namespace duckdb { namespace duckdb {
namespace internal { namespace internal {
bool StartHttpServer(const ClientContext &context) { bool StartHttpServer(const ClientContext &context) {
const auto url = GetSetting<std::string>(context, UI_REMOTE_URL_SETTING_NAME, const auto url =
GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT)); GetSetting<std::string>(context, UI_REMOTE_URL_SETTING_NAME,
const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DEFAULT);; GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME,
return ui::HttpServer::instance()->Start(port, url, context.db); UI_REMOTE_URL_SETTING_DEFAULT));
const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME,
UI_LOCAL_PORT_SETTING_DEFAULT);
;
return ui::HttpServer::instance()->Start(port, url, context.db);
} }
std::string GetHttpServerLocalURL() { std::string GetHttpServerLocalURL() {
return StringUtil::Format("http://localhost:%d/", ui::HttpServer::instance()->LocalPort()); return StringUtil::Format("http://localhost:%d/",
ui::HttpServer::instance()->LocalPort());
} }
} // namespace internal } // namespace internal
std::string StartUIFunction(ClientContext &context) { std::string StartUIFunction(ClientContext &context) {
internal::StartHttpServer(context); internal::StartHttpServer(context);
auto local_url = internal::GetHttpServerLocalURL(); auto local_url = internal::GetHttpServerLocalURL();
const std::string command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); const std::string command =
return system(command.c_str()) ? StringUtil::Format("%s %s", OPEN_COMMAND, local_url);
StringUtil::Format("Navigate browser to %s", local_url) // open command failed return system(command.c_str())
: StringUtil::Format("MotherDuck UI started at %s", local_url); ? StringUtil::Format("Navigate browser to %s",
local_url) // open command failed
: StringUtil::Format("MotherDuck UI started at %s", local_url);
} }
std::string StartUIServerFunction(ClientContext &context) { std::string StartUIServerFunction(ClientContext &context) {
const char* already = internal::StartHttpServer(context) ? "already " : ""; const char *already = internal::StartHttpServer(context) ? "already " : "";
return StringUtil::Format("MotherDuck UI server %sstarted at %s", already, internal::GetHttpServerLocalURL()); return StringUtil::Format("MotherDuck UI server %sstarted at %s", already,
internal::GetHttpServerLocalURL());
} }
std::string StopUIServerFunction() { std::string StopUIServerFunction() {
return ui::HttpServer::instance()->Stop() ? "UI server stopped" : "UI server already stopped"; return ui::HttpServer::instance()->Stop() ? "UI server stopped"
: "UI server already stopped";
} }
// Connected notification // Connected notification
struct NotifyConnectedFunctionData : public TableFunctionData { struct NotifyConnectedFunctionData : public TableFunctionData {
NotifyConnectedFunctionData(std::string _token) : token(_token) {} NotifyConnectedFunctionData(std::string _token) : token(_token) {}
std::string token; std::string token;
}; };
static unique_ptr<FunctionData> NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input, static unique_ptr<FunctionData>
vector<LogicalType> &out_types, vector<string> &out_names) { NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input,
if (input.inputs[0].IsNull()) { vector<LogicalType> &out_types, vector<string> &out_names) {
throw BinderException("Must provide a token"); if (input.inputs[0].IsNull()) {
} throw BinderException("Must provide a token");
}
out_names.emplace_back("result"); out_names.emplace_back("result");
out_types.emplace_back(LogicalType::VARCHAR); out_types.emplace_back(LogicalType::VARCHAR);
return make_uniq<NotifyConnectedFunctionData>(input.inputs[0].ToString()); return make_uniq<NotifyConnectedFunctionData>(input.inputs[0].ToString());
} }
std::string NotifyConnectedFunction(ClientContext &context, TableFunctionInput &input) { std::string NotifyConnectedFunction(ClientContext &context,
auto &inputs = input.bind_data->Cast<NotifyConnectedFunctionData>(); TableFunctionInput &input) {
ui::HttpServer::instance()->SendConnectedEvent(inputs.token); auto &inputs = input.bind_data->Cast<NotifyConnectedFunctionData>();
return "OK"; ui::HttpServer::instance()->SendConnectedEvent(inputs.token);
return "OK";
} }
// - connected notification // - connected notification
std::string NotifyCatalogChangedFunction() { std::string NotifyCatalogChangedFunction() {
ui::HttpServer::instance()->SendCatalogChangedEvent(); ui::HttpServer::instance()->SendCatalogChangedEvent();
return "OK"; return "OK";
} }
static void LoadInternal(DatabaseInstance &instance) { static void LoadInternal(DatabaseInstance &instance) {
auto &config = DBConfig::GetConfig(instance); auto &config = DBConfig::GetConfig(instance);
config.AddExtensionOption(UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION, config.AddExtensionOption(
LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT)); UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION,
LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT));
config.AddExtensionOption( config.AddExtensionOption(
UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION, LogicalType::VARCHAR, UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION,
Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT))); LogicalType::VARCHAR,
Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME,
UI_REMOTE_URL_SETTING_DEFAULT)));
RESISTER_TF("start_ui", StartUIFunction); RESISTER_TF("start_ui", StartUIFunction);
RESISTER_TF("start_ui_server", StartUIServerFunction); RESISTER_TF("start_ui_server", StartUIServerFunction);
RESISTER_TF("stop_ui_server", StopUIServerFunction); RESISTER_TF("stop_ui_server", StopUIServerFunction);
RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction); RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction);
RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR}, NotifyConnectedFunction, NotifyConnectedBind); RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR},
NotifyConnectedFunction, NotifyConnectedBind);
} }
void UiExtension::Load(DuckDB &db) { void UiExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }
LoadInternal(*db.instance); std::string UiExtension::Name() { return "ui"; }
}
std::string UiExtension::Name() {
return "ui";
}
std::string UiExtension::Version() const { std::string UiExtension::Version() const {
#ifdef UI_EXTENSION_GIT_SHA #ifdef UI_EXTENSION_GIT_SHA
return UI_EXTENSION_GIT_SHA; return UI_EXTENSION_GIT_SHA;
#else #else
return ""; return "";
#endif #endif
} }
@@ -127,12 +140,12 @@ std::string UiExtension::Version() const {
extern "C" { extern "C" {
DUCKDB_EXTENSION_API void ui_init(duckdb::DatabaseInstance &db) { DUCKDB_EXTENSION_API void ui_init(duckdb::DatabaseInstance &db) {
duckdb::DuckDB db_wrapper(db); duckdb::DuckDB db_wrapper(db);
db_wrapper.LoadExtension<duckdb::UiExtension>(); db_wrapper.LoadExtension<duckdb::UiExtension>();
} }
DUCKDB_EXTENSION_API const char *ui_version() { DUCKDB_EXTENSION_API const char *ui_version() {
return duckdb::DuckDB::LibraryVersion(); return duckdb::DuckDB::LibraryVersion();
} }
} }

View File

@@ -5,57 +5,60 @@
namespace duckdb { namespace duckdb {
// Copied from https://www.mycplus.com/source-code/c-source-code/base64-encode-decode/ // Copied from
constexpr char k_encoding_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+_"; // https://www.mycplus.com/source-code/c-source-code/base64-encode-decode/
constexpr char k_encoding_table[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+_";
std::vector<char> BuildDecodingTable() { std::vector<char> BuildDecodingTable() {
std::vector<char> decoding_table; std::vector<char> decoding_table;
decoding_table.resize(256); decoding_table.resize(256);
for (int i = 0; i < 64; ++i) { for (int i = 0; i < 64; ++i) {
decoding_table[static_cast<unsigned char>(k_encoding_table[i])] = i; decoding_table[static_cast<unsigned char>(k_encoding_table[i])] = i;
} }
return decoding_table; return decoding_table;
} }
const static std::vector<char> k_decoding_table = BuildDecodingTable(); const static std::vector<char> k_decoding_table = BuildDecodingTable();
std::string DecodeBase64(const std::string &data) { std::string DecodeBase64(const std::string &data) {
size_t input_length = data.size(); size_t input_length = data.size();
if (input_length < 4 || input_length % 4 != 0) { if (input_length < 4 || input_length % 4 != 0) {
// Handle this exception // Handle this exception
return ""; return "";
} }
size_t output_length = input_length / 4 * 3; size_t output_length = input_length / 4 * 3;
if (data[input_length - 1] == '=') { if (data[input_length - 1] == '=') {
output_length--; output_length--;
} }
if (data[input_length - 2] == '=') { if (data[input_length - 2] == '=') {
output_length--; output_length--;
} }
std::string decoded_data; std::string decoded_data;
decoded_data.resize(output_length); decoded_data.resize(output_length);
for (size_t i = 0, j = 0; i < input_length;) { for (size_t i = 0, j = 0; i < input_length;) {
uint32_t sextet_a = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; uint32_t sextet_a = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]];
uint32_t sextet_b = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; uint32_t sextet_b = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]];
uint32_t sextet_c = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; uint32_t sextet_c = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]];
uint32_t sextet_d = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]]; uint32_t sextet_d = data[i] == '=' ? 0 & i++ : k_decoding_table[data[i++]];
uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) +
(sextet_c << 1 * 6) + (sextet_d << 0 * 6);
if (j < output_length) { if (j < output_length) {
decoded_data[j++] = (triple >> 2 * 8) & 0xFF; decoded_data[j++] = (triple >> 2 * 8) & 0xFF;
} }
if (j < output_length) { if (j < output_length) {
decoded_data[j++] = (triple >> 1 * 8) & 0xFF; decoded_data[j++] = (triple >> 1 * 8) & 0xFF;
} }
if (j < output_length) { if (j < output_length) {
decoded_data[j++] = (triple >> 0 * 8) & 0xFF; decoded_data[j++] = (triple >> 0 * 8) & 0xFF;
} }
} }
return decoded_data; return decoded_data;
} }
} // namespace duckdb } // namespace duckdb

View File

@@ -6,29 +6,29 @@
namespace duckdb { namespace duckdb {
const char *TryGetEnv(const char *name) { const char *TryGetEnv(const char *name) {
const char *res = std::getenv(name); const char *res = std::getenv(name);
if (res) { if (res) {
return res; return res;
} }
return std::getenv(StringUtil::Upper(name).c_str()); return std::getenv(StringUtil::Upper(name).c_str());
} }
std::string GetEnvOrDefault(const char *name, const char *default_value) { std::string GetEnvOrDefault(const char *name, const char *default_value) {
const char *res = TryGetEnv(name); const char *res = TryGetEnv(name);
if (res) { if (res) {
return res; return res;
} }
return default_value; return default_value;
} }
bool IsEnvEnabled(const char *name) { bool IsEnvEnabled(const char *name) {
const char *res = TryGetEnv(name); const char *res = TryGetEnv(name);
if (!res) { if (!res) {
return false; return false;
} }
auto lc_res = StringUtil::Lower(res); auto lc_res = StringUtil::Lower(res);
return lc_res == "1" || lc_res == "true"; return lc_res == "1" || lc_res == "true";
} }
} // namespace duckdb } // namespace duckdb

View File

@@ -4,22 +4,23 @@ namespace duckdb {
namespace internal { namespace internal {
bool ShouldRun(TableFunctionInput &input) { bool ShouldRun(TableFunctionInput &input) {
auto state = dynamic_cast<RunOnceTableFunctionState *>(input.global_state.get()); auto state =
D_ASSERT(state != nullptr); dynamic_cast<RunOnceTableFunctionState *>(input.global_state.get());
if (state->run) { D_ASSERT(state != nullptr);
return false; if (state->run) {
} return false;
}
state->run = true; state->run = true;
return true; return true;
} }
unique_ptr<FunctionData> ResultBind(ClientContext &, TableFunctionBindInput &, unique_ptr<FunctionData> ResultBind(ClientContext &, TableFunctionBindInput &,
vector<LogicalType> &out_types, vector<LogicalType> &out_types,
vector<std::string> &out_names) { vector<std::string> &out_names) {
out_names.emplace_back("result"); out_names.emplace_back("result");
out_types.emplace_back(LogicalType::VARCHAR); out_types.emplace_back(LogicalType::VARCHAR);
return nullptr; return nullptr;
} }
} // namespace internal } // namespace internal

View File

@@ -6,43 +6,46 @@
namespace duckdb { namespace duckdb {
namespace ui { namespace ui {
void EmptyResult::Serialize(Serializer &) const { void EmptyResult::Serialize(Serializer &) const {}
}
void TokenizeResult::Serialize(Serializer &serializer) const { void TokenizeResult::Serialize(Serializer &serializer) const {
serializer.WriteProperty(100, "offsets", offsets); serializer.WriteProperty(100, "offsets", offsets);
serializer.WriteProperty(101, "types", types); serializer.WriteProperty(101, "types", types);
} }
// Adapted from parts of DataChunk::Serialize // Adapted from parts of DataChunk::Serialize
void ColumnNamesAndTypes::Serialize(Serializer &serializer) const { void ColumnNamesAndTypes::Serialize(Serializer &serializer) const {
serializer.WriteProperty(100, "names", names); serializer.WriteProperty(100, "names", names);
serializer.WriteProperty(101, "types", types); serializer.WriteProperty(101, "types", types);
} }
// Adapted from parts of DataChunk::Serialize // Adapted from parts of DataChunk::Serialize
void Chunk::Serialize(Serializer &serializer) const { void Chunk::Serialize(Serializer &serializer) const {
serializer.WriteProperty(100, "row_count", row_count); serializer.WriteProperty(100, "row_count", row_count);
serializer.WriteList(101, "vectors", vectors.size(), [&](Serializer::List &list, idx_t i) { serializer.WriteList(101, "vectors", vectors.size(),
list.WriteObject([&](Serializer &object) { [&](Serializer::List &list, idx_t i) {
// Reference the vector to avoid potentially mutating it during serialization list.WriteObject([&](Serializer &object) {
Vector serialized_vector(vectors[i].GetType()); // Reference the vector to avoid potentially mutating
serialized_vector.Reference(vectors[i]); // it during serialization
serialized_vector.Serialize(object, row_count); Vector serialized_vector(vectors[i].GetType());
}); serialized_vector.Reference(vectors[i]);
}); serialized_vector.Serialize(object, row_count);
});
});
} }
void SuccessResult::Serialize(Serializer &serializer) const { void SuccessResult::Serialize(Serializer &serializer) const {
serializer.WriteProperty(100, "success", true); serializer.WriteProperty(100, "success", true);
serializer.WriteProperty(101, "column_names_and_types", column_names_and_types); serializer.WriteProperty(101, "column_names_and_types",
serializer.WriteList(102, "chunks", chunks.size(), column_names_and_types);
[&](Serializer::List &list, idx_t i) { list.WriteElement(chunks[i]); }); serializer.WriteList(
102, "chunks", chunks.size(),
[&](Serializer::List &list, idx_t i) { list.WriteElement(chunks[i]); });
} }
void ErrorResult::Serialize(Serializer &serializer) const { void ErrorResult::Serialize(Serializer &serializer) const {
serializer.WriteProperty(100, "success", false); serializer.WriteProperty(100, "success", false);
serializer.WriteProperty(101, "error", error); serializer.WriteProperty(101, "error", error);
} }
} // namespace ui } // namespace ui