Skip to content

Commit

Permalink
[FSTORE-1453] Move client, decorators, variable_api and constants to …
Browse files Browse the repository at this point in the history
…hopsworks_common (logicalclocks#229)

* Move client

* Move decorators, adapt client

* Create aliases for client and decorators

* Remove _python_version from client

* Merge hsfs client

* Move online_store_rest_client

* Adapt online_store_rest_client

* Create aliases for variable_api

* Fix test_online_store_rest_client

* Fix __all__ in client/__init__

* Merge hsfs/decorators

* Move constants to hopsworks_common

* Create alias for constants

* Fix mistype in decorators

* Move constants to core

* Make alias for decorators in hsml
  • Loading branch information
aversey committed Jul 18, 2024
1 parent 0790dab commit d77b1dd
Show file tree
Hide file tree
Showing 32 changed files with 2,188 additions and 2,396 deletions.
83 changes: 25 additions & 58 deletions python/hopsworks/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2022 Logical Clocks AB
# Copyright 2024 Hopsworks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,60 +14,27 @@
# limitations under the License.
#

from hopsworks.client import external, hopsworks


_client = None
_python_version = None


def init(
client_type,
host=None,
port=None,
project=None,
hostname_verification=None,
trust_store_path=None,
cert_folder=None,
api_key_file=None,
api_key_value=None,
):
global _client
if not _client:
if client_type == "hopsworks":
_client = hopsworks.Client()
elif client_type == "external":
_client = external.Client(
host,
port,
project,
hostname_verification,
trust_store_path,
cert_folder,
api_key_file,
api_key_value,
)


def get_instance():
global _client
if _client:
return _client
raise Exception("Couldn't find client. Try reconnecting to Hopsworks.")


def get_python_version():
global _python_version
return _python_version


def set_python_version(python_version):
global _python_version
_python_version = python_version


def stop():
global _client
if _client:
_client._close()
_client = None
from hopsworks_common.client import (
auth,
base,
exceptions,
external,
get_instance,
hopsworks,
init,
online_store_rest_client,
stop,
)


__all__ = [
auth,
base,
exceptions,
external,
get_instance,
hopsworks,
init,
online_store_rest_client,
stop,
]
33 changes: 11 additions & 22 deletions python/hopsworks/client/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2022 Logical Clocks AB
# Copyright 2024 Hopsworks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,26 +14,15 @@
# limitations under the License.
#

import requests
from hopsworks_common.client.auth import (
ApiKeyAuth,
BearerAuth,
OnlineStoreKeyAuth,
)


class BearerAuth(requests.auth.AuthBase):
"""Class to encapsulate a Bearer token."""

def __init__(self, token):
self._token = token

def __call__(self, r):
r.headers["Authorization"] = "Bearer " + self._token.strip()
return r


class ApiKeyAuth(requests.auth.AuthBase):
"""Class to encapsulate an API key."""

def __init__(self, token):
self._token = token

def __call__(self, r):
r.headers["Authorization"] = "ApiKey " + self._token.strip()
return r
__all__ = [
ApiKeyAuth,
BearerAuth,
OnlineStoreKeyAuth,
]
175 changes: 7 additions & 168 deletions python/hopsworks/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2022 Logical Clocks AB
# Copyright 2024 Hopsworks AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,172 +14,11 @@
# limitations under the License.
#

import os
from abc import ABC, abstractmethod
from hopsworks_common.client.base import (
Client,
)

import furl
import requests
import urllib3
from hopsworks.client import auth, exceptions
from hopsworks.decorators import connected


urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class Client(ABC):
TOKEN_FILE = "token.jwt"
APIKEY_FILE = "api.key"
REST_ENDPOINT = "REST_ENDPOINT"
HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"

@abstractmethod
def __init__(self):
"""To be implemented by clients."""
pass

def _get_verify(self, verify, trust_store_path):
"""Get verification method for sending HTTP requests to Hopsworks.
Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c
:param verify: perform hostname verification, 'true' or 'false'
:type verify: str
:param trust_store_path: path of the truststore locally if it was uploaded manually to
the external environment such as AWS Sagemaker
:type trust_store_path: str
:return: if verify is true and the truststore is provided, then return the trust store location
if verify is true but the truststore wasn't provided, then return true
if verify is false, then return false
:rtype: str or boolean
"""
if verify == "true":
if trust_store_path is not None:
return trust_store_path
else:
return True

return False

def _get_host_port_pair(self):
"""
Removes "http or https" from the rest endpoint and returns a list
[endpoint, port], where endpoint is on the format /path.. without http://
:return: a list [endpoint, port]
:rtype: list
"""
endpoint = self._base_url
if "http" in endpoint:
last_index = endpoint.rfind("/")
endpoint = endpoint[last_index + 1 :]
host, port = endpoint.split(":")
return host, port

def _read_jwt(self):
"""Retrieve jwt from local container."""
return self._read_file(self.TOKEN_FILE)

def _read_apikey(self):
"""Retrieve apikey from local container."""
return self._read_file(self.APIKEY_FILE)

def _read_file(self, secret_file):
"""Retrieve secret from local container."""
with open(os.path.join(self._secrets_dir, secret_file), "r") as secret:
return secret.read()

def _get_credentials(self, project_id):
"""Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive
:param project_id: id of the project
:type project_id: int
:return: JSON response with credentials
:rtype: dict
"""
return self._send_request("GET", ["project", project_id, "credentials"])

def _write_pem_file(self, content: str, path: str) -> None:
with open(path, "w") as f:
f.write(content)

@connected
def _send_request(
self,
method,
path_params,
query_params=None,
headers=None,
data=None,
stream=False,
files=None,
with_base_path_params=True,
):
"""Send REST request to Hopsworks.
Uses the client it is executed from. Path parameters are url encoded automatically.
:param method: 'GET', 'PUT' or 'POST'
:type method: str
:param path_params: a list of path params to build the query url from starting after
the api resource, for example `["project", 119, "featurestores", 67]`.
:type path_params: list
:param query_params: A dictionary of key/value pairs to be added as query parameters,
defaults to None
:type query_params: dict, optional
:param headers: Additional header information, defaults to None
:type headers: dict, optional
:param data: The payload as a python dictionary to be sent as json, defaults to None
:type data: dict, optional
:param stream: Set if response should be a stream, defaults to False
:type stream: boolean, optional
:param files: dictionary for multipart encoding upload
:type files: dict, optional
:raises RestAPIError: Raised when request wasn't correctly received, understood or accepted
:return: Response json
:rtype: dict
"""
f_url = furl.furl(self._base_url)
if with_base_path_params:
base_path_params = ["hopsworks-api", "api"]
f_url.path.segments = base_path_params + path_params
else:
f_url.path.segments = path_params
url = str(f_url)

request = requests.Request(
method,
url=url,
headers=headers,
data=data,
params=query_params,
auth=self._auth,
files=files,
)

prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)

if response.status_code == 401 and self.REST_ENDPOINT in os.environ:
# refresh token and retry request - only on hopsworks
self._auth = auth.BearerAuth(self._read_jwt())
# Update request with the new token
request.auth = self._auth
prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)

if response.status_code // 100 != 2:
raise exceptions.RestAPIError(url, response)

if stream:
return response
else:
# handle different success response codes
if len(response.content) == 0:
return None
return response.json()

def _close(self):
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
self._connected = False
__all__ = [
Client,
]
Loading

0 comments on commit d77b1dd

Please sign in to comment.