Fix ownership model of DatabaseInstance
This commit is contained in:
@@ -15,8 +15,13 @@ project(${TARGET_NAME})
|
|||||||
include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib)
|
include_directories(src/include ${DuckDB_SOURCE_DIR}/third_party/httplib)
|
||||||
|
|
||||||
set(EXTENSION_SOURCES
|
set(EXTENSION_SOURCES
|
||||||
src/ui_extension.cpp src/http_server.cpp src/utils/encoding.cpp
|
src/ui_extension.cpp
|
||||||
src/utils/env.cpp src/utils/helpers.cpp src/utils/serialization.cpp)
|
src/http_server.cpp
|
||||||
|
src/state.cpp
|
||||||
|
src/utils/encoding.cpp
|
||||||
|
src/utils/env.cpp
|
||||||
|
src/utils/helpers.cpp
|
||||||
|
src/utils/serialization.cpp)
|
||||||
|
|
||||||
find_package(Git)
|
find_package(Git)
|
||||||
if(NOT Git_FOUND)
|
if(NOT Git_FOUND)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "utils/env.hpp"
|
#include "utils/env.hpp"
|
||||||
#include "utils/serialization.hpp"
|
#include "utils/serialization.hpp"
|
||||||
#include "utils/encoding.hpp"
|
#include "utils/encoding.hpp"
|
||||||
|
#include "state.hpp"
|
||||||
|
|
||||||
// Chosen to be no more than half of the lesser of the two limits:
|
// Chosen to be no more than half of the lesser of the two limits:
|
||||||
// - The default httplib thread pool size = 8
|
// - The default httplib thread pool size = 8
|
||||||
@@ -70,14 +71,34 @@ void EventDispatcher::Close() {
|
|||||||
|
|
||||||
unique_ptr<HttpServer> HttpServer::server_instance;
|
unique_ptr<HttpServer> HttpServer::server_instance;
|
||||||
|
|
||||||
HttpServer *HttpServer::instance() {
|
HttpServer *HttpServer::GetInstance(ClientContext &context) {
|
||||||
if (!server_instance) {
|
if (server_instance) {
|
||||||
server_instance = make_uniq<HttpServer>();
|
// We already have an instance, make sure we're running on the right DB
|
||||||
|
server_instance->UpdateDatabaseInstance(context.db);
|
||||||
|
} else {
|
||||||
|
server_instance = make_uniq<HttpServer>(context.db);
|
||||||
std::atexit(HttpServer::StopInstance);
|
std::atexit(HttpServer::StopInstance);
|
||||||
}
|
}
|
||||||
return server_instance.get();
|
return server_instance.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void HttpServer::UpdateDatabaseInstanceIfRunning(
|
||||||
|
shared_ptr<DatabaseInstance> db) {
|
||||||
|
if (server_instance) {
|
||||||
|
server_instance->UpdateDatabaseInstance(db);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void HttpServer::UpdateDatabaseInstance(
|
||||||
|
shared_ptr<DatabaseInstance> context_db) {
|
||||||
|
const auto current_db = server_instance->ddb_instance.lock();
|
||||||
|
if (current_db != context_db) {
|
||||||
|
server_instance->StopWatcher(); // Likely already stopped, but just in case
|
||||||
|
server_instance->ddb_instance = context_db;
|
||||||
|
server_instance->StartWatcher();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool HttpServer::Started() {
|
bool HttpServer::Started() {
|
||||||
return server_instance && server_instance->main_thread;
|
return server_instance && server_instance->main_thread;
|
||||||
}
|
}
|
||||||
@@ -88,16 +109,18 @@ void HttpServer::StopInstance() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HttpServer::Start(const uint16_t _local_port,
|
const HttpServer &HttpServer::Start(const uint16_t _local_port,
|
||||||
const std::string &_remote_url,
|
const std::string &_remote_url,
|
||||||
const shared_ptr<DatabaseInstance> &_ddb_instance) {
|
bool *was_started) {
|
||||||
if (main_thread) {
|
if (main_thread) {
|
||||||
return false;
|
if (was_started) {
|
||||||
|
*was_started = true;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
local_port = _local_port;
|
local_port = _local_port;
|
||||||
remote_url = _remote_url;
|
remote_url = _remote_url;
|
||||||
ddb_instance = _ddb_instance;
|
|
||||||
#ifndef UI_EXTENSION_GIT_SHA
|
#ifndef UI_EXTENSION_GIT_SHA
|
||||||
#error "UI_EXTENSION_GIT_SHA must be defined"
|
#error "UI_EXTENSION_GIT_SHA must be defined"
|
||||||
#endif
|
#endif
|
||||||
@@ -105,12 +128,33 @@ bool HttpServer::Start(const uint16_t _local_port,
|
|||||||
DuckDB::Platform());
|
DuckDB::Platform());
|
||||||
event_dispatcher = make_uniq<EventDispatcher>();
|
event_dispatcher = make_uniq<EventDispatcher>();
|
||||||
main_thread = make_uniq<std::thread>(&HttpServer::Run, this);
|
main_thread = make_uniq<std::thread>(&HttpServer::Run, this);
|
||||||
|
StartWatcher();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void HttpServer::StartWatcher() {
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> guard(watcher_mutex);
|
std::lock_guard<std::mutex> guard(watcher_mutex);
|
||||||
watcher_should_run = true;
|
watcher_should_run = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!watcher_thread) {
|
||||||
watcher_thread = make_uniq<std::thread>(&HttpServer::Watch, this);
|
watcher_thread = make_uniq<std::thread>(&HttpServer::Watch, this);
|
||||||
return true;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void HttpServer::StopWatcher() {
|
||||||
|
if (!watcher_thread) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(watcher_mutex);
|
||||||
|
watcher_should_run = false;
|
||||||
|
}
|
||||||
|
watcher_cv.notify_all();
|
||||||
|
watcher_thread->join();
|
||||||
|
watcher_thread.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HttpServer::Stop() {
|
bool HttpServer::Stop() {
|
||||||
@@ -121,27 +165,20 @@ bool HttpServer::Stop() {
|
|||||||
event_dispatcher->Close();
|
event_dispatcher->Close();
|
||||||
server.stop();
|
server.stop();
|
||||||
|
|
||||||
if (watcher_thread) {
|
StopWatcher();
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> guard(watcher_mutex);
|
|
||||||
watcher_should_run = false;
|
|
||||||
}
|
|
||||||
watcher_cv.notify_all();
|
|
||||||
watcher_thread->join();
|
|
||||||
watcher_thread.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
main_thread->join();
|
main_thread->join();
|
||||||
main_thread.reset();
|
main_thread.reset();
|
||||||
event_dispatcher.reset();
|
event_dispatcher.reset();
|
||||||
connections.clear();
|
|
||||||
ddb_instance.reset();
|
ddb_instance.reset();
|
||||||
remote_url = "";
|
remote_url = "";
|
||||||
local_port = 0;
|
local_port = 0;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t HttpServer::LocalPort() { return local_port; }
|
std::string HttpServer::LocalUrl() const {
|
||||||
|
return StringUtil::Format("http://localhost:%d/", local_port);
|
||||||
|
}
|
||||||
|
|
||||||
void HttpServer::SendConnectedEvent(const std::string &token) {
|
void HttpServer::SendConnectedEvent(const std::string &token) {
|
||||||
SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token));
|
SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token));
|
||||||
@@ -157,13 +194,14 @@ void HttpServer::SendEvent(const std::string &message) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) {
|
void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db,
|
||||||
|
CatalogState &last_state) {
|
||||||
bool has_change = false;
|
bool has_change = false;
|
||||||
duckdb::Connection con{*ddb_instance};
|
duckdb::Connection con{db};
|
||||||
auto &context = *con.context;
|
auto &context = *con.context;
|
||||||
con.BeginTransaction();
|
con.BeginTransaction();
|
||||||
const auto &databases =
|
|
||||||
ddb_instance->GetDatabaseManager().GetDatabases(context);
|
const auto &databases = db.GetDatabaseManager().GetDatabases(context);
|
||||||
std::set<idx_t> db_oids;
|
std::set<idx_t> db_oids;
|
||||||
|
|
||||||
// Check currently attached databases
|
// Check currently attached databases
|
||||||
@@ -204,7 +242,11 @@ void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) {
|
|||||||
void HttpServer::Watch() {
|
void HttpServer::Watch() {
|
||||||
CatalogState last_state;
|
CatalogState last_state;
|
||||||
while (watcher_should_run) {
|
while (watcher_should_run) {
|
||||||
WatchForCatalogUpdate(last_state);
|
auto db = ddb_instance.lock();
|
||||||
|
if (!db) {
|
||||||
|
break; // DB went away, nothing to watch
|
||||||
|
}
|
||||||
|
WatchForCatalogUpdate(*db, last_state);
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(watcher_mutex);
|
std::unique_lock<std::mutex> lock(watcher_mutex);
|
||||||
watcher_cv.wait_for(lock,
|
watcher_cv.wait_for(lock,
|
||||||
@@ -256,13 +298,21 @@ void HttpServer::HandleGetLocalEvents(const httplib::Request &req,
|
|||||||
|
|
||||||
void HttpServer::HandleGetLocalToken(const httplib::Request &req,
|
void HttpServer::HandleGetLocalToken(const httplib::Request &req,
|
||||||
httplib::Response &res) {
|
httplib::Response &res) {
|
||||||
if (!ddb_instance->ExtensionIsLoaded("motherduck")) {
|
auto db = ddb_instance.lock();
|
||||||
|
if (!db) {
|
||||||
|
res.status = 500;
|
||||||
|
res.set_content("Database was invalidated, UI needs to be restarted",
|
||||||
|
"text/plain");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!db->ExtensionIsLoaded("motherduck")) {
|
||||||
res.set_content("", "text/plain"); // UI expects an empty response if the
|
res.set_content("", "text/plain"); // UI expects an empty response if the
|
||||||
// extension is not loaded
|
// extension is not loaded
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Connection connection(*ddb_instance);
|
Connection connection(*db);
|
||||||
auto query_res = connection.Query("CALL get_md_token()");
|
auto query_res = connection.Query("CALL get_md_token()");
|
||||||
if (query_res->HasError()) {
|
if (query_res->HasError()) {
|
||||||
res.status = 500;
|
res.status = 500;
|
||||||
@@ -316,7 +366,14 @@ void HttpServer::HandleInterrupt(const httplib::Request &req,
|
|||||||
auto description = req.get_header_value("X-MD-Description");
|
auto description = req.get_header_value("X-MD-Description");
|
||||||
auto connection_name = req.get_header_value("X-MD-Connection-Name");
|
auto connection_name = req.get_header_value("X-MD-Connection-Name");
|
||||||
|
|
||||||
auto connection = FindConnection(connection_name);
|
auto db = ddb_instance.lock();
|
||||||
|
if (!db) {
|
||||||
|
res.status = 404;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto connection =
|
||||||
|
UIStorageExtensionInfo::GetState(*db).FindConnection(connection_name);
|
||||||
if (!connection) {
|
if (!connection) {
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
return;
|
return;
|
||||||
@@ -347,13 +404,23 @@ void HttpServer::DoHandleRun(const httplib::Request &req,
|
|||||||
|
|
||||||
std::string content = ReadContent(content_reader);
|
std::string content = ReadContent(content_reader);
|
||||||
|
|
||||||
auto connection = FindOrCreateConnection(connection_name);
|
auto db = ddb_instance.lock();
|
||||||
|
if (!db) {
|
||||||
|
SetResponseErrorResult(
|
||||||
|
res, "Database was invalidated, UI needs to be restarted");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto connection =
|
||||||
|
UIStorageExtensionInfo::GetState(*db).FindOrCreateConnection(
|
||||||
|
*db, connection_name);
|
||||||
|
|
||||||
// Set current database if optional header is provided.
|
// Set current database if optional header is provided.
|
||||||
if (!database_name.empty()) {
|
if (!database_name.empty()) {
|
||||||
connection->context->RunFunctionInTransaction([&] {
|
auto &context = *connection->context;
|
||||||
ddb_instance->GetDatabaseManager().SetDefaultDatabase(
|
context.RunFunctionInTransaction([&] {
|
||||||
*connection->context, database_name);
|
auto &manager = context.db->GetDatabaseManager();
|
||||||
|
manager.SetDefaultDatabase(context, database_name);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,48 +531,6 @@ HttpServer::ReadContent(const httplib::ContentReader &content_reader) {
|
|||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Connection>
|
|
||||||
HttpServer::FindConnection(const std::string &connection_name) {
|
|
||||||
if (connection_name.empty()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need to protect access to the connections map because this can be called
|
|
||||||
// from multiple threads.
|
|
||||||
std::lock_guard<std::mutex> guard(connections_mutex);
|
|
||||||
|
|
||||||
auto result = connections.find(connection_name);
|
|
||||||
if (result != connections.end()) {
|
|
||||||
return result->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_ptr<Connection>
|
|
||||||
HttpServer::FindOrCreateConnection(const std::string &connection_name) {
|
|
||||||
if (connection_name.empty()) {
|
|
||||||
// If no connection name was provided, create and return a new connection
|
|
||||||
// but don't remember it.
|
|
||||||
return make_shared_ptr<Connection>(*ddb_instance);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need to protect access to the connections map because this can be called
|
|
||||||
// from multiple threads.
|
|
||||||
std::lock_guard<std::mutex> guard(connections_mutex);
|
|
||||||
|
|
||||||
// If an existing connection with the provided name was found, return it.
|
|
||||||
auto result = connections.find(connection_name);
|
|
||||||
if (result != connections.end()) {
|
|
||||||
return result->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, create a new one, remember it, and return it.
|
|
||||||
auto connection = make_shared_ptr<Connection>(*ddb_instance);
|
|
||||||
connections[connection_name] = connection;
|
|
||||||
return connection;
|
|
||||||
}
|
|
||||||
|
|
||||||
void HttpServer::SetResponseContent(httplib::Response &res,
|
void HttpServer::SetResponseContent(httplib::Response &res,
|
||||||
const MemoryStream &content) {
|
const MemoryStream &content) {
|
||||||
auto data = content.GetData();
|
auto data = content.GetData();
|
||||||
|
|||||||
@@ -41,46 +41,51 @@ private:
|
|||||||
class HttpServer {
|
class HttpServer {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static HttpServer *instance();
|
HttpServer(shared_ptr<DatabaseInstance> _ddb_instance)
|
||||||
|
: ddb_instance(_ddb_instance) {}
|
||||||
|
static HttpServer *GetInstance(ClientContext &);
|
||||||
|
static void UpdateDatabaseInstanceIfRunning(shared_ptr<DatabaseInstance>);
|
||||||
static bool Started();
|
static bool Started();
|
||||||
static void StopInstance();
|
static void StopInstance();
|
||||||
|
|
||||||
bool Start(const uint16_t localPort, const std::string &remoteUrl,
|
const HttpServer &Start(const uint16_t local_port,
|
||||||
const shared_ptr<DatabaseInstance> &ddbInstance);
|
const std::string &remote_url,
|
||||||
|
bool *was_started = nullptr);
|
||||||
bool Stop();
|
bool Stop();
|
||||||
uint16_t LocalPort();
|
std::string LocalUrl() const;
|
||||||
void SendConnectedEvent(const std::string &token);
|
void SendConnectedEvent(const std::string &token);
|
||||||
void SendCatalogChangedEvent();
|
void SendCatalogChangedEvent();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void UpdateDatabaseInstance(shared_ptr<DatabaseInstance> context_db);
|
||||||
void SendEvent(const std::string &message);
|
void SendEvent(const std::string &message);
|
||||||
void Run();
|
void Run();
|
||||||
void Watch();
|
void Watch();
|
||||||
|
void StartWatcher();
|
||||||
|
void StopWatcher();
|
||||||
void HandleGetLocalEvents(const httplib::Request &req,
|
void HandleGetLocalEvents(const httplib::Request &req,
|
||||||
httplib::Response &res);
|
httplib::Response &res);
|
||||||
void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res);
|
void HandleGetLocalToken(const httplib::Request &req, httplib::Response &res);
|
||||||
void HandleGet(const httplib::Request &req, httplib::Response &res);
|
void HandleGet(const httplib::Request &req, httplib::Response &res);
|
||||||
void HandleInterrupt(const httplib::Request &req, httplib::Response &res);
|
void HandleInterrupt(const httplib::Request &req, httplib::Response &res);
|
||||||
void DoHandleRun(const httplib::Request &req, httplib::Response &res,
|
void DoHandleRun(const httplib::Request &req, httplib::Response &res,
|
||||||
const httplib::ContentReader &contentReader);
|
const httplib::ContentReader &content_reader);
|
||||||
void HandleRun(const httplib::Request &req, httplib::Response &res,
|
void HandleRun(const httplib::Request &req, httplib::Response &res,
|
||||||
const httplib::ContentReader &contentReader);
|
const httplib::ContentReader &content_reader);
|
||||||
void HandleTokenize(const httplib::Request &req, httplib::Response &res,
|
void HandleTokenize(const httplib::Request &req, httplib::Response &res,
|
||||||
const httplib::ContentReader &contentReader);
|
const httplib::ContentReader &content_reader);
|
||||||
std::string ReadContent(const httplib::ContentReader &contentReader);
|
std::string ReadContent(const httplib::ContentReader &content_reader);
|
||||||
shared_ptr<Connection> FindConnection(const std::string &connectionName);
|
|
||||||
shared_ptr<Connection>
|
|
||||||
FindOrCreateConnection(const std::string &connectionName);
|
|
||||||
void SetResponseContent(httplib::Response &res, const MemoryStream &content);
|
void SetResponseContent(httplib::Response &res, const MemoryStream &content);
|
||||||
void SetResponseEmptyResult(httplib::Response &res);
|
void SetResponseEmptyResult(httplib::Response &res);
|
||||||
void SetResponseErrorResult(httplib::Response &res, const std::string &error);
|
void SetResponseErrorResult(httplib::Response &res, const std::string &error);
|
||||||
|
|
||||||
// Watchers
|
// Watchers
|
||||||
void WatchForCatalogUpdate(CatalogState &last_state);
|
void WatchForCatalogUpdate(DatabaseInstance &, CatalogState &last_state);
|
||||||
|
|
||||||
uint16_t local_port;
|
uint16_t local_port;
|
||||||
std::string remote_url;
|
std::string remote_url;
|
||||||
shared_ptr<DatabaseInstance> ddb_instance;
|
weak_ptr<DatabaseInstance> ddb_instance;
|
||||||
std::string user_agent;
|
std::string user_agent;
|
||||||
httplib::Server server;
|
httplib::Server server;
|
||||||
unique_ptr<std::thread> main_thread;
|
unique_ptr<std::thread> main_thread;
|
||||||
@@ -89,8 +94,6 @@ private:
|
|||||||
std::condition_variable watcher_cv;
|
std::condition_variable watcher_cv;
|
||||||
std::atomic<bool> watcher_should_run;
|
std::atomic<bool> watcher_should_run;
|
||||||
|
|
||||||
std::mutex connections_mutex;
|
|
||||||
std::unordered_map<std::string, shared_ptr<Connection>> connections;
|
|
||||||
unique_ptr<EventDispatcher> event_dispatcher;
|
unique_ptr<EventDispatcher> event_dispatcher;
|
||||||
|
|
||||||
static unique_ptr<HttpServer> server_instance;
|
static unique_ptr<HttpServer> server_instance;
|
||||||
|
|||||||
24
src/include/state.hpp
Normal file
24
src/include/state.hpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <duckdb/storage/storage_extension.hpp>
|
||||||
|
#include <duckdb/main/connection.hpp>
|
||||||
|
|
||||||
|
namespace duckdb {
|
||||||
|
const static std::string STORAGE_EXTENSION_KEY = "ui";
|
||||||
|
|
||||||
|
class UIStorageExtensionInfo : public StorageExtensionInfo {
|
||||||
|
public:
|
||||||
|
static UIStorageExtensionInfo &GetState(const DatabaseInstance &instance);
|
||||||
|
|
||||||
|
shared_ptr<Connection> FindConnection(const std::string &connection_name);
|
||||||
|
shared_ptr<Connection>
|
||||||
|
FindOrCreateConnection(DatabaseInstance &db,
|
||||||
|
const std::string &connection_name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex connections_mutex;
|
||||||
|
std::unordered_map<std::string, shared_ptr<Connection>> connections;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace duckdb
|
||||||
@@ -35,12 +35,6 @@ bool ShouldRun(TableFunctionInput &input);
|
|||||||
|
|
||||||
template <typename Func> struct CallFunctionHelper;
|
template <typename Func> struct CallFunctionHelper;
|
||||||
|
|
||||||
template <> struct CallFunctionHelper<std::string (*)()> {
|
|
||||||
static std::string call(ClientContext &context, TableFunctionInput &input,
|
|
||||||
std::string (*f)()) {
|
|
||||||
return f();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <> struct CallFunctionHelper<std::string (*)(ClientContext &)> {
|
template <> struct CallFunctionHelper<std::string (*)(ClientContext &)> {
|
||||||
static std::string call(ClientContext &context, TableFunctionInput &input,
|
static std::string call(ClientContext &context, TableFunctionInput &input,
|
||||||
|
|||||||
60
src/state.cpp
Normal file
60
src/state.cpp
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#include "state.hpp"
|
||||||
|
|
||||||
|
#include <duckdb/main/database.hpp>
|
||||||
|
|
||||||
|
namespace duckdb {
|
||||||
|
|
||||||
|
UIStorageExtensionInfo &
|
||||||
|
UIStorageExtensionInfo::GetState(const DatabaseInstance &instance) {
|
||||||
|
auto &config = instance.config;
|
||||||
|
auto it = config.storage_extensions.find(STORAGE_EXTENSION_KEY);
|
||||||
|
if (it == config.storage_extensions.end()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Fatal error: couldn't find the UI extension state.");
|
||||||
|
}
|
||||||
|
return *static_cast<UIStorageExtensionInfo *>(it->second->storage_info.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Connection>
|
||||||
|
UIStorageExtensionInfo::FindConnection(const std::string &connection_name) {
|
||||||
|
if (connection_name.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to protect access to the connections map because this can be called
|
||||||
|
// from multiple threads.
|
||||||
|
std::lock_guard<std::mutex> guard(connections_mutex);
|
||||||
|
|
||||||
|
auto result = connections.find(connection_name);
|
||||||
|
if (result != connections.end()) {
|
||||||
|
return result->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Connection> UIStorageExtensionInfo::FindOrCreateConnection(
|
||||||
|
DatabaseInstance &db, const std::string &connection_name) {
|
||||||
|
if (connection_name.empty()) {
|
||||||
|
// If no connection name was provided, create and return a new connection
|
||||||
|
// but don't remember it.
|
||||||
|
return make_shared_ptr<Connection>(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If an existing connection with the provided name was found, return it.
|
||||||
|
auto connection = FindConnection(connection_name);
|
||||||
|
if (connection) {
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, create a new one, remember it, and return it.
|
||||||
|
auto new_con = make_shared_ptr<Connection>(db);
|
||||||
|
|
||||||
|
// Need to protect access to the connections map because this can be called
|
||||||
|
// from multiple threads.
|
||||||
|
std::lock_guard<std::mutex> guard(connections_mutex);
|
||||||
|
connections[connection_name] = new_con;
|
||||||
|
return new_con;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace duckdb
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "utils/helpers.hpp"
|
#include "utils/helpers.hpp"
|
||||||
#include "ui_extension.hpp"
|
#include "ui_extension.hpp"
|
||||||
#include "http_server.hpp"
|
#include "http_server.hpp"
|
||||||
|
#include "state.hpp"
|
||||||
#include <duckdb.hpp>
|
#include <duckdb.hpp>
|
||||||
#include <duckdb/common/string_util.hpp>
|
#include <duckdb/common/string_util.hpp>
|
||||||
|
|
||||||
@@ -29,27 +30,23 @@ namespace duckdb {
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
bool StartHttpServer(const ClientContext &context) {
|
const ui::HttpServer &StartHttpServer(ClientContext &context,
|
||||||
const auto url =
|
bool *was_started = nullptr) {
|
||||||
|
const auto remote_url =
|
||||||
GetSetting<std::string>(context, UI_REMOTE_URL_SETTING_NAME,
|
GetSetting<std::string>(context, UI_REMOTE_URL_SETTING_NAME,
|
||||||
GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME,
|
GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME,
|
||||||
UI_REMOTE_URL_SETTING_DEFAULT));
|
UI_REMOTE_URL_SETTING_DEFAULT));
|
||||||
const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME,
|
const uint16_t port = GetSetting(context, UI_LOCAL_PORT_SETTING_NAME,
|
||||||
UI_LOCAL_PORT_SETTING_DEFAULT);
|
UI_LOCAL_PORT_SETTING_DEFAULT);
|
||||||
;
|
return ui::HttpServer::GetInstance(context)->Start(port, remote_url,
|
||||||
return ui::HttpServer::instance()->Start(port, url, context.db);
|
was_started);
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetHttpServerLocalURL() {
|
|
||||||
return StringUtil::Format("http://localhost:%d/",
|
|
||||||
ui::HttpServer::instance()->LocalPort());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
std::string StartUIFunction(ClientContext &context) {
|
std::string StartUIFunction(ClientContext &context) {
|
||||||
internal::StartHttpServer(context);
|
const auto &server = internal::StartHttpServer(context);
|
||||||
auto local_url = internal::GetHttpServerLocalURL();
|
const auto local_url = server.LocalUrl();
|
||||||
|
|
||||||
const std::string command =
|
const std::string command =
|
||||||
StringUtil::Format("%s %s", OPEN_COMMAND, local_url);
|
StringUtil::Format("%s %s", OPEN_COMMAND, local_url);
|
||||||
@@ -60,13 +57,16 @@ std::string StartUIFunction(ClientContext &context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string StartUIServerFunction(ClientContext &context) {
|
std::string StartUIServerFunction(ClientContext &context) {
|
||||||
const char *already = internal::StartHttpServer(context) ? "already " : "";
|
bool was_started = false;
|
||||||
|
const auto &server = internal::StartHttpServer(context, &was_started);
|
||||||
|
const char *already = was_started ? "already " : "";
|
||||||
return StringUtil::Format("MotherDuck UI server %sstarted at %s", already,
|
return StringUtil::Format("MotherDuck UI server %sstarted at %s", already,
|
||||||
internal::GetHttpServerLocalURL());
|
server.LocalUrl());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string StopUIServerFunction() {
|
std::string StopUIServerFunction(ClientContext &context) {
|
||||||
return ui::HttpServer::instance()->Stop() ? "UI server stopped"
|
return ui::HttpServer::GetInstance(context)->Stop()
|
||||||
|
? "UI server stopped"
|
||||||
: "UI server already stopped";
|
: "UI server already stopped";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,20 +92,29 @@ NotifyConnectedBind(ClientContext &, TableFunctionBindInput &input,
|
|||||||
std::string NotifyConnectedFunction(ClientContext &context,
|
std::string NotifyConnectedFunction(ClientContext &context,
|
||||||
TableFunctionInput &input) {
|
TableFunctionInput &input) {
|
||||||
auto &inputs = input.bind_data->Cast<NotifyConnectedFunctionData>();
|
auto &inputs = input.bind_data->Cast<NotifyConnectedFunctionData>();
|
||||||
ui::HttpServer::instance()->SendConnectedEvent(inputs.token);
|
ui::HttpServer::GetInstance(context)->SendConnectedEvent(inputs.token);
|
||||||
return "OK";
|
return "OK";
|
||||||
}
|
}
|
||||||
|
|
||||||
// - connected notification
|
// - connected notification
|
||||||
|
|
||||||
std::string NotifyCatalogChangedFunction() {
|
std::string NotifyCatalogChangedFunction(ClientContext &context) {
|
||||||
ui::HttpServer::instance()->SendCatalogChangedEvent();
|
ui::HttpServer::GetInstance(context)->SendCatalogChangedEvent();
|
||||||
return "OK";
|
return "OK";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitStorageExtension(duckdb::DatabaseInstance &db) {
|
||||||
|
auto &config = db.config;
|
||||||
|
auto ext = duckdb::make_uniq<duckdb::StorageExtension>();
|
||||||
|
ext->storage_info = duckdb::make_uniq<UIStorageExtensionInfo>();
|
||||||
|
config.storage_extensions[STORAGE_EXTENSION_KEY] = std::move(ext);
|
||||||
|
}
|
||||||
|
|
||||||
static void LoadInternal(DatabaseInstance &instance) {
|
static void LoadInternal(DatabaseInstance &instance) {
|
||||||
auto &config = DBConfig::GetConfig(instance);
|
auto &config = DBConfig::GetConfig(instance);
|
||||||
|
|
||||||
|
InitStorageExtension(instance);
|
||||||
|
|
||||||
config.AddExtensionOption(
|
config.AddExtensionOption(
|
||||||
UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION,
|
UI_LOCAL_PORT_SETTING_NAME, UI_LOCAL_PORT_SETTING_DESCRIPTION,
|
||||||
LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT));
|
LogicalType::USMALLINT, Value::USMALLINT(UI_LOCAL_PORT_SETTING_DEFAULT));
|
||||||
@@ -122,6 +131,11 @@ static void LoadInternal(DatabaseInstance &instance) {
|
|||||||
RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction);
|
RESISTER_TF("notify_ui_catalog_changed", NotifyCatalogChangedFunction);
|
||||||
RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR},
|
RESISTER_TF_ARGS("notify_ui_connected", {LogicalType::VARCHAR},
|
||||||
NotifyConnectedFunction, NotifyConnectedBind);
|
NotifyConnectedFunction, NotifyConnectedBind);
|
||||||
|
|
||||||
|
// If the server is already running we need to update the database instance
|
||||||
|
// since the previous one was invalidated (eg. in the shell when we '.open'
|
||||||
|
// a new database)
|
||||||
|
ui::HttpServer::UpdateDatabaseInstanceIfRunning(instance.shared_from_this());
|
||||||
}
|
}
|
||||||
|
|
||||||
void UiExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }
|
void UiExtension::Load(DuckDB &db) { LoadInternal(*db.instance); }
|
||||||
|
|||||||
Reference in New Issue
Block a user