diff --git a/neuralspace/constants.py b/neuralspace/constants.py index d3dff1f..fc883cb 100644 --- a/neuralspace/constants.py +++ b/neuralspace/constants.py @@ -38,6 +38,9 @@ if env.VOCAB_ADAPT_URL is not None: VOCAB_ADAPT_URL = env.VOCAB_ADAPT_URL +AMA_URL = 'api/v2/prompts' +if env.AMA_URL is not None: + AMA_URL = env.AMA_URL # full url formation @@ -46,6 +49,7 @@ FULL_LANGS_URL = f'{BASE_URL.rstrip("/")}/{LANGS_URL.strip("/")}' FULL_VOICES_URL = f'{BASE_URL.rstrip("/")}/{VOICES_URL.strip("/")}' FULL_TOKEN_URL = f'{BASE_URL.rstrip("/")}/{TOKEN_URL.strip("/")}' +FULL_AMA_URL = f'{BASE_URL.rstrip("/")}/{AMA_URL.strip("/")}' FULL_TTS_URL = f'{BASE_URL.rstrip("/")}/{TTS_URL.strip("/")}' FULL_VOCAB_ADAPT_URL = f'{BASE_URL.rstrip("/")}/{VOCAB_ADAPT_URL.strip("/")}' diff --git a/neuralspace/env.py b/neuralspace/env.py index e55777d..e5a5c06 100644 --- a/neuralspace/env.py +++ b/neuralspace/env.py @@ -10,6 +10,7 @@ STREAM_URL = get('NS_STREAM_URL') LANGS_URL = get('NS_LANGS_URL') TOKEN_URL = get('NS_TOKEN_URL') +AMA_URL = get('NS_AMA_URL') TIMEOUT_SEC = get('NS_TIMEOUT_SEC') TTS_URL = get('NS_TTS_URL') VOICES_URL = get('NS_VOICES_URL') diff --git a/neuralspace/voice_ai.py b/neuralspace/voice_ai.py index 1722fa0..1995799 100644 --- a/neuralspace/voice_ai.py +++ b/neuralspace/voice_ai.py @@ -270,6 +270,36 @@ def poll_until_complete(self, job_id: str, poll_schedule: Optional[List[float]] return result + def ama(self, job_id: str, prompts: List[str]) -> Dict[str, Any]: + """ + Create and send a request for an AMA job. + + Parameters + ---------- + job_id: str + The job ID for the transcript on which to run AMA. + prompts: List[str] + List of prompts for the AMA. + + Returns + ------- + result: dict + The response from the server. + """ + url = f'{K.FULL_AMA_URL.rstrip("/")}' + hdrs = self._create_headers() + hdrs['Content-Type'] = 'application/json' + data = json.dumps({ + "jobId": job_id, + "prompts": prompts + }) + + sess = self._get_session() + r = sess.post(url, headers=hdrs, data=data) + resp = utils.get_json_resp(r) + return resp + + def _get_short_lived_token(self, timeout): url = f'{K.FULL_TOKEN_URL}?duration={timeout}' hdrs = self._create_headers()