Skip to content

Commit

Permalink
Use OAuth2 if externalAuthentication is present in connection url
Browse files Browse the repository at this point in the history
After this change if 'externalAuthentication' is passed as a parameter
on the connection url we automatically set `http_schema` to `http` and
use `OAuth2Authentication`.
  • Loading branch information
Pablo Takara authored and hashhar committed Sep 22, 2023
1 parent 0584c93 commit d9d46b0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 10 additions & 1 deletion tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from sqlalchemy.engine.url import URL, make_url

from trino.auth import BasicAuthentication
from trino.auth import BasicAuthentication, OAuth2Authentication
from trino.dbapi import Connection
from trino.sqlalchemy import URL as trino_url
from trino.sqlalchemy.dialect import (
Expand Down Expand Up @@ -296,3 +296,12 @@ def test_trino_connection_certificate_auth():
assert isinstance(cparams['auth'], CertificateAuthentication)
assert cparams['auth']._cert == cert
assert cparams['auth']._key == key


def test_trino_connection_oauth2_auth():
dialect = TrinoDialect()
url = make_url('trino://host/?externalAuthentication=true')
_, cparams = dialect.create_connect_args(url)

assert cparams['http_scheme'] == "https"
assert isinstance(cparams['auth'], OAuth2Authentication)
11 changes: 10 additions & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

from trino import dbapi as trino_dbapi
from trino import logging
from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication
from trino.auth import (
BasicAuthentication,
CertificateAuthentication,
JWTAuthentication,
OAuth2Authentication,
)
from trino.dbapi import Cursor
from trino.sqlalchemy import compiler, datatype, error

Expand Down Expand Up @@ -113,6 +118,10 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
kwargs["http_scheme"] = "https"
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))

if "externalAuthentication" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = OAuth2Authentication()

if "source" in url.query:
kwargs["source"] = unquote_plus(url.query["source"])
else:
Expand Down

0 comments on commit d9d46b0

Please sign in to comment.