From 4e96696d091f2f658e34ea8958bb59833035a7f2 Mon Sep 17 00:00:00 2001 From: Yves Date: Mon, 17 Feb 2025 16:02:40 +0100 Subject: [PATCH] Simplify TF registration --- src/include/utils/helpers.hpp | 54 ++++++++++++++++++++++++++++++++++- src/ui_extension.cpp | 44 +++++++--------------------- src/utils/helpers.cpp | 8 ++---- 3 files changed, 65 insertions(+), 41 deletions(-) diff --git a/src/include/utils/helpers.hpp b/src/include/utils/helpers.hpp index 70c7bc0..33c9e7c 100644 --- a/src/include/utils/helpers.hpp +++ b/src/include/utils/helpers.hpp @@ -1,9 +1,13 @@ #pragma once #include +#include +#include namespace duckdb { +typedef std::string (*simple_tf_t) (ClientContext &); + struct RunOnceTableFunctionState : GlobalTableFunctionState { RunOnceTableFunctionState() : run(false) {}; std::atomic run; @@ -20,8 +24,56 @@ T GetSetting(const ClientContext &context, const char *setting_name, const T def return context.TryGetCurrentSetting(setting_name, value) ? value.GetValue() : default_value; } +namespace internal { +unique_ptr ResultBind(ClientContext &, TableFunctionBindInput &, + vector &, + vector &); + bool ShouldRun(TableFunctionInput &input); -void RegisterTF(DatabaseInstance &instance, const char* name, table_function_t func); +template +struct CallFunctionHelper; + +template <> +struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)()) { + return f(); + } +}; + +template <> +struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &)) { + return f(context); + } +}; + +template <> +struct CallFunctionHelper { + static std::string call(ClientContext &context, TableFunctionInput &input, std::string(*f)(ClientContext &, TableFunctionInput &)) { + return f(context, input); + } +}; + +template +void TableFunc(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + if (!ShouldRun(input)) { + return; + } + + const std::string result = CallFunctionHelper::call(context, input, func); + output.SetCardinality(1); + output.SetValue(0, 0, result); +} + +template +void RegisterTF(DatabaseInstance &instance, const char* name) { + TableFunction tf(name, {}, internal::TableFunc, internal::ResultBind, RunOnceTableFunctionState::Init); + ExtensionUtil::RegisterFunction(instance, tf); +} + +} + +#define RESISTER_TF(name, func) internal::RegisterTF(instance, name) } // namespace duckdb diff --git a/src/ui_extension.cpp b/src/ui_extension.cpp index b66f8f9..e6a82b5 100644 --- a/src/ui_extension.cpp +++ b/src/ui_extension.cpp @@ -40,47 +40,23 @@ std::string GetHttpServerLocalURL() { } // namespace internal -void OutputResult(const std::string &result, DataChunk &out_chunk) { - out_chunk.SetCardinality(1); - out_chunk.SetValue(0, 0, result); -} - -void StartUIFunction(ClientContext &context, TableFunctionInput &input, - DataChunk &out_chunk) { - if (!ShouldRun(input)) { - return; - } - +std::string StartUIFunction(ClientContext &context) { internal::StartHttpServer(context); auto local_url = internal::GetHttpServerLocalURL(); const std::string command = StringUtil::Format("%s %s", OPEN_COMMAND, local_url); - std::string result = system(command.c_str()) ? + return system(command.c_str()) ? StringUtil::Format("Navigate browser to %s", local_url) // open command failed : StringUtil::Format("MotherDuck UI started at %s", local_url); - OutputResult(result, out_chunk); } -void StartUIServerFunction(ClientContext &context, TableFunctionInput &input, - DataChunk &out_chunk) { - if (!ShouldRun(input)) { - return; - } - - const bool already = internal::StartHttpServer(context); - const char* already_str = already ? "already " : ""; - auto result = StringUtil::Format("MotherDuck UI server %sstarted at %s", already_str, internal::GetHttpServerLocalURL()); - OutputResult(result, out_chunk); +std::string StartUIServerFunction(ClientContext &context) { + const char* already = internal::StartHttpServer(context) ? "already " : ""; + return StringUtil::Format("MotherDuck UI server %sstarted at %s", already, internal::GetHttpServerLocalURL()); } -void StopUIServerFunction(ClientContext &, TableFunctionInput &input, - DataChunk &out_chunk) { - if (!ShouldRun(input)) { - return; - } - - auto result = ui::HttpServer::instance()->Stop() ? "UI server stopped" : "UI server already stopped"; - OutputResult(result, out_chunk); +std::string StopUIServerFunction() { + return ui::HttpServer::instance()->Stop() ? "UI server stopped" : "UI server already stopped"; } // FIXME @@ -103,9 +79,9 @@ static void LoadInternal(DatabaseInstance &instance) { UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DESCRIPTION, LogicalType::VARCHAR, Value(GetEnvOrDefault(UI_REMOTE_URL_SETTING_NAME, UI_REMOTE_URL_SETTING_DEFAULT))); - RegisterTF(instance, "start_ui", StartUIFunction); - RegisterTF(instance, "start_ui_server", StartUIServerFunction); - RegisterTF(instance, "stop_ui_server", StopUIServerFunction); + RESISTER_TF("start_ui", StartUIFunction); + RESISTER_TF("start_ui_server", StartUIServerFunction); + RESISTER_TF("stop_ui_server", StopUIServerFunction); } void UiExtension::Load(DuckDB &db) { diff --git a/src/utils/helpers.cpp b/src/utils/helpers.cpp index 71c21d1..c359306 100644 --- a/src/utils/helpers.cpp +++ b/src/utils/helpers.cpp @@ -1,7 +1,7 @@ #include "utils/helpers.hpp" -#include namespace duckdb { +namespace internal { bool ShouldRun(TableFunctionInput &input) { auto state = dynamic_cast(input.global_state.get()); @@ -22,9 +22,5 @@ unique_ptr ResultBind(ClientContext &, TableFunctionBindInput &, return nullptr; } -void RegisterTF(DatabaseInstance &instance, const char* name, table_function_t func) { - TableFunction tf(name, {}, func, ResultBind, RunOnceTableFunctionState::Init); - ExtensionUtil::RegisterFunction(instance, tf); -} - +} // namespace internal } // namespace duckdb