-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run Translation - Refactor command-line script into a Python module #8
Closed
Closed
Changes from 7 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
9c8df0b
Update README.md
JooZef315 5b5f2f5
add run translation module & test
JooZef315 9ff9970
use default_factory design pattern
JooZef315 726c413
rafactor test & add translation_demo
JooZef315 2acfaac
create new eval method in main_translation
JooZef315 388dfc4
add LICENSE header
JooZef315 d4c81c8
Merge branch 'main' into run_translation
JooZef315 18f3d56
use logger instead of print
JooZef315 fc3fafc
reuse evaluate_full & edit tests
JooZef315 f21cf53
Merge branch 'run_translation' of https://github.com/JooZef315/ssvp_s…
JooZef315 d265238
remove casting len to float
JooZef315 69500f5
remove casting len to float
JooZef315 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,4 +97,5 @@ sweep*/ | |
core* | ||
|
||
features_outputs | ||
|
||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
|
||
import sys | ||
import os | ||
from omegaconf import OmegaConf | ||
|
||
# Add the parent directory to sys.path so it can locate the `translation` folder | ||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
|
||
# Now, import the modules | ||
from translation.run_translation_module import Config, run_translation, ModelConfig, DataConfig, CommonConfig | ||
|
||
# Define the translation configuration | ||
translation_config = Config( | ||
common=CommonConfig( | ||
eval=True, | ||
load_model="translation/signhiera_mock.pth" | ||
), | ||
data=DataConfig( | ||
val_data_dir="features_outputs/0" | ||
), | ||
model=ModelConfig( | ||
name_or_path="google-t5/t5-base", | ||
feature_dim=1024 | ||
JooZef315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
) | ||
|
||
# Convert it to DictConfig | ||
translation_dict_config = OmegaConf.structured(translation_config) | ||
|
||
# Run translation with the DictConfig instance | ||
run_translation(translation_dict_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
|
||
import unittest | ||
JooZef315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from unittest.mock import MagicMock, patch | ||
from translation.run_translation_module import TranslationModule, Config | ||
import numpy as np | ||
|
||
class TestTranslationModule(unittest.TestCase): | ||
def setUp(self): | ||
# Basic setup for testing TranslationModule | ||
self.config = Config() | ||
self.config.model.name_or_path = "translation/signhiera_mock.pth" | ||
self.translator = TranslationModule(self.config) | ||
|
||
@patch("run_translation_module.TranslationModule.run_translation") | ||
def test_translation_with_mock_features(self, mock_run_translation): | ||
# Mock feature array that simulates extracted features | ||
mock_features = np.random.rand(10, 512) # 10 timesteps, 512-dim features | ||
|
||
# Mock translation return value | ||
mock_run_translation.return_value = "This is a test translation." | ||
|
||
# Run translation with mocked features | ||
result = self.translator.run_translation(mock_features) | ||
|
||
# Assertions | ||
self.assertEqual(result, "This is a test translation.") | ||
self.assertTrue(mock_run_translation.called) | ||
mock_run_translation.assert_called_with(mock_features) | ||
|
||
def test_configuration_loading(self): | ||
# Ensure the configuration fields are loaded as expected | ||
self.assertEqual(self.config.model.name_or_path, "translation/signhiera_mock.pth") | ||
|
||
@patch("translation_module.TranslationModule.run_translation") | ||
def test_translation_output_type(self, mock_run_translation): | ||
# Mock feature array for translation | ||
mock_features = np.random.rand(10, 512) | ||
|
||
# Mock output for translation to simulate text output | ||
mock_run_translation.return_value = "Translation successful." | ||
|
||
# Perform translation | ||
output = self.translator.run_translation(mock_features) | ||
|
||
# Assertions | ||
self.assertIsInstance(output, str) # Check output type | ||
self.assertTrue(mock_run_translation.called) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's be consistent about the license header here (it's important because those headers are automatically parsed by our bot and it will raise warning / error if the format is not recognized) |
||
|
||
import logging | ||
import os | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from omegaconf import II, MISSING, DictConfig, OmegaConf | ||
from ssvp_slt.util.misc import reformat_logger | ||
# from main_translation import main as translate | ||
from translation.main_translation import eval as translate | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Configuration classes for the module | ||
@dataclass | ||
class CommonConfig: | ||
output_dir: str = "./translation_output" | ||
log_dir: str = "./translation_logs" | ||
resume: Optional[str] = None | ||
load_model: Optional[str] = None | ||
seed: int = 42 | ||
device: str = "cuda" | ||
fp16: bool = True | ||
eval: bool = False | ||
dist_eval: bool = True | ||
pin_mem: bool = True | ||
num_workers: int = 10 | ||
eval_print_samples: bool = False | ||
max_checkpoints: int = 3 | ||
eval_steps: Optional[int] = None | ||
eval_best_model_after_training: bool = True | ||
overwrite_output_dir: bool = False | ||
compute_bleurt: bool = False | ||
|
||
@dataclass | ||
class ModelConfig: | ||
name_or_path: str = None | ||
feature_dim: int = 512 | ||
from_scratch: bool = False | ||
dropout: float = 0.3 | ||
num_beams: int = 5 | ||
lower_case: bool = False | ||
|
||
# Fairseq-specific fields for model compatibility | ||
min_source_positions: int = 0 | ||
max_source_positions: int = 1024 | ||
max_target_positions: int = 1024 | ||
feats_type: Optional[str] = "hiera" | ||
activation_fn: Optional[str] = "relu" | ||
encoder_normalize_before: Optional[bool] = True | ||
encoder_embed_dim: Optional[int] = 768 | ||
encoder_ffn_embed_dim: Optional[int] = 3072 | ||
encoder_attention_heads: Optional[int] = 12 | ||
encoder_layerdrop: Optional[float] = 0.1 | ||
encoder_layers: Optional[int] = 12 | ||
decoder_normalize_before: Optional[bool] = True | ||
decoder_embed_dim: Optional[int] = 768 | ||
decoder_ffn_embed_dim: Optional[int] = 3072 | ||
decoder_attention_heads: Optional[int] = 12 | ||
decoder_layerdrop: Optional[float] = 0.1 | ||
decoder_layers: Optional[int] = 12 | ||
decoder_output_dim: Optional[int] = 768 | ||
classifier_dropout: Optional[float] = 0.1 | ||
attention_dropout: Optional[float] = 0.1 | ||
activation_dropout: Optional[float] = 0.1 | ||
layernorm_embedding: Optional[bool] = True | ||
no_scale_embedding: Optional[bool] = False | ||
share_decoder_input_output_embed: Optional[bool] = True | ||
num_hidden_layers: Optional[int] = 12 | ||
|
||
@dataclass | ||
class DataConfig: | ||
train_data_dirs: str = MISSING | ||
val_data_dir: str = MISSING | ||
num_epochs_extracted: int = 1 | ||
min_source_positions: int = 0 | ||
max_source_positions: int = 1024 | ||
max_target_positions: int = 1024 | ||
|
||
@dataclass | ||
class CriterionConfig: | ||
label_smoothing: float = 0.2 | ||
|
||
@dataclass | ||
class OptimizationConfig: | ||
clip_grad: float = 1.0 | ||
lr: float = 0.001 | ||
min_lr: float = 1e-4 | ||
weight_decay: float = 1e-1 | ||
start_epoch: int = 0 | ||
epochs: int = 200 | ||
warmup_epochs: int = 10 | ||
train_batch_size: int = 32 | ||
val_batch_size: int = 64 | ||
gradient_accumulation_steps: int = 1 | ||
early_stopping: bool = True | ||
patience: int = 10 | ||
epoch_offset: Optional[int] = 0 | ||
|
||
@dataclass | ||
class WandbConfig: | ||
enabled: bool = True | ||
project: Optional[str] = None | ||
entity: Optional[str] = None | ||
name: Optional[str] = None | ||
run_id: Optional[str] = None | ||
log_code: bool = True | ||
|
||
@dataclass | ||
class DistConfig: | ||
world_size: int = 1 | ||
port: int = 1 | ||
local_rank: int = -1 | ||
enabled: bool = False | ||
rank: Optional[int] = None | ||
dist_url: Optional[str] = None | ||
gpu: Optional[int] = None | ||
dist_backend: Optional[str] = None | ||
|
||
@dataclass | ||
class Config: | ||
common: CommonConfig = field(default_factory=CommonConfig) | ||
model: ModelConfig = field(default_factory=ModelConfig) | ||
data: DataConfig = field(default_factory=DataConfig) | ||
criterion: CriterionConfig = field(default_factory=CriterionConfig) | ||
optim: OptimizationConfig = field(default_factory=OptimizationConfig) | ||
dist: DistConfig = field(default_factory=DistConfig) | ||
wandb: WandbConfig = field(default_factory=WandbConfig) | ||
|
||
debug: bool = False | ||
fairseq: bool = False | ||
|
||
def run_translation(cfg: DictConfig): | ||
# Process configuration and start translation | ||
OmegaConf.resolve(cfg) | ||
reformat_logger() | ||
|
||
if cfg.debug: | ||
print("Running in debug mode") | ||
JooZef315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# If evaluating without training | ||
if cfg.common.eval: | ||
if cfg.common.load_model is None: | ||
raise RuntimeError("Evaluation mode requires a specified model.") | ||
cfg.common.output_dir = os.path.dirname(cfg.common.load_model) | ||
cfg.common.log_dir = os.path.join(cfg.common.output_dir, "logs") | ||
else: | ||
Path.mkdir(Path(cfg.common.output_dir), parents=True, exist_ok=True) | ||
Path.mkdir(Path(cfg.common.log_dir), parents=True, exist_ok=True) | ||
|
||
translate(cfg) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest removing this, and just specifying a fully qualified path