diff --git a/src/http_server.cpp b/src/http_server.cpp index 1e530cb..531d8da 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -21,22 +21,22 @@ constexpr const char *EMPTY_SSE_MESSAGE = ":\r\r"; constexpr idx_t EMPTY_SSE_MESSAGE_LENGTH = 3; bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { - std::unique_lock lock(mutex_); + std::unique_lock lock(mutex); // Don't allow too many simultaneous waits, because each consumes a thread in // the httplib thread pool, and also browsers limit the number of server-sent // event connections. - if (closed_ || wait_count_ >= MAX_EVENT_WAIT_COUNT) { + if (closed || wait_count >= MAX_EVENT_WAIT_COUNT) { return false; } - int target_id = next_id_; - wait_count_++; - cv_.wait_for(lock, std::chrono::seconds(5)); - wait_count_--; - if (closed_) { + int target_id = next_id; + wait_count++; + cv.wait_for(lock, std::chrono::seconds(5)); + wait_count--; + if (closed) { return false; } - if (current_id_ == target_id) { - sink->write(message_.data(), message_.size()); + if (current_id == target_id) { + sink->write(message.data(), message.size()); } else { // Our wait timer expired. Write an empty, no-op message. // This enables detecting when the client is gone. @@ -45,82 +45,86 @@ bool EventDispatcher::WaitEvent(httplib::DataSink *sink) { return true; } -void EventDispatcher::SendEvent(const std::string &message) { - std::lock_guard guard(mutex_); - if (closed_) { +void EventDispatcher::SendEvent(const std::string &_message) { + std::lock_guard guard(mutex); + if (closed) { return; } - current_id_ = next_id_++; - message_ = message; - cv_.notify_all(); + current_id = next_id++; + message = _message; + cv.notify_all(); } void EventDispatcher::Close() { - std::lock_guard guard(mutex_); - if (closed_) { + std::lock_guard guard(mutex); + if (closed) { return; } - current_id_ = next_id_++; - closed_ = true; - cv_.notify_all(); + + current_id = next_id++; + closed = true; + cv.notify_all(); } -unique_ptr HttpServer::instance_; +unique_ptr HttpServer::server_instance; HttpServer *HttpServer::instance() { - if (!instance_) { - instance_ = make_uniq(); + if (!server_instance) { + server_instance = make_uniq(); std::atexit(HttpServer::StopInstance); } - return instance_.get(); + return server_instance.get(); } -bool HttpServer::Started() { return instance_ && instance_->thread_; } +bool HttpServer::Started() { + return server_instance && server_instance->main_thread; +} void HttpServer::StopInstance() { - if (instance_) { - instance_->Stop(); + if (server_instance) { + server_instance->Stop(); } } -bool HttpServer::Start(const uint16_t local_port, const std::string &remote_url, - const shared_ptr &ddb_instance) { - if (thread_) { +bool HttpServer::Start(const uint16_t _local_port, + const std::string &_remote_url, + const shared_ptr &_ddb_instance) { + if (main_thread) { return false; } - local_port_ = local_port; - remote_url_ = remote_url; - ddb_instance_ = ddb_instance; + local_port = _local_port; + remote_url = _remote_url; + ddb_instance = _ddb_instance; #ifndef UI_EXTENSION_GIT_SHA #error "UI_EXTENSION_GIT_SHA must be defined" #endif - user_agent_ = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA, - DuckDB::Platform()); - event_dispatcher_ = make_uniq(); - thread_ = make_uniq(&HttpServer::Run, this); + user_agent = StringUtil::Format("duckdb-ui/%s(%s)", UI_EXTENSION_GIT_SHA, + DuckDB::Platform()); + event_dispatcher = make_uniq(); + main_thread = make_uniq(&HttpServer::Run, this); return true; } bool HttpServer::Stop() { - if (!thread_) { + if (!main_thread) { return false; } - event_dispatcher_->Close(); - server_.stop(); - thread_->join(); - thread_.reset(); - event_dispatcher_.reset(); - connections_.clear(); - ddb_instance_.reset(); - remote_url_ = ""; - local_port_ = 0; + event_dispatcher->Close(); + server.stop(); + main_thread->join(); + main_thread.reset(); + event_dispatcher.reset(); + connections.clear(); + ddb_instance.reset(); + remote_url = ""; + local_port = 0; return true; } -uint16_t HttpServer::LocalPort() { return local_port_; } +uint16_t HttpServer::LocalPort() { return local_port; } void HttpServer::SendConnectedEvent(const std::string &token) { SendEvent(StringUtil::Format("event: ConnectedEvent\ndata: %s\n\n", token)); @@ -131,45 +135,45 @@ void HttpServer::SendCatalogChangedEvent() { } void HttpServer::SendEvent(const std::string &message) { - if (event_dispatcher_) { - event_dispatcher_->SendEvent(message); + if (event_dispatcher) { + event_dispatcher->SendEvent(message); } } void HttpServer::Run() { - server_.Get("/localEvents", - [&](const httplib::Request &req, httplib::Response &res) { - HandleGetLocalEvents(req, res); - }); - server_.Get("/localToken", - [&](const httplib::Request &req, httplib::Response &res) { - HandleGetLocalToken(req, res); - }); - server_.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { + server.Get("/localEvents", + [&](const httplib::Request &req, httplib::Response &res) { + HandleGetLocalEvents(req, res); + }); + server.Get("/localToken", + [&](const httplib::Request &req, httplib::Response &res) { + HandleGetLocalToken(req, res); + }); + server.Get("/.*", [&](const httplib::Request &req, httplib::Response &res) { HandleGet(req, res); }); - server_.Post("/ddb/interrupt", - [&](const httplib::Request &req, httplib::Response &res) { - HandleInterrupt(req, res); - }); - server_.Post("/ddb/run", - [&](const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &content_reader) { - HandleRun(req, res, content_reader); - }); - server_.Post("/ddb/tokenize", - [&](const httplib::Request &req, httplib::Response &res, - const httplib::ContentReader &content_reader) { - HandleTokenize(req, res, content_reader); - }); - server_.listen("localhost", local_port_); + server.Post("/ddb/interrupt", + [&](const httplib::Request &req, httplib::Response &res) { + HandleInterrupt(req, res); + }); + server.Post("/ddb/run", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + HandleRun(req, res, content_reader); + }); + server.Post("/ddb/tokenize", + [&](const httplib::Request &req, httplib::Response &res, + const httplib::ContentReader &content_reader) { + HandleTokenize(req, res, content_reader); + }); + server.listen("localhost", local_port); } void HttpServer::HandleGetLocalEvents(const httplib::Request &req, httplib::Response &res) { res.set_chunked_content_provider( "text/event-stream", [&](size_t /*offset*/, httplib::DataSink &sink) { - if (event_dispatcher_->WaitEvent(&sink)) { + if (event_dispatcher->WaitEvent(&sink)) { return true; } sink.done(); @@ -179,13 +183,13 @@ void HttpServer::HandleGetLocalEvents(const httplib::Request &req, void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { - if (!ddb_instance_->ExtensionIsLoaded("motherduck")) { + if (!ddb_instance->ExtensionIsLoaded("motherduck")) { res.set_content("", "text/plain"); // UI expects an empty response if the // extension is not loaded return; } - Connection connection(*ddb_instance_); + Connection connection(*ddb_instance); auto query_res = connection.Query("CALL get_md_token()"); if (query_res->HasError()) { res.status = 500; @@ -204,7 +208,7 @@ void HttpServer::HandleGet(const httplib::Request &req, httplib::Response &res) { // Create HTTP client to remote URL // TODO: Can this be created once and shared? - httplib::Client client(remote_url_); + httplib::Client client(remote_url); client.set_keep_alive(true); // Provide a way to turn on or off server certificate verification, at least @@ -218,7 +222,7 @@ void HttpServer::HandleGet(const httplib::Request &req, } // forward GET to remote URL - auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent_}}); + auto result = client.Get(req.path, req.params, {{"User-Agent", user_agent}}); if (!result) { res.status = 500; return; @@ -275,7 +279,7 @@ void HttpServer::DoHandleRun(const httplib::Request &req, // Set current database if optional header is provided. if (!database_name.empty()) { connection->context->RunFunctionInTransaction([&] { - ddb_instance_->GetDatabaseManager().SetDefaultDatabase( + ddb_instance->GetDatabaseManager().SetDefaultDatabase( *connection->context, database_name); }); } @@ -395,10 +399,10 @@ HttpServer::FindConnection(const std::string &connection_name) { // Need to protect access to the connections map because this can be called // from multiple threads. - std::lock_guard guard(connections_mutex_); + std::lock_guard guard(connections_mutex); - auto result = connections_.find(connection_name); - if (result != connections_.end()) { + auto result = connections.find(connection_name); + if (result != connections.end()) { return result->second; } @@ -410,22 +414,22 @@ HttpServer::FindOrCreateConnection(const std::string &connection_name) { if (connection_name.empty()) { // If no connection name was provided, create and return a new connection // but don't remember it. - return make_shared_ptr(*ddb_instance_); + return make_shared_ptr(*ddb_instance); } // Need to protect access to the connections map because this can be called // from multiple threads. - std::lock_guard guard(connections_mutex_); + std::lock_guard guard(connections_mutex); // If an existing connection with the provided name was found, return it. - auto result = connections_.find(connection_name); - if (result != connections_.end()) { + auto result = connections.find(connection_name); + if (result != connections.end()) { return result->second; } // Otherwise, create a new one, remember it, and return it. - auto connection = make_shared_ptr(*ddb_instance_); - connections_[connection_name] = connection; + auto connection = make_shared_ptr(*ddb_instance); + connections[connection_name] = connection; return connection; } diff --git a/src/include/http_server.hpp b/src/include/http_server.hpp index 4fe1d1c..59fbaf7 100644 --- a/src/include/http_server.hpp +++ b/src/include/http_server.hpp @@ -25,13 +25,13 @@ public: void Close(); private: - std::mutex mutex_; - std::condition_variable cv_; - std::atomic_int next_id_{0}; - std::atomic_int current_id_{-1}; - std::atomic_int wait_count_{0}; - std::string message_; - std::atomic_bool closed_{false}; + std::mutex mutex; + std::condition_variable cv; + std::atomic_int next_id{0}; + std::atomic_int current_id{-1}; + std::atomic_int wait_count{0}; + std::string message; + std::atomic_bool closed{false}; }; class HttpServer { @@ -70,17 +70,17 @@ private: void SetResponseEmptyResult(httplib::Response &res); void SetResponseErrorResult(httplib::Response &res, const std::string &error); - uint16_t local_port_; - std::string remote_url_; - shared_ptr ddb_instance_; - std::string user_agent_; - httplib::Server server_; - unique_ptr thread_; - std::mutex connections_mutex_; - std::unordered_map> connections_; - unique_ptr event_dispatcher_; + uint16_t local_port; + std::string remote_url; + shared_ptr ddb_instance; + std::string user_agent; + httplib::Server server; + unique_ptr main_thread; + std::mutex connections_mutex; + std::unordered_map> connections; + unique_ptr event_dispatcher; - static unique_ptr instance_; + static unique_ptr server_instance; }; ;