From 0584c938a194e634a6e8ed50e63360e7a000f743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Can=20G=C3=BCney=20Aksakalli?= Date: Wed, 20 Sep 2023 15:00:07 +0200 Subject: [PATCH] add Trino Python Client version to user agent --- tests/unit/test_client.py | 6 ++++-- trino/client.py | 4 +++- trino/constants.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index bbd953e8..d1b23ae7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -38,7 +38,7 @@ _get_token_requests, _post_statement_requests, ) -from trino import constants +from trino import __version__, constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer from trino.client import ( ClientSession, @@ -125,6 +125,7 @@ def assert_headers(headers): assert headers[constants.HEADER_SOURCE] == source assert headers[constants.HEADER_USER] == user assert headers[constants.HEADER_SESSION] == "" + assert headers[constants.HEADER_TRANSACTION] is None assert headers[constants.HEADER_TIMEZONE] == timezone assert headers[constants.HEADER_CLIENT_CAPABILITIES] == "PARAMETRIC_DATETIME" assert headers[accept_encoding_header] == accept_encoding_value @@ -135,7 +136,8 @@ def assert_headers(headers): "catalog1=NONE," "catalog2=" + urllib.parse.quote("ROLE{catalog2_role}") ) - assert len(headers.keys()) == 11 + assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}" + assert len(headers.keys()) == 12 req.post("URL") _, post_kwargs = post.call_args diff --git a/trino/client.py b/trino/client.py index 9e9ce3ee..e262f626 100644 --- a/trino/client.py +++ b/trino/client.py @@ -62,6 +62,7 @@ import trino.logging from trino import constants, exceptions +from trino._version import __version__ __all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"] @@ -447,7 +448,7 @@ def transaction_id(self, value): @property def http_headers(self) -> Dict[str, str]: - headers = {} + headers = requests.structures.CaseInsensitiveDict() headers[constants.HEADER_CATALOG] = self._client_session.catalog headers[constants.HEADER_SCHEMA] = self._client_session.schema @@ -455,6 +456,7 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' + headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" if len(self._client_session.roles.values()): headers[constants.HEADER_ROLE] = ",".join( # ``name`` must not contain ``=`` diff --git a/trino/constants.py b/trino/constants.py index c9527a3b..2199105c 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -26,6 +26,8 @@ URL_STATEMENT_PATH = "/v1/statement" +CLIENT_NAME = "Trino Python Client" + HEADER_CATALOG = "X-Trino-Catalog" HEADER_SCHEMA = "X-Trino-Schema" HEADER_SOURCE = "X-Trino-Source"