Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run Translation - Refactor command-line script into a Python module #8

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ sweep*/
core*

features_outputs

*.pth
37 changes: 37 additions & 0 deletions tests/translation_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import sys
import os
from omegaconf import OmegaConf

# Add the parent directory to sys.path so it can locate the `translation` folder
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest removing this, and just specifying a fully qualified path


# Now, import the modules
from translation.run_translation_module import Config, run_translation, ModelConfig, DataConfig, CommonConfig

# Define the translation configuration
translation_config = Config(
common=CommonConfig(
eval=True,
load_model="translation/signhiera_mock.pth"
),
data=DataConfig(
val_data_dir="features_outputs/0"
),
model=ModelConfig(
name_or_path="google-t5/t5-base",
feature_dim=1024
JooZef315 marked this conversation as resolved.
Show resolved Hide resolved
)
)

# Convert it to DictConfig
translation_dict_config = OmegaConf.structured(translation_config)

# Run translation with the DictConfig instance
run_translation(translation_dict_config)
53 changes: 53 additions & 0 deletions tests/translation_module_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import unittest
JooZef315 marked this conversation as resolved.
Show resolved Hide resolved
from unittest.mock import MagicMock, patch
from translation.run_translation_module import TranslationModule, Config
import numpy as np

class TestTranslationModule(unittest.TestCase):
def setUp(self):
# Basic setup for testing TranslationModule
self.config = Config()
self.config.model.name_or_path = "translation/signhiera_mock.pth"
self.translator = TranslationModule(self.config)

@patch("run_translation_module.TranslationModule.run_translation")
def test_translation_with_mock_features(self, mock_run_translation):
# Mock feature array that simulates extracted features
mock_features = np.random.rand(10, 512) # 10 timesteps, 512-dim features

# Mock translation return value
mock_run_translation.return_value = "This is a test translation."

# Run translation with mocked features
result = self.translator.run_translation(mock_features)

# Assertions
self.assertEqual(result, "This is a test translation.")
self.assertTrue(mock_run_translation.called)
mock_run_translation.assert_called_with(mock_features)

def test_configuration_loading(self):
# Ensure the configuration fields are loaded as expected
self.assertEqual(self.config.model.name_or_path, "translation/signhiera_mock.pth")

@patch("translation_module.TranslationModule.run_translation")
def test_translation_output_type(self, mock_run_translation):
# Mock feature array for translation
mock_features = np.random.rand(10, 512)

# Mock output for translation to simulate text output
mock_run_translation.return_value = "Translation successful."

# Perform translation
output = self.translator.run_translation(mock_features)

# Assertions
self.assertIsInstance(output, str) # Check output type
self.assertTrue(mock_run_translation.called)
2 changes: 1 addition & 1 deletion translation/engine_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ssvp_slt.modeling.fairseq_model import FairseqTokenizer
from transformers import PreTrainedTokenizerFast

