Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to deploy models remotely #652

Merged
merged 21 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ tabpy/tabpy_server/staging
# etc
setup.bat
*~
tabpy_log.log.1
tabpy_log.log.*
7 changes: 7 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## v2.13.0

### Improvements

- Add support for deploying functions to a remote TabPy server by setting
`remote_server=True` when creating the Client instance.

## v2.12.0

### Improvements
Expand Down
10 changes: 10 additions & 0 deletions docs/tabpy-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ The URL and port are where the Tableau-Python-Server process has been started -
more info can be found in the
[Starting TabPy](server-install.md#starting-tabpy) section of the documentation.

When connecting to a remote TabPy server, configure the following parameters:

- Set `remote_server` to `True` to indicate a remote connection
- Set `localhost_endpoint` to the specific localhost address used by the remote server
- **Note:** The protocol and port may differ from the main endpoint

```python
client = Client('https://example.com:443/', remote_server=True, localhost_endpoint='http://localhost:9004/')
```

## Authentication

When TabPy is configured with the authentication feature on, client code
Expand Down
2 changes: 1 addition & 1 deletion tabpy/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.12.1
2.13.0
6 changes: 4 additions & 2 deletions tabpy/tabpy_server/handlers/endpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ def delete(self, name):

# delete files
if endpoint_info["type"] != "alias":
delete_path = get_query_object_path(
query_path = get_query_object_path(
self.settings["state_file_path"], name, None
)
staging_path = query_path.replace("/query_objects/", "/staging/endpoints/")
try:
yield self._delete_po_future(delete_path)
yield self._delete_po_future(query_path)
yield self._delete_po_future(staging_path)
except Exception as e:
self.error_out(400, f"Error while deleting: {e}")
self.finish()
Expand Down
4 changes: 2 additions & 2 deletions tabpy/tabpy_server/management/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def add_endpoint(
Name of the endpoint
description : str, optional
Description of this endpoint
doc_string : str, optional
docstring : str, optional
The doc string for this endpoint, if needed.
endpoint_type : str
The endpoint type (model, alias)
Expand Down Expand Up @@ -309,7 +309,7 @@ def update_endpoint(
Name of the endpoint
description : str, optional
Description of this endpoint
doc_string : str, optional
docstring : str, optional
The doc string for this endpoint, if needed.
endpoint_type : str, optional
The endpoint type (model, alias)
Expand Down
88 changes: 87 additions & 1 deletion tabpy/tabpy_tools/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect
from re import compile
import time
import requests
Expand Down Expand Up @@ -49,7 +50,9 @@ def _check_endpoint_name(name):


class Client:
def __init__(self, endpoint, query_timeout=1000):
def __init__(
self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None
):
"""
Connects to a running server.

Expand All @@ -63,10 +66,19 @@ def __init__(self, endpoint, query_timeout=1000):

query_timeout : float, optional
The timeout for query operations.

remote_server : bool, optional
Whether client is a remote TabPy server.

localhost_endpoint : str, optional
The localhost endpoint with potentially different protocol and
port compared to the main endpoint parameter.
"""
_check_hostname(endpoint)

self._endpoint = endpoint
self._remote_server = remote_server
self._localhost_endpoint = localhost_endpoint

session = requests.session()
session.verify = False
Expand Down Expand Up @@ -232,6 +244,12 @@ def deploy(self, name, obj, description="", schema=None, override=False, is_publ
--------
remove, get_endpoints
"""
if self._remote_server:
return self._remote_deploy(
name, obj,
description=description, schema=schema, override=override, is_public=is_public
)

endpoint = self.get_endpoints().get(name)
version = 1
if endpoint:
Expand Down Expand Up @@ -390,6 +408,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi
"methods": endpoint_object.get_methods(),
"required_files": [],
"required_packages": [],
"docstring": endpoint_object.get_docstring(),
"schema": copy.copy(schema),
"is_public": is_public,
}
Expand Down Expand Up @@ -419,6 +438,7 @@ def _wait_for_endpoint_deployment(
logger.info(
f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
)
time.sleep(interval)
start = time.time()
while True:
ep_status = self.get_status()
Expand Down Expand Up @@ -447,6 +467,72 @@ def _wait_for_endpoint_deployment(
logger.info(f"Sleeping {interval}...")
time.sleep(interval)

def _remote_deploy(
self, name, obj, description="", schema=None, override=False, is_public=False
):
"""
Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
as deploy.
"""
remote_script = self._gen_remote_script()
remote_script += f"{inspect.getsource(obj)}\n"

remote_script += (
f"client.deploy("
f"'{name}', {obj.__name__}, '{description}', "
f"override={override}, is_public={is_public}, schema={schema}"
f")"
)

return self._evaluate_remote_script(remote_script)

def _gen_remote_script(self):
"""
Generates a remote script for TabPy client connection with credential handling.

Returns:
str: A Python script to establish a TabPy client connection
"""
remote_script = [
"from tabpy.tabpy_tools.client import Client",
f"client = Client('{self._localhost_endpoint or self._endpoint}')"
]

remote_script.append(
f"client.set_credentials('{auth.username}', '{auth.password}')"
) if (auth := self._service.service_client.network_wrapper.auth) else None

return "\n".join(remote_script) + "\n"

def _evaluate_remote_script(self, remote_script):
"""
Uses TabPy /evaluate endpoint to execute a remote TabPy client script.

Parameters
----------
remote_script : str
The script to execute remotely.
"""
print(f"Remote script:\n{remote_script}\n")
url = f"{self._endpoint}evaluate"
headers = {"Content-Type": "application/json"}
payload = {"data": {}, "script": remote_script}

response = requests.post(
url,
headers=headers,
auth=self._service.service_client.network_wrapper.auth,
json=payload
)

msg = response.text.replace('null', 'Success')
if "Ad-hoc scripts have been disabled" in msg:
msg += "\n[Remote TabPy client not allowed.]"
jakeichikawasalesforce marked this conversation as resolved.
Show resolved Hide resolved

status_message = (f"{response.status_code} - {msg}\n")
print(status_message)
return status_message

def set_credentials(self, username, password):
"""
Set credentials for all the TabPy client-server communication
Expand Down
16 changes: 11 additions & 5 deletions tabpy/tabpy_tools/custom_query_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import platform
import sys
from .query_object import QueryObject as _QueryObject


Expand Down Expand Up @@ -69,12 +71,16 @@ def query(self, *args, **kwargs):
)
raise

def get_doc_string(self):
def get_docstring(self):
"""Get doc string from customized query"""
if self.custom_query.__doc__ is not None:
return self.custom_query.__doc__
else:
return "-- no docstring found in query function --"
default_docstring = "-- no docstring found in query function --"

# TODO: fix docstring parsing on Windows systems
if sys.platform == 'win32':
return default_docstring

ds = getattr(self.custom_query, '__doc__', None)
return ds if ds and isinstance(ds, str) else default_docstring

def get_methods(self):
return [self.get_query_method()]
Expand Down
4 changes: 2 additions & 2 deletions tabpy/tabpy_tools/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class QueryObject(abc.ABC):
"""
Derived class needs to implement the following interface:
* query() -- given input, return query result
* get_doc_string() -- returns documentation for the Query Object
* get_docstring() -- returns documentation for the Query Object
"""

def __init__(self, description=""):
Expand All @@ -30,7 +30,7 @@ def query(self, input):
pass

@abc.abstractmethod
def get_doc_string(self):
def get_docstring(self):
"""Returns documentation for the query object

By default, this method returns the docstring for 'query' method
Expand Down
2 changes: 2 additions & 0 deletions tabpy/tabpy_tools/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Endpoint(RESTObject):
version = RESTProperty(int)
description = RESTProperty(str)
dependencies = RESTProperty(list)
docstring = RESTProperty(str)
methods = RESTProperty(list)
creation_time = RESTProperty(datetime, from_epoch, to_epoch)
last_modified_time = RESTProperty(datetime, from_epoch, to_epoch)
Expand All @@ -64,6 +65,7 @@ def __eq__(self, other):
and self.version == other.version
and self.description == other.description
and self.dependencies == other.dependencies
and self.docstring == other.docstring
and self.methods == other.methods
and self.evaluator == other.evaluator
and self.schema_version == other.schema_version
Expand Down
31 changes: 28 additions & 3 deletions tests/unit/tools_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@ def setUp(self):

def test_init(self):
client = Client("http://example.com:9004")

self.assertEqual(client._endpoint, "http://example.com:9004")
self.assertEqual(client._remote_server, False)

client = Client("http://example.com/", 10.0)

self.assertEqual(client._endpoint, "http://example.com/")

client = Client(endpoint="https://example.com/", query_timeout=-10.0)

self.assertEqual(client._endpoint, "https://example.com/")
self.assertEqual(client.query_timeout, 0.0)

client = Client(
"http://example.com:442/",
remote_server=True,
localhost_endpoint="http://localhost:9004/"
)
self.assertEqual(client._endpoint, "http://example.com:442/")
self.assertEqual(client._remote_server, True)
self.assertEqual(client._localhost_endpoint, "http://localhost:9004/")

# valid name tests
with self.assertRaises(ValueError):
Client("")
Expand Down Expand Up @@ -90,3 +97,21 @@ def test_check_invalid_endpoint_name(self):
f"endpoint name {endpoint_name } can only contain: "
"a-z, A-Z, 0-9, underscore, hyphens and spaces.",
)

def test_deploy_with_remote_server(self):
client = Client("http://example.com:9004/", remote_server=True)
mock_evaluate_remote_script = Mock()
client._evaluate_remote_script = mock_evaluate_remote_script
client.deploy('name', lambda: True, 'description')
mock_evaluate_remote_script.assert_called()

def test_gen_remote_script(self):
client = Client("http://example.com:9004/", remote_server=True)
script = client._gen_remote_script()
self.assertTrue("from tabpy.tabpy_tools.client import Client" in script)
self.assertTrue("client = Client('http://example.com:9004/')" in script)
self.assertFalse("client.set_credentials" in script)

client.set_credentials("username", "password")
script = client._gen_remote_script()
self.assertTrue("client.set_credentials('username', 'password')" in script)
Loading