diff --git a/src/http_server.cpp b/src/http_server.cpp index 13886f1..260fa78 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -22,6 +22,76 @@ namespace ui { unique_ptr HttpServer::server_instance; +// Helpers for validating request origin/referer in deployments where the UI is +// exposed on a non-localhost host (e.g., Docker, k8s, reverse proxies). These +// checks allow either the configured local_url, or the runtime host derived +// from the request headers. They also allow an escape hatch via the +// environment variable `ui_allow_any_origin=1|true`. +namespace { + +// Returns true if the given referer begins with any of the expected base URLs. +static bool RefererStartsWithAny(const std::string &referer, + const std::vector &bases) { + for (const auto &base : bases) { + if (!base.empty() && referer.compare(0, base.size(), base) == 0) { + return true; + } + } + return false; +} + +static std::vector +ExpectedBaseUrls(const httplib::Request &req, const std::string &local_url) { + // Prefer forwarded host if present, otherwise fall back to Host. + auto forwarded_host = req.get_header_value("X-Forwarded-Host"); + auto host = forwarded_host.empty() ? req.get_header_value("Host") + : forwarded_host; + + std::vector bases; + bases.push_back(local_url); + if (!host.empty()) { + bases.push_back(StringUtil::Format("http://%s", host)); + bases.push_back(StringUtil::Format("https://%s", host)); + } + return bases; +} + +static bool IsOriginAllowed(const httplib::Request &req, + const std::string &local_url) { + if (IsEnvEnabled("ui_allow_any_origin")) { + return true; + } + + auto origin = req.get_header_value("Origin"); + if (origin.empty()) { + return false; + } + + auto bases = ExpectedBaseUrls(req, local_url); + for (const auto &base : bases) { + if (origin == base) { + return true; + } + } + return false; +} + +static bool IsRefererAllowed(const httplib::Request &req, + const std::string &local_url) { + if (IsEnvEnabled("ui_allow_any_origin")) { + return true; + } + + auto referer = req.get_header_value("Referer"); + if (referer.empty()) { + return false; + } + + return RefererStartsWithAny(referer, ExpectedBaseUrls(req, local_url)); +} + +} // namespace + HttpServer *HttpServer::GetInstance(ClientContext &context) { if (server_instance) { // We already have an instance, make sure we're running on the right DB @@ -228,8 +298,7 @@ void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { // GET requests don't include Origin, so use Referer instead. // Referer includes the path, so only compare the start. - auto referer = req.get_header_value("Referer"); - if (referer.compare(0, local_url.size(), local_url) != 0) { + if (!IsRefererAllowed(req, local_url)) { res.status = 401; return; } @@ -321,8 +390,7 @@ void HttpServer::HandleGet(const httplib::Request &req, void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; } @@ -361,8 +429,7 @@ void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, void HttpServer::DoHandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; } @@ -625,8 +692,7 @@ void HttpServer::DoHandleRun(const httplib::Request &req, void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; }