diff --git a/prestodb/client.py b/prestodb/client.py index 0b10e1b..255930c 100644 --- a/prestodb/client.py +++ b/prestodb/client.py @@ -229,6 +229,7 @@ def __init__( # mypy cannot follow module import self._http_session = self.http.Session() # type: ignore self._http_session.headers.update(self.http_headers) + self._http_cookies = {} self._auth = auth if self._auth: if http_scheme == constants.HTTP: @@ -304,9 +305,19 @@ def max_attempts(self, value): ), max_attempts=self._max_attempts, ) - self._get = with_retry(self._http_session.get) - self._post = with_retry(self._http_session.post) - self._delete = with_retry(self._http_session.delete) + def get(*args, **kwargs): + return self._http_session.get( + *args, cookies=self._http_cookies, **kwargs) + def post(*args, **kwargs): + return self._http_session.post( + *args, cookies=self._http_cookies, **kwargs) + def delete(*args, **kwargs): + return self._http_session.delete( + *args, cookies=self._http_cookies, **kwargs) + + self._get = with_retry(get) + self._post = with_retry(post) + self._delete = with_retry(delete) def get_url(self, path): # type: (Text) -> Text @@ -418,6 +429,9 @@ def process(self, http_response): ): self._client_session.properties[key] = value + if http_response.cookies: + self._http_cookies.update(http_response.cookies) + self._next_uri = response.get('nextUri') return PrestoStatus( diff --git a/tests/test_client.py b/tests/test_client.py index 6892f0f..303121c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -420,6 +420,28 @@ def test_presto_fetch_request(monkeypatch): assert status.rows == RESP_DATA_GET_0['data'] +def test_presto_fetch_request_with_cookie(monkeypatch): + monkeypatch.setattr(PrestoRequest.http.Response, 'json', get_json_get_0) + + req = PrestoRequest( + host='coordinator', + port=8080, + user='test', + source='test', + catalog='test', + schema='test', + http_scheme='http', + session_properties={}, + ) + + http_resp = PrestoRequest.http.Response() + http_resp.status_code = 200 + http_resp.cookies['key'] = 'value' + req.process(http_resp) + + assert req._http_cookies['key'] == 'value' + + def test_presto_fetch_error(monkeypatch): monkeypatch.setattr( PrestoRequest.http.Response,