-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b8d4fd1
Showing
166 changed files
with
750,068 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# General | ||
.DS_Store | ||
.AppleDouble | ||
.LSOverride | ||
|
||
# Icon must end with two \r | ||
Icon | ||
|
||
|
||
# Thumbnails | ||
._* | ||
|
||
# Files that might appear in the root of a volume | ||
.DocumentRevisions-V100 | ||
.fseventsd | ||
.Spotlight-V100 | ||
.TemporaryItems | ||
.Trashes | ||
.VolumeIcon.icns | ||
.com.apple.timemachine.donotpresent | ||
|
||
# Directories potentially created on remote AFP share | ||
.AppleDB | ||
.AppleDesktop | ||
Network Trash Folder | ||
Temporary Items | ||
.apdisk | ||
|
||
# IDEs and editors | ||
.idea | ||
.vscode |
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
wandb/ | ||
checkpoints/ | ||
bin/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/bin/bash | ||
|
||
source ./vars.sh | ||
|
||
SUBSET=${1:-test} | ||
|
||
# https://github.com/facebookresearch/fairseq/issues/3000 | ||
# https://github.com/facebookresearch/fairseq/issues/3103 | ||
# https://github.com/facebookresearch/fairseq/issues/808 | ||
echo "Generating translations for the '$SUBSET' dataset." | ||
fairseq-generate \ | ||
"$BIN_PATH" \ | ||
--batch-size 1 \ | ||
--beam 5 \ | ||
--bpe sentencepiece \ | ||
--dataset-impl "$DATASET_IMPL" \ | ||
--gen-subset "$SUBSET" \ | ||
--path "$ROOT/checkpoints/checkpoint_best.pt" \ | ||
--required-batch-size-multiple 1 \ | ||
--results-path "$RESULTS_PATH" \ | ||
--sentencepiece-model "$MODEL_PREFIX.model" \ | ||
--source-lang "$SOURCE_LANG" \ | ||
--target-lang "$TARGET_LANG" | ||
|
||
grep "^D-" "$RESULTS_PATH/generate-$SUBSET.txt" | LC_ALL=C sort -V | cut -f3 > "$RESULTS_PATH/$TARGET_LANG.hyp" | ||
|
||
_PATH=$TEST_PATH | ||
|
||
if [ "$SUBSET" == "valid" ]; then | ||
_PATH=$VALID_PATH | ||
fi | ||
|
||
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#evaluation | ||
sacrebleu -m chrf --chrf-word-order 2 "$_PATH/$TARGET_LANG" < "$RESULTS_PATH/$TARGET_LANG.hyp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
|
||
source ./vars.sh | ||
|
||
fairseq-interactive \ | ||
"$BIN_PATH" \ | ||
--batch-size 1 \ | ||
--beam 5 \ | ||
--path "$ROOT/checkpoints/checkpoint_best.pt" \ | ||
--remove-bpe sentencepiece \ | ||
--source-lang "$SOURCE_LANG" \ | ||
--target-lang "$TARGET_LANG" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/bin/bash | ||
|
||
source ./vars.sh | ||
|
||
# Train the SentencePiece model. | ||
$SPM_TRAIN --input="$TRAIN_PATH/$SOURCE_LANG,$TRAIN_PATH/$TARGET_LANG" \ | ||
--character_coverage=1.0 \ | ||
--model_prefix="$MODEL_PREFIX" \ | ||
--model_type=bpe \ | ||
--num_threads="$(nproc)" \ | ||
--max_sentence_length=256 \ | ||
--vocab_size=8000 \ | ||
--shuffle_input_sentence=true \ | ||
--bos_id=0 --pad_id=1 --eos_id=2 --unk_id=3 | ||
|
||
# Format the vocabulary file for fairseq. | ||
# https://github.com/facebookresearch/fairseq/issues/459 | ||
cut -f1 "$MODEL_PREFIX.vocab" | tail -n +5 | sed "s/$/ 100/g" > "$MODEL_PREFIX.dict" | ||
|
||
# Encode the datasets with SentencePiece. | ||
encode "train" | ||
encode "valid" | ||
encode "test" | ||
|
||
# Binarize the datasets for fairseq. | ||
fairseq-preprocess \ | ||
--bpe sentencepiece \ | ||
--dataset-impl "$DATASET_IMPL" \ | ||
--destdir "$BIN_PATH" \ | ||
--joined-dictionary \ | ||
--source-lang "$SOURCE_LANG" \ | ||
--srcdict "$MODEL_PREFIX.dict" \ | ||
--target-lang "$TARGET_LANG" \ | ||
--testpref "$ROOT/test.spm" \ | ||
--trainpref "$ROOT/train.spm" \ | ||
--validpref "$ROOT/valid.spm" \ | ||
--workers "$(nproc)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/bin/bash | ||
|
||
source ./vars.sh | ||
|
||
ulimit -n 500000 | ||
|
||
# Based on bilingual experiments from https://aclanthology.org/2023.acl-long.154/. | ||
|
||
CUDA_VISIBLE_DEVICES=0 fairseq-train \ | ||
"$BIN_PATH" \ | ||
--adam-betas '(0.9, 0.98)' \ | ||
--adam-eps 1e-06 \ | ||
--arch transformer \ | ||
--attention-dropout 0.2 \ | ||
--best-checkpoint-metric bleu \ | ||
--bpe sentencepiece \ | ||
--clip-norm 0.0 \ | ||
--criterion label_smoothed_cross_entropy \ | ||
--decoder-ffn-embed-dim 4096 \ | ||
--decoder-normalize-before \ | ||
--dropout 0.3 \ | ||
--encoder-ffn-embed-dim 4096 \ | ||
--encoder-normalize-before \ | ||
--eval-bleu \ | ||
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ | ||
--eval-bleu-detok space \ | ||
--eval-bleu-print-samples \ | ||
--eval-bleu-remove-bpe sentencepiece \ | ||
--keep-last-epochs 1 \ | ||
--label-smoothing 0.1 \ | ||
--log-format json \ | ||
--log-interval 100 \ | ||
--lr 0.001 \ | ||
--lr-scheduler inverse_sqrt \ | ||
--max-epoch 2000 \ | ||
--max-tokens 8000 \ | ||
--maximize-best-checkpoint-metric \ | ||
--optimizer adam \ | ||
--relu-dropout 0.2 \ | ||
--save-interval 10 \ | ||
--seed 2 \ | ||
--sentencepiece-model "$MODEL_PREFIX.model" \ | ||
--share-all-embeddings \ | ||
--update-freq 16 \ | ||
--validate-interval 5 \ | ||
--wandb-project "$WANDB_PROJECT" \ | ||
--warmup-init-lr 1e-07 \ | ||
--warmup-updates 400 \ | ||
--weight-decay 0.0001 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/bin/bash | ||
|
||
DEBUG=0 | ||
|
||
# Languages | ||
SOURCE_LANG=eng | ||
TARGET_LANG=spa | ||
|
||
# Paths | ||
ROOT=$(dirname "$0") | ||
TRAIN_PATH=$ROOT/train | ||
VALID_PATH=$ROOT/valid | ||
TEST_PATH=$ROOT/test | ||
RESULTS_PATH=$ROOT/results | ||
MODEL_PREFIX=$ROOT/spm | ||
BIN_PATH=$ROOT/bin/ | ||
|
||
# WandB | ||
WANDB_CONSOLE=off | ||
WANDB_PROJECT=wmt24-oldi | ||
|
||
# Fairseq Variables | ||
LANGS="$SOURCE_LANG,$TARGET_LANG" | ||
LANG_PAIRS="$SOURCE_LANG-$TARGET_LANG,$TARGET_LANG-$SOURCE_LANG" | ||
|
||
DATASET_IMPL="mmap" | ||
if [ "$DEBUG" -eq 1 ]; then | ||
DATASET_IMPL="raw" | ||
fi | ||
|
||
# SentencePiece Variables | ||
SPM=$ROOT/../../sentencepiece/build/src | ||
SPM_TRAIN="$SPM/spm_train" | ||
SPM_ENCODE="$SPM/spm_encode" | ||
|
||
set -e | ||
|
||
encode() { | ||
local _DATA_TYPE=$1 | ||
local _PATH=$ROOT | ||
|
||
case $_DATA_TYPE in | ||
train) | ||
_PATH=$TRAIN_PATH | ||
;; | ||
valid) | ||
_PATH=$VALID_PATH | ||
;; | ||
test) | ||
_PATH=$TEST_PATH | ||
;; | ||
*) | ||
echo "Invalid data type: $_DATA_TYPE" | ||
exit 1 | ||
;; | ||
esac | ||
|
||
for LANG in $SOURCE_LANG $TARGET_LANG; do | ||
$SPM_ENCODE \ | ||
--model="$MODEL_PREFIX.model" \ | ||
--output_format=piece < "$_PATH/$LANG" > "$ROOT/$_DATA_TYPE.spm.$LANG" | ||
done | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import argparse | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from google.cloud import storage | ||
|
||
|
||
def save_prov(blob, directory, filter_range): | ||
path_parts = blob.name.split("/") | ||
index = int(path_parts[-2]) | ||
|
||
if index < filter_range[0] or index > filter_range[1]: | ||
return None | ||
|
||
try: | ||
content = blob.download_as_text() | ||
with open(f"{directory}/{index}.json", "w") as f: | ||
f.write(content) | ||
|
||
print(f"Saved content from {blob.name}") | ||
|
||
except Exception as e: | ||
print(f"Failed to download {blob.name}: {str(e)}") | ||
return None | ||
|
||
|
||
def main(bucket_name, prefix, directory, filter_range): | ||
storage_client = storage.Client() | ||
bucket = storage_client.bucket(bucket_name) | ||
|
||
blobs = [ | ||
blob | ||
for blob in bucket.list_blobs(prefix=f"{prefix}/") | ||
if blob.name.endswith("prov.json") | ||
] | ||
|
||
print(f"Found {len(blobs)} provenance files") | ||
|
||
with ThreadPoolExecutor(max_workers=20) as executor: | ||
executor.map(lambda blob: save_prov(blob, directory, filter_range), blobs) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Download provenance files from Google Cloud Storage." | ||
) | ||
parser.add_argument("--prefix", type=str, help="GCS path prefix.") | ||
parser.add_argument("--bucket", type=str, help="GCS bucket name.") | ||
parser.add_argument( | ||
"--directory", | ||
type=str, | ||
default="prov-json", | ||
help="Local target directory to save files.", | ||
) | ||
parser.add_argument( | ||
"--range", | ||
type=int, | ||
nargs=2, | ||
default=(1, 6193), | ||
help="Range of indices to process (inclusive).", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
main(args.bucket, args.prefix, args.directory, tuple(args.range)) |
Oops, something went wrong.