Skip to content

Commit

Permalink
Refactor task status methods into get_task_status
Browse files Browse the repository at this point in the history
- Consolidated `get_dubbing_task_status` and `get_tts_task_status` into
  `get_task_status`.
- Added support for "tts", "dubbing", and "transcription" tasks in
  get_task_status.
- Updated method calls to use `get_task_status` with appropriate task
  type.
  • Loading branch information
cr2007 committed Jun 7, 2024
1 parent 978670f commit a747391
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions cambai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ def start_dubbing(self, *, video_url: str, source_language: int = 1,
return response.json()


def get_dubbing_task_status(self, task_id: str) -> TaskStatus:
def get_task_status(
self,
task: Literal["tts", "dubbing", "transcription"],
task_id: str
) -> TaskStatus:
"""
Retrieves the status of a specific dubbing task.
Expand All @@ -289,8 +293,14 @@ def get_dubbing_task_status(self, task_id: str) -> TaskStatus:
HTTPError: If the GET request to the API endpoint fails.
"""

if task == "dubbing":
url: str = self.create_api_endpoint(f"end_to_end_dubbing/{task_id}")
if task == "tts":
url: str = self.create_api_endpoint(f"tts/{task_id}")
if task == "transcription":
url: str = self.create_api_endpoint(f"create_transcription/{task_id}")

# Construct the API endpoint URL for retrieving the status of a specific dubbing task
url: str = self.create_api_endpoint(f"end_to_end_dubbing/{task_id}")

# Send a GET request to the API endpoint
response = self.session.get(url)
Expand Down Expand Up @@ -379,7 +389,7 @@ def dub(self, *, video_url: str, source_language: int = 1, target_language: int,
# Enter a loop to periodically check the status of the dubbing task
while True:
# Get the current status of the dubbing task
task = self.get_dubbing_task_status(task_id)
task = self.get_task_status("dubbing", task_id)

# If debug is True, print the task status and run ID
if debug:
Expand All @@ -402,7 +412,7 @@ def dub(self, *, video_url: str, source_language: int = 1, target_language: int,
sleep(1)

# Get the final status of the dubbing task
task = self.get_dubbing_task_status(task_id)
task = self.get_task_status("dubbing", task_id)

# If the run ID is None, raise an APIError
if task["run_id"] is None:
Expand Down Expand Up @@ -481,17 +491,7 @@ def get_tts_status(self, task_id: str) -> TaskStatus:
TaskStatus: The status of the TTS task.
"""

# Create the API endpoint URL using the provided task ID
url: str = self.create_api_endpoint(f"tts/{task_id}")

# Send the GET request to the API
response = self.session.get(url)

# Raise an exception if the request was unsuccessful
response.raise_for_status()

# Return the JSON response from the API, which includes the task status
return response.json()
return self.get_task_status("tts", task_id)


def get_tts_result(self, run_id: int, output_directory: Optional[str]) -> None:
Expand Down

0 comments on commit a747391

Please sign in to comment.