-
Notifications
You must be signed in to change notification settings - Fork 110
/
api.py
111 lines (82 loc) · 3.33 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import importlib
from quart import Quart, g, request, jsonify, abort
from markupsafe import escape
import json
import aiohttp
import os
from functools import wraps
from dotenv import load_dotenv
from quart_compress import Compress
load_dotenv()
app = Quart(__name__)
Compress(app)
extra_dirs = ['src']
extra_files = extra_dirs[:]
for extra_dir in extra_dirs:
for dirname, dirs, files in os.walk(extra_dir):
for filename in files:
filename = os.path.join(dirname, filename)
if os.path.isfile(filename):
extra_files.append(filename)
app.config.update(extra_files=extra_files)
with open('repository_data.json') as f:
repository_data = json.load(f)
AUTH_HEADER = os.getenv("AUTH_HEADER")
AUTH_HEADER_KEY = os.getenv("AUTH_HEADER_KEY")
def verify_auth_header(auth_header_key, expected_value):
def decorator(f):
@wraps(f)
async def decorated_function(*args, **kwargs):
auth_header = request.headers.get(auth_header_key)
if not auth_header or auth_header != expected_value:
print("Unauthorized access");
abort(401) # Unauthorized
return await f(*args, **kwargs)
return decorated_function
return decorator
@app.route("/")
def welcome():
return "<p>Welcome!</p>"
@app.route("/repository")
def repository():
""" Returns the repository data, which contains the available models and their configurations"""
return jsonify(repository_data)
def json_to_object(request_class, json_str):
"""Converts a JSON string to an object of the given class at level 1."""
data = json.loads(json_str)
return request_class(**data)
def get_model_config(use_case, provider, mode):
""" Returns the model config for the given use case, provider and mode """
use_case_data = repository_data.get('use_cases').get(use_case)
if use_case_data is None:
return f'{escape(use_case)} Use case is not available', 400
provider_data = use_case_data.get(provider)
if provider_data is None:
return f'{escape(provider)} Provider is not available', 400
mode = provider_data.get(mode)
if mode is None:
return f'{escape(mode)} Mode is not available', 400
return mode, 200
@app.route("/<use_case>/<provider>/<mode>", methods=['POST'])
@verify_auth_header(AUTH_HEADER_KEY, AUTH_HEADER)
async def transformer(use_case, provider, mode):
""" Returns the translation for the given tex; provider and mode are as mentioned in the repository"""
model_config = get_model_config(use_case, provider, mode)
if model_config[1] != 200:
return model_config
model_class_name = model_config[0].get('model_class')
model_request_class_name = model_config[0].get('request_class')
module = importlib.import_module("src" + "." + use_case + "." + provider + "." + mode)
model = getattr(module, model_class_name)(app)
model_request = getattr(module, model_request_class_name)
request_class = json_to_object(model_request, json.dumps(await request.json))
if model_config[0].get("__is_async"):
response = await model.inference(request_class)
else:
response = model.inference(request_class)
return response
@app.before_serving
async def startup():
app.client = aiohttp.ClientSession()
# quart --app api --debug run
# hypercorn api -b 0.0.0.0:8000