From e11686f4a06301144ce618f354fd12d077c8c5a9 Mon Sep 17 00:00:00 2001 From: Eric Liu Date: Sat, 13 Sep 2025 19:02:04 +0000 Subject: [PATCH] feat(http): implement origin and referer validation for HTTP requests Added helper functions to validate request origins and referers, allowing for configurable local URLs and support for environments where the UI is exposed externally. Introduced environment variable `ui_allow_any_origin` for bypassing origin checks. Updated request handling methods to utilize these validations, enhancing security for HTTP interactions. --- src/http_server.cpp | 82 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 8 deletions(-) diff --git a/src/http_server.cpp b/src/http_server.cpp index 13886f1..260fa78 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -22,6 +22,76 @@ namespace ui { unique_ptr HttpServer::server_instance; +// Helpers for validating request origin/referer in deployments where the UI is +// exposed on a non-localhost host (e.g., Docker, k8s, reverse proxies). These +// checks allow either the configured local_url, or the runtime host derived +// from the request headers. They also allow an escape hatch via the +// environment variable `ui_allow_any_origin=1|true`. +namespace { + +// Returns true if the given referer begins with any of the expected base URLs. +static bool RefererStartsWithAny(const std::string &referer, + const std::vector &bases) { + for (const auto &base : bases) { + if (!base.empty() && referer.compare(0, base.size(), base) == 0) { + return true; + } + } + return false; +} + +static std::vector +ExpectedBaseUrls(const httplib::Request &req, const std::string &local_url) { + // Prefer forwarded host if present, otherwise fall back to Host. + auto forwarded_host = req.get_header_value("X-Forwarded-Host"); + auto host = forwarded_host.empty() ? req.get_header_value("Host") + : forwarded_host; + + std::vector bases; + bases.push_back(local_url); + if (!host.empty()) { + bases.push_back(StringUtil::Format("http://%s", host)); + bases.push_back(StringUtil::Format("https://%s", host)); + } + return bases; +} + +static bool IsOriginAllowed(const httplib::Request &req, + const std::string &local_url) { + if (IsEnvEnabled("ui_allow_any_origin")) { + return true; + } + + auto origin = req.get_header_value("Origin"); + if (origin.empty()) { + return false; + } + + auto bases = ExpectedBaseUrls(req, local_url); + for (const auto &base : bases) { + if (origin == base) { + return true; + } + } + return false; +} + +static bool IsRefererAllowed(const httplib::Request &req, + const std::string &local_url) { + if (IsEnvEnabled("ui_allow_any_origin")) { + return true; + } + + auto referer = req.get_header_value("Referer"); + if (referer.empty()) { + return false; + } + + return RefererStartsWithAny(referer, ExpectedBaseUrls(req, local_url)); +} + +} // namespace + HttpServer *HttpServer::GetInstance(ClientContext &context) { if (server_instance) { // We already have an instance, make sure we're running on the right DB @@ -228,8 +298,7 @@ void HttpServer::HandleGetLocalToken(const httplib::Request &req, httplib::Response &res) { // GET requests don't include Origin, so use Referer instead. // Referer includes the path, so only compare the start. - auto referer = req.get_header_value("Referer"); - if (referer.compare(0, local_url.size(), local_url) != 0) { + if (!IsRefererAllowed(req, local_url)) { res.status = 401; return; } @@ -321,8 +390,7 @@ void HttpServer::HandleGet(const httplib::Request &req, void HttpServer::HandleInterrupt(const httplib::Request &req, httplib::Response &res) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; } @@ -361,8 +429,7 @@ void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res, void HttpServer::DoHandleRun(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; } @@ -625,8 +692,7 @@ void HttpServer::DoHandleRun(const httplib::Request &req, void HttpServer::HandleTokenize(const httplib::Request &req, httplib::Response &res, const httplib::ContentReader &content_reader) { - auto origin = req.get_header_value("Origin"); - if (origin != local_url) { + if (!IsOriginAllowed(req, local_url)) { res.status = 401; return; }