Merge pull request #20 from duckdb/jray/check-origin-header

check origin header
This commit is contained in:
Jeff Raymakers
2025-03-11 12:41:09 -07:00
committed by GitHub
2 changed files with 12 additions and 8 deletions

View File

@@ -106,6 +106,7 @@ void HttpServer::DoStart(const uint16_t _local_port,
} }
local_port = _local_port; local_port = _local_port;
local_url = StringUtil::Format("http://localhost:%d", local_port);
remote_url = _remote_url; remote_url = _remote_url;
user_agent = user_agent =
StringUtil::Format("duckdb-ui/%s-%s(%s)", DuckDB::LibraryVersion(), StringUtil::Format("duckdb-ui/%s-%s(%s)", DuckDB::LibraryVersion(),
@@ -211,8 +212,10 @@ void HttpServer::HandleGetLocalEvents(const httplib::Request &req,
void HttpServer::HandleGetLocalToken(const httplib::Request &req, void HttpServer::HandleGetLocalToken(const httplib::Request &req,
httplib::Response &res) { httplib::Response &res) {
auto sec_fetch_site = req.get_header_value("Sec-Fetch-Site"); // GET requests don't include Origin, so use Referer instead.
if (sec_fetch_site == "cross-site") { // Referer includes the path, so only compare the start.
auto referer = req.get_header_value("Referer");
if (referer.compare(0, local_url.size(), local_url) != 0) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -276,8 +279,8 @@ void HttpServer::HandleGet(const httplib::Request &req,
void HttpServer::HandleInterrupt(const httplib::Request &req, void HttpServer::HandleInterrupt(const httplib::Request &req,
httplib::Response &res) { httplib::Response &res) {
auto sec_fetch_site = req.get_header_value("Sec-Fetch-Site"); auto origin = req.get_header_value("Origin");
if (sec_fetch_site == "cross-site") { if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -316,8 +319,8 @@ void HttpServer::HandleRun(const httplib::Request &req, httplib::Response &res,
void HttpServer::DoHandleRun(const httplib::Request &req, void HttpServer::DoHandleRun(const httplib::Request &req,
httplib::Response &res, httplib::Response &res,
const httplib::ContentReader &content_reader) { const httplib::ContentReader &content_reader) {
auto sec_fetch_site = req.get_header_value("Sec-Fetch-Site"); auto origin = req.get_header_value("Origin");
if (sec_fetch_site == "cross-site") { if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }
@@ -438,8 +441,8 @@ void HttpServer::DoHandleRun(const httplib::Request &req,
void HttpServer::HandleTokenize(const httplib::Request &req, void HttpServer::HandleTokenize(const httplib::Request &req,
httplib::Response &res, httplib::Response &res,
const httplib::ContentReader &content_reader) { const httplib::ContentReader &content_reader) {
auto sec_fetch_site = req.get_header_value("Sec-Fetch-Site"); auto origin = req.get_header_value("Origin");
if (sec_fetch_site == "cross-site") { if (origin != local_url) {
res.status = 401; res.status = 401;
return; return;
} }

View File

@@ -69,6 +69,7 @@ private:
shared_ptr<DatabaseInstance> LockDatabaseInstance(); shared_ptr<DatabaseInstance> LockDatabaseInstance();
uint16_t local_port; uint16_t local_port;
std::string local_url;
std::string remote_url; std::string remote_url;
weak_ptr<DatabaseInstance> ddb_instance; weak_ptr<DatabaseInstance> ddb_instance;
std::string user_agent; std::string user_agent;