Watch for MD connection

This commit is contained in:
Yves
2025-02-21 10:50:03 +01:00
parent f206212575
commit 642988fa94
2 changed files with 61 additions and 28 deletions

View File

@@ -195,12 +195,11 @@ void HttpServer::SendEvent(const std::string &message) {
} }
} }
void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db, bool WasCatalogUpdated(DatabaseInstance &db, Connection &connection,
CatalogState &last_state) { CatalogState &last_state) {
bool has_change = false; bool has_change = false;
duckdb::Connection con{db}; auto &context = *connection.context;
auto &context = *con.context; connection.BeginTransaction();
con.BeginTransaction();
const auto &databases = db.GetDatabaseManager().GetDatabases(context); const auto &databases = db.GetDatabaseManager().GetDatabases(context);
std::set<idx_t> db_oids; std::set<idx_t> db_oids;
@@ -234,20 +233,61 @@ void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db,
} }
} }
// If there was any change, notify the UI connection.Rollback();
if (has_change) { return has_change;
SendCatalogChangedEvent(); }
std::string GetMDToken(Connection &connection) {
auto query_res = connection.Query("CALL GET_MD_TOKEN()");
if (query_res->HasError()) {
query_res->ThrowError();
return ""; // unreachable
} }
auto chunk = query_res->Fetch();
return chunk->GetValue(0, 0).GetValue<std::string>();
}
bool IsMDConnected(Connection &con) {
if (!con.context->db->ExtensionIsLoaded("motherduck")) {
return false;
}
auto query_res = con.Query("CALL MD_IS_CONNECTED()");
if (query_res->HasError()) {
std::cerr << "Error in IsMDConnected: " << query_res->GetError()
<< std::endl;
return false;
}
auto chunk = query_res->Fetch();
return chunk->GetValue(0, 0).GetValue<bool>();
} }
void HttpServer::Watch() { void HttpServer::Watch() {
CatalogState last_state; CatalogState last_state;
bool is_md_connected = false;
while (watcher_should_run) { while (watcher_should_run) {
auto db = ddb_instance.lock(); auto db = ddb_instance.lock();
if (!db) { if (!db) {
break; // DB went away, nothing to watch break; // DB went away, nothing to watch
} }
WatchForCatalogUpdate(*db, last_state);
try {
duckdb::Connection con{*db};
if (WasCatalogUpdated(*db, con, last_state)) {
SendCatalogChangedEvent();
}
if (!is_md_connected && IsMDConnected(con)) {
is_md_connected = true;
SendConnectedEvent(GetMDToken(con));
}
} catch (std::exception &ex) {
// Do not crash with uncaught exception, but quit.
std::cerr << "Error in watcher: " << ex.what() << std::endl;
std::cerr << "Will now terminate." << std::endl;
return;
}
{ {
std::unique_lock<std::mutex> lock(watcher_mutex); std::unique_lock<std::mutex> lock(watcher_mutex);
watcher_cv.wait_for(lock, watcher_cv.wait_for(lock,
@@ -313,27 +353,23 @@ void HttpServer::HandleGetLocalToken(const httplib::Request &req,
return; return;
} }
Connection connection(*db); Connection connection{*db};
auto query_res = connection.Query("CALL get_md_token()"); try {
if (query_res->HasError()) { auto token = GetMDToken(connection);
if (StringUtil::Contains(
query_res->GetError(),
"GET_MD_TOKEN will be available after you connect")) {
// UI expects an empty response if MD isn't connected
res.set_content("", "text/plain");
return;
}
res.status = 500;
res.set_content("Could not get token: " + query_res->GetError(),
"text/plain");
return;
}
auto chunk = query_res->Fetch();
auto token = chunk->GetValue(0, 0).GetValue<std::string>();
res.status = 200; res.status = 200;
res.set_content(token, "text/plain"); res.set_content(token, "text/plain");
} catch (std::exception &ex) {
if (StringUtil::Contains(
ex.what(), "GET_MD_TOKEN will be available after you connect")) {
// UI expects an empty response if MD isn't connected
res.status = 200;
res.set_content("", "text/plain");
} else {
res.status = 500;
res.set_content("Could not get token: " + std::string(ex.what()),
"text/plain");
}
}
} }
void HttpServer::HandleGet(const httplib::Request &req, void HttpServer::HandleGet(const httplib::Request &req,

View File

@@ -80,9 +80,6 @@ private:
void SetResponseEmptyResult(httplib::Response &res); void SetResponseEmptyResult(httplib::Response &res);
void SetResponseErrorResult(httplib::Response &res, const std::string &error); void SetResponseErrorResult(httplib::Response &res, const std::string &error);
// Watchers
void WatchForCatalogUpdate(DatabaseInstance &, CatalogState &last_state);
uint16_t local_port; uint16_t local_port;
std::string remote_url; std::string remote_url;
weak_ptr<DatabaseInstance> ddb_instance; weak_ptr<DatabaseInstance> ddb_instance;