Skip to content

Commit

Permalink
Merge pull request ClickHouse#56488 from lingtaolf/feature/getHTTPHeader
Browse files Browse the repository at this point in the history
add function getClientHTTPHeader
  • Loading branch information
antonio2368 authored Nov 28, 2023
2 parents ae09d04 + ed7f19c commit a61f328
Show file tree
Hide file tree
Showing 16 changed files with 299 additions and 5 deletions.
38 changes: 38 additions & 0 deletions docs/en/sql-reference/functions/other-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,45 @@ WHERE macro = 'test';
│ test │ Value │
└───────┴──────────────┘
```

## getClientHTTPHeader
Returns the value of specified http header.If there is no such header or the request method is not http, it will throw an exception.

**Syntax**

```sql
getClientHTTPHeader(name);
```

**Arguments**

- `name` — HTTP header name .[String](../../sql-reference/data-types/string.md#string)

**Returned value**

Value of the specified header.
Type:[String](../../sql-reference/data-types/string.md#string).


When we use `clickhouse-client` to execute this function, we'll always get empty string, because client doesn't use http protocol.
```sql
SELECT getCientHTTPHeader('test')
```
result:

```text
┌─getClientHTTPHeader('test')─┐
│ │
└────────────------───────────┘
```
Try to use http request:
```shell
echo "select getClientHTTPHeader('X-Clickhouse-User')" | curl -H 'X-ClickHouse-User: default' -H 'X-ClickHouse-Key: ' 'http://localhost:8123/' -d @-

