diff --git a/src/http_server.cpp b/src/http_server.cpp index 092f73b..e960702 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -195,12 +195,11 @@ void HttpServer::SendEvent(const std::string &message) { } } -void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db, - CatalogState &last_state) { +bool WasCatalogUpdated(DatabaseInstance &db, Connection &connection, + CatalogState &last_state) { bool has_change = false; - duckdb::Connection con{db}; - auto &context = *con.context; - con.BeginTransaction(); + auto &context = *connection.context; + connection.BeginTransaction(); const auto &databases = db.GetDatabaseManager().GetDatabases(context); std::set db_oids; @@ -234,20 +233,61 @@ void HttpServer::WatchForCatalogUpdate(DatabaseInstance &db, } } - // If there was any change, notify the UI - if (has_change) { - SendCatalogChangedEvent(); + connection.Rollback(); + return has_change; +} + +std::string GetMDToken(Connection &connection) { + auto query_res = connection.Query("CALL GET_MD_TOKEN()"); + if (query_res->HasError()) { + query_res->ThrowError(); + return ""; // unreachable } + + auto chunk = query_res->Fetch(); + return chunk->GetValue(0, 0).GetValue(); +} + +bool IsMDConnected(Connection &con) { + if (!con.context->db->ExtensionIsLoaded("motherduck")) { + return false; + } + auto query_res = con.Query("CALL MD_IS_CONNECTED()"); + if (query_res->HasError()) { + std::cerr << "Error in IsMDConnected: " << query_res->GetError() + << std::endl; + return false; + } + + auto chunk = query_res->Fetch(); + return chunk->GetValue(0, 0).GetValue(); } void HttpServer::Watch() { CatalogState last_state; + bool is_md_connected = false; while (watcher_should_run) { auto db = ddb_instance.lock(); if (!db) { break; // DB went away, nothing to watch } - WatchForCatalogUpdate(*db, last_state); + + try { + duckdb::Connection con{*db}; + if (WasCatalogUpdated(*db, con, last_state)) { + SendCatalogChangedEvent(); + } + + if (!is_md_connected && IsMDConnected(con)) { + is_md_connected = true; + SendConnectedEvent(GetMDToken(con)); + } + } catch (std::exception &ex) { + // Do not crash with uncaught exception, but quit. + std::cerr << "Error in watcher: " << ex.what() << std::endl; + std::cerr << "Will now terminate." << std::endl; + return; + } { std::unique_lock lock(watcher_mutex); watcher_cv.wait_for(lock, @@ -313,27 +353,23 @@ void HttpServer::HandleGetLocalToken(const httplib::Request &req, return; } - Connection connection(*db); - auto query_res = connection.Query("CALL get_md_token()"); - if (query_res->HasError()) { + Connection connection{*db}; + try { + auto token = GetMDToken(connection); + res.status = 200; + res.set_content(token, "text/plain"); + } catch (std::exception &ex) { if (StringUtil::Contains( - query_res->GetError(), - "GET_MD_TOKEN will be available after you connect")) { + ex.what(), "GET_MD_TOKEN will be available after you connect")) { // UI expects an empty response if MD isn't connected + res.status = 200; res.set_content("", "text/plain"); - return; + } else { + res.status = 500; + res.set_content("Could not get token: " + std::string(ex.what()), + "text/plain"); } - - res.status = 500; - res.set_content("Could not get token: " + query_res->GetError(), - "text/plain"); - return; } - - auto chunk = query_res->Fetch(); - auto token = chunk->GetValue(0, 0).GetValue(); - res.status = 200; - res.set_content(token, "text/plain"); } void HttpServer::HandleGet(const httplib::Request &req, diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index 52d0f60..9706a7f 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -80,9 +80,6 @@ private: void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); - // Watchers - void WatchForCatalogUpdate(DatabaseInstance &, CatalogState &last_state); - uint16_t local_port; std::string remote_url; weak_ptr ddb_instance;