Skip to content

Commit

Permalink
Merge branch 'master' into vocab_adapt
Browse files Browse the repository at this point in the history
  • Loading branch information
vik-rant authored Jan 16, 2024
2 parents 1fd7751 + 2ce94e1 commit 30dd916
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
4 changes: 4 additions & 0 deletions neuralspace/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
if env.VOCAB_ADAPT_URL is not None:
VOCAB_ADAPT_URL = env.VOCAB_ADAPT_URL

AMA_URL = 'api/v2/prompts'
if env.AMA_URL is not None:
AMA_URL = env.AMA_URL

# full url formation

Expand All @@ -46,6 +49,7 @@
FULL_LANGS_URL = f'{BASE_URL.rstrip("/")}/{LANGS_URL.strip("/")}'
FULL_VOICES_URL = f'{BASE_URL.rstrip("/")}/{VOICES_URL.strip("/")}'
FULL_TOKEN_URL = f'{BASE_URL.rstrip("/")}/{TOKEN_URL.strip("/")}'
FULL_AMA_URL = f'{BASE_URL.rstrip("/")}/{AMA_URL.strip("/")}'
FULL_TTS_URL = f'{BASE_URL.rstrip("/")}/{TTS_URL.strip("/")}'
FULL_VOCAB_ADAPT_URL = f'{BASE_URL.rstrip("/")}/{VOCAB_ADAPT_URL.strip("/")}'

Expand Down
1 change: 1 addition & 0 deletions neuralspace/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
STREAM_URL = get('NS_STREAM_URL')
LANGS_URL = get('NS_LANGS_URL')
TOKEN_URL = get('NS_TOKEN_URL')
AMA_URL = get('NS_AMA_URL')
TIMEOUT_SEC = get('NS_TIMEOUT_SEC')
TTS_URL = get('NS_TTS_URL')
VOICES_URL = get('NS_VOICES_URL')
Expand Down
30 changes: 30 additions & 0 deletions neuralspace/voice_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,36 @@ def poll_until_complete(self, job_id: str, poll_schedule: Optional[List[float]]
return result


def ama(self, job_id: str, prompts: List[str]) -> Dict[str, Any]:
"""
Create and send a request for an AMA job.
Parameters
----------
job_id: str
The job ID for the transcript on which to run AMA.
prompts: List[str]
List of prompts for the AMA.
Returns
-------
result: dict
The response from the server.
"""
url = f'{K.FULL_AMA_URL.rstrip("/")}'
hdrs = self._create_headers()
hdrs['Content-Type'] = 'application/json'
data = json.dumps({
"jobId": job_id,
"prompts": prompts
})

sess = self._get_session()
r = sess.post(url, headers=hdrs, data=data)
resp = utils.get_json_resp(r)
return resp


def _get_short_lived_token(self, timeout):
url = f'{K.FULL_TOKEN_URL}?duration={timeout}'
hdrs = self._create_headers()
Expand Down

0 comments on commit 30dd916

Please sign in to comment.