from utils_translation import (compute_accuracy, compute_bleu,
from translation.utils_translation import (compute_accuracy, compute_bleu,
create_dataloader, postprocess_text)

EVAL_BLEU_ORDER = 4
Expand Down
24 changes: 22 additions & 2 deletions translation/main_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,31 @@
import torch
from omegaconf import DictConfig, OmegaConf

from engine_translation import evaluate, evaluate_full, train_one_epoch
from utils_translation import (create_dataloader, create_model_and_tokenizer,
from translation.engine_translation import evaluate, evaluate_full, train_one_epoch
from translation.utils_translation import (create_dataloader, create_model_and_tokenizer,
create_optimizer_and_loss_scaler)


def eval(cfg: DictConfig):
"""
Function to handle the evaluation of the model.
"""
device = torch.device(cfg.common.device)
model, tokenizer = create_model_and_tokenizer(cfg)

# Load model for finetuning or eval
if (misc.get_last_checkpoint(cfg) is None or cfg.common.eval) and cfg.common.load_model:
misc.load_model(model, cfg.common.load_model)

evaluate_full(cfg, model.to(device), tokenizer, device)
# Create validation data loader and evaluate the model
dataloader_val = create_dataloader("val", cfg, tokenizer)
val_stats, _, _ = evaluate(cfg, dataloader_val, model.to(device), tokenizer, device)

# Optionally, print or log val_stats for evaluation feedback
print("Validation Stats:", val_stats)


def main(cfg: DictConfig):
misc.init_distributed_mode(cfg)

Expand Down
158 changes: 158 additions & 0 deletions translation/run_translation_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be consistent about the license header here (it's important because those headers are automatically parsed by our bot and it will raise warning / error if the format is not recognized)


import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

from omegaconf import II, MISSING, DictConfig, OmegaConf
from ssvp_slt.util.misc import reformat_logger
# from main_translation import main as translate
from translation.main_translation import eval as translate

logger = logging.getLogger(__name__)

# Configuration classes for the module
@dataclass
class CommonConfig:
output_dir: str = "./translation_output"
log_dir: str = "./translation_logs"
resume: Optional[str] = None
load_model: Optional[str] = None
seed: int = 42
device: str = "cuda"
fp16: bool = True
eval: bool = False
dist_eval: bool = True
pin_mem: bool = True
num_workers: int = 10
eval_print_samples: bool = False
max_checkpoints: int = 3
eval_steps: Optional[int] = None
eval_best_model_after_training: bool = True
overwrite_output_dir: bool = False
compute_bleurt: bool = False

@dataclass
class ModelConfig:
name_or_path: str = None
feature_dim: int = 512
from_scratch: bool = False
dropout: float = 0.3
num_beams: int = 5
lower_case: bool = False

# Fairseq-specific fields for model compatibility
min_source_positions: int = 0
max_source_positions: int = 1024
max_target_positions: int = 1024
feats_type: Optional[str] = "hiera"
activation_fn: Optional[str] = "relu"
encoder_normalize_before: Optional[bool] = True
encoder_embed_dim: Optional[int] = 768
encoder_ffn_embed_dim: Optional[int] = 3072
encoder_attention_heads: Optional[int] = 12
encoder_layerdrop: Optional[float] = 0.1
encoder_layers: Optional[int] = 12
decoder_normalize_before: Optional[bool] = True
decoder_embed_dim: Optional[int] = 768
decoder_ffn_embed_dim: Optional[int] = 3072
decoder_attention_heads: Optional[int] = 12
decoder_layerdrop: Optional[float] = 0.1
decoder_layers: Optional[int] = 12
decoder_output_dim: Optional[int] = 768
classifier_dropout: Optional[float] = 0.1
attention_dropout: Optional[float] = 0.1
activation_dropout: Optional[float] = 0.1
layernorm_embedding: Optional[bool] = True
no_scale_embedding: Optional[bool] = False
share_decoder_input_output_embed: Optional[bool] = True
num_hidden_layers: Optional[int] = 12

@dataclass
class DataConfig:
train_data_dirs: str = MISSING
val_data_dir: str = MISSING
num_epochs_extracted: int = 1
min_source_positions: int = 0
max_source_positions: int = 1024
max_target_positions: int = 1024

@dataclass
class CriterionConfig:
label_smoothing: float = 0.2

@dataclass
class OptimizationConfig:
clip_grad: float = 1.0
lr: float = 0.001
min_lr: float = 1e-4
weight_decay: float = 1e-1
start_epoch: int = 0
epochs: int = 200
warmup_epochs: int = 10
train_batch_size: int = 32
val_batch_size: int = 64
gradient_accumulation_steps: int = 1
early_stopping: bool = True
patience: int = 10
epoch_offset: Optional[int] = 0

@dataclass
class WandbConfig:
enabled: bool = True
project: Optional[str] = None
entity: Optional[str] = None
name: Optional[str] = None
run_id: Optional[str] = None
log_code: bool = True

@dataclass
class DistConfig:
world_size: int = 1
port: int = 1
local_rank: int = -1
enabled: bool = False
rank: Optional[int] = None
dist_url: Optional[str] = None
gpu: Optional[int] = None
dist_backend: Optional[str] = None

@dataclass
class Config:
common: CommonConfig = field(default_factory=CommonConfig)
model: ModelConfig = field(default_factory=ModelConfig)
data: DataConfig = field(default_factory=DataConfig)
criterion: CriterionConfig = field(default_factory=CriterionConfig)
optim: OptimizationConfig = field(default_factory=OptimizationConfig)
dist: DistConfig = field(default_factory=DistConfig)
wandb: WandbConfig = field(default_factory=WandbConfig)

debug: bool = False
fairseq: bool = False

def run_translation(cfg: DictConfig):
# Process configuration and start translation
OmegaConf.resolve(cfg)
reformat_logger()

if cfg.debug:
print("Running in debug mode")
JooZef315 marked this conversation as resolved.
Show resolved Hide resolved

# If evaluating without training
if cfg.common.eval:
if cfg.common.load_model is None:
raise RuntimeError("Evaluation mode requires a specified model.")
cfg.common.output_dir = os.path.dirname(cfg.common.load_model)
cfg.common.log_dir = os.path.join(cfg.common.output_dir, "logs")
else:
Path.mkdir(Path(cfg.common.output_dir), parents=True, exist_ok=True)
Path.mkdir(Path(cfg.common.log_dir), parents=True, exist_ok=True)

translate(cfg)