feat(http): implement origin and referer validation for HTTP requests
Some checks failed
Build linux_amd64 extension and upload to Packages / build-linux-amd64 (push) Failing after 25m14s

Added helper functions to validate request origins and referers, allowing for configurable local URLs and support for environments where the UI is exposed externally. Introduced environment variable `ui_allow_any_origin` for bypassing origin checks. Updated request handling methods to utilize these validations, enhancing security for HTTP interactions.
This commit is contained in:
2025-09-13 19:02:04 +00:00
parent 5821ee7fc8
commit e11686f4a0

View File

@@ -22,6 +22,76 @@ namespace ui {
unique_ptr<HttpServer> HttpServer::server_instance; unique_ptr<HttpServer> HttpServer::server_instance;
// Helpers for validating request origin/referer in deployments where the UI is
// exposed on a non-localhost host (e.g., Docker, k8s, reverse proxies). These
// checks allow either the configured local_url, or the runtime host derived
// from the request headers. They also allow an escape hatch via the
// environment variable `ui_allow_any_origin=1|true`.
namespace {
// Returns true if the given referer begins with any of the expected base URLs.
static bool RefererStartsWithAny(const std::string &referer,
const std::vector<std::string> &bases) {
for (const auto &base : bases) {
if (!base.empty() && referer.compare(0, base.size(), base) == 0) {
return true;
}
}
return false;
}
static std::vector<std::string>
ExpectedBaseUrls(const httplib::Request &req, const std::string &local_url) {
// Prefer forwarded host if present, otherwise fall back to Host.
auto forwarded_host = req.get_header_value("X-Forwarded-Host");
auto host = forwarded_host.empty() ? req.get_header_value("Host")
: forwarded_host;
std::vector<std::string> bases;
bases.push_back(local_url);
if (!host.empty()) {
bases.push_back(StringUtil::Format("http://%s", host));
bases.push_back(StringUtil::Format("https://%s", host));
}
return bases;
}
static bool IsOriginAllowed(const httplib::Request &req,
const std::string &local_url) {
if (IsEnvEnabled("ui_allow_any_origin")) {
return true;
}
auto origin = req.get_header_value("Origin");
if (origin.empty()) {
return false;
}
auto bases = ExpectedBaseUrls(req, local_url);
for (const auto &base : bases) {
if (origin == base) {
return true;
}
}
return false;
}
static bool IsRefererAllowed(const httplib::Request &req,
const std::string &local_url) {
if (IsEnvEnabled("ui_allow_any_origin")) {
return true;
}
auto referer = req.get_header_value("Referer");
if (referer.empty()) {
return false;
}
return RefererStartsWithAny(referer, ExpectedBaseUrls(req, local_url));
}
} // namespace
HttpServer *HttpServer::GetInstance(ClientContext &context) { HttpServer *HttpServer::GetInstance(ClientContext &context) {
if (server_instance) { if (server_instance) {
// We already have an instance, make sure we're running on the right DB // We already have an instance, make sure we're running on the right DB
@@ -228,8 +298,7 @@ void HttpServer::HandleGetLocalToken(const httplib::Request &req,
httplib::Response &res) { httplib::Response &res) {
// GET requests don't include Origin, so use Referer instead. // GET requests don't include Origin, so use Referer instead.
// Referer includes the path, so only compare the start. // Referer includes the path, so only compare the start.
auto referer = req.get_header_value("Referer"); if (!IsRefererAllowed(req, local_url)) {
if (referer.compare(0, local_url.size(), local_url) != 0) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -321,8 +390,7 @@ void HttpServer::HandleGet(const httplib::Request &req,
void HttpServer::HandleInterrupt(const httplib::Request &req, void HttpServer::HandleInterrupt(const httplib::Request &req,
httplib::Response &res) { httplib::Response &res) {
auto origin = req.get_header_value("Origin"); if (!IsOriginAllowed(req, local_url)) {
if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -361,8 +429,7 @@ void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res,
void HttpServer::DoHandleRun(const httplib::Request &req, void HttpServer::DoHandleRun(const httplib::Request &req,
httplib::Response &res, httplib::Response &res,
const httplib::ContentReader &content_reader) { const httplib::ContentReader &content_reader) {
auto origin = req.get_header_value("Origin"); if (!IsOriginAllowed(req, local_url)) {
if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -625,8 +692,7 @@ void HttpServer::DoHandleRun(const httplib::Request &req,
void HttpServer::HandleTokenize(const httplib::Request &req, void HttpServer::HandleTokenize(const httplib::Request &req,
httplib::Response &res, httplib::Response &res,
const httplib::ContentReader &content_reader) { const httplib::ContentReader &content_reader) {
auto origin = req.get_header_value("Origin"); if (!IsOriginAllowed(req, local_url)) {
if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }