Fix ownership model of DatabaseInstance
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include "utils/env.hpp"
|
||||
#include "utils/serialization.hpp"
|
||||
#include "utils/encoding.hpp"
|
||||
#include "state.hpp"
|
||||
|
||||
// Chosen to be no more than half of the lesser of the two limits:
|
||||
// - The default httplib thread pool size = 8
|
||||
@@ -70,14 +71,34 @@ void EventDispatcher::Close() {
|
||||
|
||||
unique_ptr<HttpServer> HttpServer::server_instance;
|
||||
|
||||
HttpServer *HttpServer::instance() {
|
||||
if (!server_instance) {
|
||||
server_instance = make_uniq<HttpServer>();
|
||||
HttpServer *HttpServer::GetInstance(ClientContext &context) {
|
||||
if (server_instance) {
|
||||
// We already have an instance, make sure we're running on the right DB
|
||||
server_instance->UpdateDatabaseInstance(context.db);
|
||||
} else {
|
||||
server_instance = make_uniq<HttpServer>(context.db);
|
||||
std::atexit(HttpServer::StopInstance);
|
||||
}
|
||||
return server_instance.get();
|
||||
}
|
||||
|
||||
void HttpServer::UpdateDatabaseInstanceIfRunning(
|
||||
shared_ptr<DatabaseInstance> db) {
|
||||
if (server_instance) {
|
||||
server_instance->UpdateDatabaseInstance(db);
|
||||
}
|
||||
}
|
||||
|
||||
void HttpServer::UpdateDatabaseInstance(
|
||||
shared_ptr<DatabaseInstance> context_db) {
|
||||
const auto current_db = server_instance->ddb_instance.lock();
|
||||
if (current_db != context_db) {
|
||||
server_instance->StopWatcher(); // Likely already stopped, but just in case
|
||||
server_instance->ddb_instance = context_db;
|
||||
server_instance->StartWatcher();
|
||||
}
|
||||
}
|
||||
|
||||
bool HttpServer::Started() {
|
||||
return server_instance && server_instance->main_thread;
|
||||
}
|
||||
@@ -88,16 +109,18 @@ void HttpServer::StopInstance() {
|
||||
}
|
||||
}
|
||||
|
||||
bool HttpServer::Start(const uint16_t _local_port,
|
||||
const std::string &_remote_url,
|
||||
const shared_ptr<DatabaseInstance> &_ddb_instance) {
|
||||
const HttpServer &HttpServer::Start(const uint16_t _local_port,
|
||||
const std::string &_remote_url,
|
||||
bool *was_started) {
|
||||
if (main_thread) {
|
||||
return false;
|
||||
if (was_started) {
|
||||
*was_started = true;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
local_port = _local_port;
|
||||
remote_url = _remote_url;
|
||||
ddb_instance = _ddb_instance;
|
||||
#ifndef UI_EXTENSION_GIT_SHA
|
||||
#error "UI_EXTENSION_GIT_SHA must be defined"
|
||||
#endif
|
||||
@@ -105,12 +128,33 @@ bool HttpServer::Start(const uint16_t _local_port,
|
||||
DuckDB::Platform());
|
||||
event_dispatcher = make_uniq<EventDispatcher>();
|
||||
main_thread = make_uniq<std::thread>(&HttpServer::Run, this);
|
||||
StartWatcher();
|
||||
return *this;
|
||||
}
|
||||
|
||||
void HttpServer::StartWatcher() {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(watcher_mutex);
|
||||
watcher_should_run = true;
|
||||
}
|
||||
watcher_thread = make_uniq<std::thread>(&HttpServer::Watch, this);
|
||||
return true;
|
||||
|
||||
if (!watcher_thread) {
|
||||
watcher_thread = make_uniq<std::thread>(&HttpServer::Watch, this);
|
||||
}
|
||||
}
|
||||
|
||||
void HttpServer::StopWatcher() {
|
||||
if (!watcher_thread) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(watcher_mutex);
|
||||
watcher_should_run = false;
|
||||
}
|
||||
watcher_cv.notify_all();
|
||||
watcher_thread->join();
|
||||
watcher_thread.reset();
|
||||
}
|
||||
|
||||
bool HttpServer::Stop() {
|
||||
@@ -121,27 +165,20 @@ bool HttpServer::Stop() {
|
||||
event_dispatcher->Close();
|
||||
server.stop();
|
||||
|
||||
if (watcher_thread) {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(watcher_mutex);
|
||||
watcher_should_run = false;
|
||||
}
|
||||
watcher_cv.notify_all();
|
||||
watcher_thread->join();
|
||||
watcher_thread.reset();
|
||||
}
|
||||
StopWatcher();
|
||||
|
||||
main_thread->join();
|
||||
main_thread.reset();
|
||||
event_dispatcher.reset();
|
||||
connections.clear();
|
||||
ddb_instance.reset();
|
||||
remote_url = "";
|
||||
local_port = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
uint16_t HttpServer::LocalPort() { return local_port; }
|
||||
std::string HttpServer::LocalUrl() const {
|
||||
return StringUtil::Format("http://localhost:%d/", local_port);
|
||||
}
|
||||
|
||||
void HttpServer::SendConnectedEvent(const std::string &token) {
|
||||
SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token));
|
||||
@@ -157,13 +194,14 @@ void HttpServer::SendEvent(const std::string &message) {
|
||||
}
|
||||
}
|
||||
|
||||
void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) {
|
||||
void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db,
|
||||
CatalogState &last_state) {
|
||||
bool has_change = false;
|
||||
duckdb::Connection con{*ddb_instance};
|
||||
duckdb::Connection con{db};
|
||||
auto &context = *con.context;
|
||||
con.BeginTransaction();
|
||||
const auto &databases =
|
||||
ddb_instance->GetDatabaseManager().GetDatabases(context);
|
||||
|
||||
const auto &databases = db.GetDatabaseManager().GetDatabases(context);
|
||||
std::set<idx_t> db_oids;
|
||||
|
||||
// Check currently attached databases
|
||||
@@ -204,7 +242,11 @@ void HttpServer::WatchForCatalogUpdate(CatalogState &last_state) {
|
||||
void HttpServer::Watch() {
|
||||
CatalogState last_state;
|
||||
while (watcher_should_run) {
|
||||
WatchForCatalogUpdate(last_state);
|
||||
auto db = ddb_instance.lock();
|
||||
if (!db) {
|
||||
break; // DB went away, nothing to watch
|
||||
}
|
||||
WatchForCatalogUpdate(*db, last_state);
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(watcher_mutex);
|
||||
watcher_cv.wait_for(lock,
|
||||
@@ -256,13 +298,21 @@ void HttpServer::HandleGetLocalEvents(const httplib::Request &req,
|
||||
|
||||
void HttpServer::HandleGetLocalToken(const httplib::Request &req,
|
||||
httplib::Response &res) {
|
||||
if (!ddb_instance->ExtensionIsLoaded("motherduck")) {
|
||||
auto db = ddb_instance.lock();
|
||||
if (!db) {
|
||||
res.status = 500;
|
||||
res.set_content("Database was invalidated, UI needs to be restarted",
|
||||
"text/plain");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!db->ExtensionIsLoaded("motherduck")) {
|
||||
res.set_content("", "text/plain"); // UI expects an empty response if the
|
||||
// extension is not loaded
|
||||
return;
|
||||
}
|
||||
|
||||
Connection connection(*ddb_instance);
|
||||
Connection connection(*db);
|
||||
auto query_res = connection.Query("CALL get_md_token()");
|
||||
if (query_res->HasError()) {
|
||||
res.status = 500;
|
||||
@@ -316,7 +366,14 @@ void HttpServer::HandleInterrupt(const httplib::Request &req,
|
||||
auto description = req.get_header_value("X-MD-Description");
|
||||
auto connection_name = req.get_header_value("X-MD-Connection-Name");
|
||||
|
||||
auto connection = FindConnection(connection_name);
|
||||
auto db = ddb_instance.lock();
|
||||
if (!db) {
|
||||
res.status = 404;
|
||||
return;
|
||||
}
|
||||
|
||||
auto connection =
|
||||
UIStorageExtensionInfo::GetState(*db).FindConnection(connection_name);
|
||||
if (!connection) {
|
||||
res.status = 404;
|
||||
return;
|
||||
@@ -347,13 +404,23 @@ void HttpServer::DoHandleRun(const httplib::Request &req,
|
||||
|
||||
std::string content = ReadContent(content_reader);
|
||||
|
||||
auto connection = FindOrCreateConnection(connection_name);
|
||||
auto db = ddb_instance.lock();
|
||||
if (!db) {
|
||||
SetResponseErrorResult(
|
||||
res, "Database was invalidated, UI needs to be restarted");
|
||||
return;
|
||||
}
|
||||
|
||||
auto connection =
|
||||
UIStorageExtensionInfo::GetState(*db).FindOrCreateConnection(
|
||||
*db, connection_name);
|
||||
|
||||
// Set current database if optional header is provided.
|
||||
if (!database_name.empty()) {
|
||||
connection->context->RunFunctionInTransaction([&] {
|
||||
ddb_instance->GetDatabaseManager().SetDefaultDatabase(
|
||||
*connection->context, database_name);
|
||||
auto &context = *connection->context;
|
||||
context.RunFunctionInTransaction([&] {
|
||||
auto &manager = context.db->GetDatabaseManager();
|
||||
manager.SetDefaultDatabase(context, database_name);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -464,48 +531,6 @@ HttpServer::ReadContent(const httplib::ContentReader &content_reader) {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
shared_ptr<Connection>
|
||||
HttpServer::FindConnection(const std::string &connection_name) {
|
||||
if (connection_name.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Need to protect access to the connections map because this can be called
|
||||
// from multiple threads.
|
||||
std::lock_guard<std::mutex> guard(connections_mutex);
|
||||
|
||||
auto result = connections.find(connection_name);
|
||||
if (result != connections.end()) {
|
||||
return result->second;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
shared_ptr<Connection>
|
||||
HttpServer::FindOrCreateConnection(const std::string &connection_name) {
|
||||
if (connection_name.empty()) {
|
||||
// If no connection name was provided, create and return a new connection
|
||||
// but don't remember it.
|
||||
return make_shared_ptr<Connection>(*ddb_instance);
|
||||
}
|
||||
|
||||
// Need to protect access to the connections map because this can be called
|
||||
// from multiple threads.
|
||||
std::lock_guard<std::mutex> guard(connections_mutex);
|
||||
|
||||
// If an existing connection with the provided name was found, return it.
|
||||
auto result = connections.find(connection_name);
|
||||
if (result != connections.end()) {
|
||||
return result->second;
|
||||
}
|
||||
|
||||
// Otherwise, create a new one, remember it, and return it.
|
||||
auto connection = make_shared_ptr<Connection>(*ddb_instance);
|
||||
connections[connection_name] = connection;
|
||||
return connection;
|
||||
}
|
||||
|
||||
void HttpServer::SetResponseContent(httplib::Response &res,
|
||||
const MemoryStream &content) {
|
||||
auto data = content.GetData();
|
||||
|
||||
Reference in New Issue
Block a user