#result
default
```

## FQDN

Returns the fully qualified domain name of the ClickHouse server.
Expand Down
2 changes: 2 additions & 0 deletions programs/server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,8 @@ try
global_context->setHTTPHeaderFilter(*config);

global_context->setMaxTableSizeToDrop(server_settings_.max_table_size_to_drop);
global_context->setClientHTTPHeaderForbiddenHeaders(server_settings_.get_client_http_header_forbidden_headers);
global_context->setAllowGetHTTPHeaderFunction(server_settings_.allow_get_client_http_header);
global_context->setMaxPartitionSizeToDrop(server_settings_.max_partition_size_to_drop);

ConcurrencyControl::SlotCount concurrent_threads_soft_limit = ConcurrencyControl::Unlimited;
Expand Down
2 changes: 2 additions & 0 deletions src/Core/ServerSettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ namespace DB
M(Double, total_memory_tracker_sample_probability, 0, "Collect random allocations and deallocations and write them into system.trace_log with 'MemorySample' trace_type. The probability is for every alloc/free regardless to the size of the allocation (can be changed with `memory_profiler_sample_min_allocation_size` and `memory_profiler_sample_max_allocation_size`). Note that sampling happens only when the amount of untracked memory exceeds 'max_untracked_memory'. You may want to set 'max_untracked_memory' to 0 for extra fine grained sampling.", 0) \
M(UInt64, total_memory_profiler_sample_min_allocation_size, 0, "Collect random allocations of size greater or equal than specified value with probability equal to `total_memory_profiler_sample_probability`. 0 means disabled. You may want to set 'max_untracked_memory' to 0 to make this threshold to work as expected.", 0) \
M(UInt64, total_memory_profiler_sample_max_allocation_size, 0, "Collect random allocations of size less or equal than specified value with probability equal to `total_memory_profiler_sample_probability`. 0 means disabled. You may want to set 'max_untracked_memory' to 0 to make this threshold to work as expected.", 0) \
M(String, get_client_http_header_forbidden_headers, "", "Comma separated list of http header names that will not be returned by function getClientHTTPHeader.", 0) \
M(Bool, allow_get_client_http_header, false, "Allow function getClientHTTPHeader", 0) \
M(Bool, validate_tcp_client_information, false, "Validate client_information in the query packet over the native TCP protocol.", 0) \
M(Bool, storage_metadata_write_full_object_key, false, "Write disk metadata files with VERSION_FULL_OBJECT_KEY format", 0) \

Expand Down
116 changes: 116 additions & 0 deletions src/Functions/getClientHTTPHeader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Interpreters/Context.h>
#include <Common/CurrentThread.h>
#include "Disks/DiskType.h"
#include "Interpreters/Context_fwd.h"
#include <Core/Field.h>
#include <Poco/Net/NameValueCollection.h>


namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int FUNCTION_NOT_ALLOWED;
extern const int BAD_ARGUMENTS;
}

namespace
{

/** Get the value of parameter in http headers.
* If there no such parameter or the method of request is not
* http, the function will throw an exception.
*/
class FunctionGetClientHTTPHeader : public IFunction, WithContext
{
private:

public:
explicit FunctionGetClientHTTPHeader(ContextPtr context_): WithContext(context_) {}

static constexpr auto name = "getClientHTTPHeader";

static FunctionPtr create(ContextPtr context_)
{
return std::make_shared<FunctionGetClientHTTPHeader>(context_);
}

bool useDefaultImplementationForConstants() const override { return true; }

String getName() const override { return name; }

bool isDeterministic() const override { return false; }

bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }


size_t getNumberOfArguments() const override
{
return 1;
}

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!getContext()->allowGetHTTPHeaderFunction())
throw Exception(ErrorCodes::FUNCTION_NOT_ALLOWED, "The function {} is not enabled, you can set allow_get_client_http_header in config file.", getName());

if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The argument of function {} must have String type", getName());
return std::make_shared<DataTypeString>();
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
const auto & client_info = getContext()->getClientInfo();
const auto & method = client_info.http_method;
const auto & headers = client_info.headers;
const IColumn * arg_column = arguments[0].column.get();
const ColumnString * arg_string = checkAndGetColumn<ColumnString>(arg_column);

if (!arg_string)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The argument of function {} must be constant String", getName());

if (method != ClientInfo::HTTPMethod::GET && method != ClientInfo::HTTPMethod::POST)
return result_type->createColumnConstWithDefaultValue(input_rows_count);

auto result_column = ColumnString::create();

const String default_value;
const std::unordered_set<String> & forbidden_header_list = getContext()->getClientHTTPHeaderForbiddenHeaders();

for (size_t row = 0; row < input_rows_count; ++row)
{
auto header_name = arg_string->getDataAt(row).toString();

if (!headers.has(header_name))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} is not in HTTP request headers.", header_name);
else
{
auto it = forbidden_header_list.find(header_name);
if (it != forbidden_header_list.end())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The header {} is in get_client_http_header_forbidden_headers, you can config it in config file.", header_name);

const String & value = headers[header_name];
result_column->insertData(value.data(), value.size());
}
}

return result_column;
}
};

}

REGISTER_FUNCTION(GetHttpHeader)
{
factory.registerFunction<FunctionGetClientHTTPHeader>();
}

}
2 changes: 2 additions & 0 deletions src/Interpreters/ClientInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <Core/UUID.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/Net/NameValueCollection.h>
#include <base/types.h>
#include <Common/OpenTelemetryTraceContext.h>
#include <Common/VersionNumber.h>
Expand Down Expand Up @@ -96,6 +97,7 @@ class ClientInfo

/// For mysql and postgresql
UInt64 connection_id = 0;
Poco::Net::NameValueCollection headers;

/// Comma separated list of forwarded IP addresses (from X-Forwarded-For for HTTP interface).
/// It's expected that proxy appends the forwarded address to the end of the list.
Expand Down
30 changes: 29 additions & 1 deletion src/Interpreters/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <optional>
#include <memory>
#include <Poco/UUID.h>
#include <Poco/Net/NameValueCollection.h>
#include <Poco/Util/Application.h>
#include <Common/SensitiveDataMasker.h>
#include <Common/Macros.h>
Expand Down Expand Up @@ -322,6 +323,8 @@ struct ContextSharedPart : boost::noncopyable
std::optional<MergeTreeSettings> merge_tree_settings TSA_GUARDED_BY(mutex); /// Settings of MergeTree* engines.
std::optional<MergeTreeSettings> replicated_merge_tree_settings TSA_GUARDED_BY(mutex); /// Settings of ReplicatedMergeTree* engines.
std::atomic_size_t max_table_size_to_drop = 50000000000lu; /// Protects MergeTree tables from accidental DROP (50GB by default)
std::unordered_set<String> get_client_http_header_forbidden_headers;
bool allow_get_client_http_header;
std::atomic_size_t max_partition_size_to_drop = 50000000000lu; /// Protects MergeTree partitions from accidental DROP (50GB by default)
/// No lock required for format_schema_path modified only during initialization
String format_schema_path; /// Path to a directory that contains schema files used by input formats.
Expand Down Expand Up @@ -3950,6 +3953,28 @@ void Context::checkTableCanBeDropped(const String & database, const String & tab
}


void Context::setClientHTTPHeaderForbiddenHeaders(const String & forbidden_headers)
{
std::unordered_set<String> forbidden_header_list;
boost::split(forbidden_header_list, forbidden_headers, [](char c) { return c == ','; });
shared->get_client_http_header_forbidden_headers = forbidden_header_list;
}

void Context::setAllowGetHTTPHeaderFunction(bool allow_get_http_header_function)
{
shared->allow_get_client_http_header= allow_get_http_header_function;
}

const std::unordered_set<String> & Context::getClientHTTPHeaderForbiddenHeaders() const
{
return shared->get_client_http_header_forbidden_headers;
}

bool Context::allowGetHTTPHeaderFunction() const
{
return shared->allow_get_client_http_header;
}

void Context::setMaxPartitionSizeToDrop(size_t max_size)
{
// Is initialized at server startup and updated at config reload
Expand Down Expand Up @@ -4270,12 +4295,15 @@ void Context::setClientConnectionId(uint32_t connection_id_)
client_info.connection_id = connection_id_;
}

void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers)
{
client_info.http_method = http_method;
client_info.http_user_agent = http_user_agent;
client_info.http_referer = http_referer;
need_recalculate_access = true;

if (!http_headers.empty())
client_info.headers = http_headers;
}

void Context::setForwardedFor(const String & forwarded_for)
Expand Down
9 changes: 8 additions & 1 deletion src/Interpreters/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <Server/HTTP/HTTPContext.h>
#include <Storages/ColumnsDescription.h>
#include <Storages/IStorage_fwd.h>
#include <Poco/Net/NameValueCollection.h>
#include <Core/Types.h>

#include "config.h"

Expand Down Expand Up @@ -640,7 +642,7 @@ class Context: public ContextData, public std::enable_shared_from_this<Context>
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {});
void setForwardedFor(const String & forwarded_for);
void setQueryKind(ClientInfo::QueryKind query_kind);
void setQueryKindInitial();
Expand Down Expand Up @@ -1073,6 +1075,11 @@ class Context: public ContextData, public std::enable_shared_from_this<Context>
/// Prevents DROP TABLE if its size is greater than max_size (50GB by default, max_size=0 turn off this check)
void setMaxTableSizeToDrop(size_t max_size);
size_t getMaxTableSizeToDrop() const;
void setClientHTTPHeaderForbiddenHeaders(const String & forbidden_headers);
/// Return the forbiddent headers that users can't get via getClientHTTPHeader function
const std::unordered_set<String> & getClientHTTPHeaderForbiddenHeaders() const;
void setAllowGetHTTPHeaderFunction(bool allow_get_http_header_function);
bool allowGetHTTPHeaderFunction() const;
void checkTableCanBeDropped(const String & database, const String & table, const size_t & table_size) const;

/// Prevents DROP PARTITION if its size is greater than max_size (50GB by default, max_size=0 turn off this check)
Expand Down
4 changes: 3 additions & 1 deletion src/Interpreters/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <Interpreters/Cluster.h>

#include <magic_enum.hpp>
#include <Poco/Net/NameValueCollection.h>

#include <atomic>
#include <condition_variable>
Expand Down Expand Up @@ -431,7 +432,7 @@ void Session::setClientConnectionId(uint32_t connection_id)
prepared_client_info->connection_id = connection_id;
}

void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers)
{
if (session_context)
{
Expand All @@ -442,6 +443,7 @@ void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String
prepared_client_info->http_method = http_method;
prepared_client_info->http_user_agent = http_user_agent;
prepared_client_info->http_referer = http_referer;
prepared_client_info->headers = http_headers;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Interpreters/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/SessionTracker.h>
#include <Poco/Net/NameValueCollection.h>

#include <chrono>
#include <memory>
Expand Down Expand Up @@ -64,7 +65,7 @@ class Session
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {});
void setForwardedFor(const String & forwarded_for);
void setQuotaClientKey(const String & quota_key);
void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
Expand Down
3 changes: 2 additions & 1 deletion src/Server/HTTPHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <Poco/StreamCopier.h>
#include <Poco/String.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/Net/NameValueCollection.h>

#include <chrono>
#include <sstream>
Expand Down Expand Up @@ -502,7 +503,7 @@ bool HTTPHandler::authenticateUser(
else if (request.getMethod() == HTTPServerRequest::HTTP_POST)
http_method = ClientInfo::HTTPMethod::POST;

session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""));
session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""), request);
session->setForwardedFor(request.get("X-Forwarded-For", ""));
session->setQuotaClientKey(quota_key);

Expand Down
4 changes: 4 additions & 0 deletions tests/config/config.d/forbidden_get_client_http_headers.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<clickhouse>
<get_client_http_header_forbidden_headers>FORBIDDEN-KEY1,FORBIDDEN-KEY2</get_client_http_header_forbidden_headers>
<allow_get_client_http_header>1</allow_get_client_http_header>
</clickhouse>
1 change: 1 addition & 0 deletions tests/config/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mkdir -p $DEST_SERVER_PATH/config.d/
mkdir -p $DEST_SERVER_PATH/users.d/
mkdir -p $DEST_CLIENT_PATH

ln -sf $SRC_PATH/config.d/forbidden_get_client_http_headers.xml $DEST_SERVER_PATH/config.d/
ln -sf $SRC_PATH/config.d/zookeeper_write.xml $DEST_SERVER_PATH/config.d/
ln -sf $SRC_PATH/config.d/listen.xml $DEST_SERVER_PATH/config.d/
ln -sf $SRC_PATH/config.d/text_log.xml $DEST_SERVER_PATH/config.d/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ geoDistance
geohashDecode
geohashEncode
geohashesInBox
getClientHTTPHeader
getMacro
getOSKernelVersion
getServerPort
Expand Down
13 changes: 13 additions & 0 deletions tests/queries/0_stateless/02911_getHTTPHeaderFuncion.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
value
value1 value2
value1 value1 value2
NOT-FOUND-KEY is not in HTTP request headers
FORBIDDEN-KEY1 is in get_client_http_header_forbidden_headers
1 row1_value1 row1_value2 row1_value3 row1_value4 row1_value5 row1_value6 row1_value7
2 row2_value1 row2_value2 row2_value3 row2_value4 row2_value5 row2_value6 row2_value7
3
value_from_query_1 value_from_query_2 value_from_query_3 1 row1_value1 row1_value2 row1_value3 row1_value4 row1_value5 row1_value6 row1_value7
value_from_query_1 value_from_query_2 value_from_query_3 2 row2_value1 row2_value2 row2_value3 row2_value4 row2_value5 row2_value6 row2_value7
value_from_query_1 value_from_query_2 value_from_query_3 3
http_value1
http_value2
Loading

0 comments on commit a61f328

Please sign in to comment.