Skip to content

Commit

Permalink
Initial release
Browse files Browse the repository at this point in the history
  • Loading branch information
josecols committed Aug 22, 2024
0 parents commit b8d4fd1
Show file tree
Hide file tree
Showing 166 changed files with 750,068 additions and 0 deletions.
31 changes: 31 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon


# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

# IDEs and editors
.idea
.vscode
427 changes: 427 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

Empty file added README.md
Empty file.
3 changes: 3 additions & 0 deletions nmt/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wandb/
checkpoints/
bin/
Empty file added nmt/README.md
Empty file.
34 changes: 34 additions & 0 deletions nmt/eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

source ./vars.sh

SUBSET=${1:-test}

# https://github.com/facebookresearch/fairseq/issues/3000
# https://github.com/facebookresearch/fairseq/issues/3103
# https://github.com/facebookresearch/fairseq/issues/808
echo "Generating translations for the '$SUBSET' dataset."
fairseq-generate \
"$BIN_PATH" \
--batch-size 1 \
--beam 5 \
--bpe sentencepiece \
--dataset-impl "$DATASET_IMPL" \
--gen-subset "$SUBSET" \
--path "$ROOT/checkpoints/checkpoint_best.pt" \
--required-batch-size-multiple 1 \
--results-path "$RESULTS_PATH" \
--sentencepiece-model "$MODEL_PREFIX.model" \
--source-lang "$SOURCE_LANG" \
--target-lang "$TARGET_LANG"

grep "^D-" "$RESULTS_PATH/generate-$SUBSET.txt" | LC_ALL=C sort -V | cut -f3 > "$RESULTS_PATH/$TARGET_LANG.hyp"

_PATH=$TEST_PATH

if [ "$SUBSET" == "valid" ]; then
_PATH=$VALID_PATH
fi

# https://github.com/facebookresearch/flores/blob/main/flores200/README.md#evaluation
sacrebleu -m chrf --chrf-word-order 2 "$_PATH/$TARGET_LANG" < "$RESULTS_PATH/$TARGET_LANG.hyp"
12 changes: 12 additions & 0 deletions nmt/interactive.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

source ./vars.sh

fairseq-interactive \
"$BIN_PATH" \
--batch-size 1 \
--beam 5 \
--path "$ROOT/checkpoints/checkpoint_best.pt" \
--remove-bpe sentencepiece \
--source-lang "$SOURCE_LANG" \
--target-lang "$TARGET_LANG"
37 changes: 37 additions & 0 deletions nmt/prepare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

source ./vars.sh

# Train the SentencePiece model.
$SPM_TRAIN --input="$TRAIN_PATH/$SOURCE_LANG,$TRAIN_PATH/$TARGET_LANG" \
--character_coverage=1.0 \
--model_prefix="$MODEL_PREFIX" \
--model_type=bpe \
--num_threads="$(nproc)" \
--max_sentence_length=256 \
--vocab_size=8000 \
--shuffle_input_sentence=true \
--bos_id=0 --pad_id=1 --eos_id=2 --unk_id=3

# Format the vocabulary file for fairseq.
# https://github.com/facebookresearch/fairseq/issues/459
cut -f1 "$MODEL_PREFIX.vocab" | tail -n +5 | sed "s/$/ 100/g" > "$MODEL_PREFIX.dict"

# Encode the datasets with SentencePiece.
encode "train"
encode "valid"
encode "test"

# Binarize the datasets for fairseq.
fairseq-preprocess \
--bpe sentencepiece \
--dataset-impl "$DATASET_IMPL" \
--destdir "$BIN_PATH" \
--joined-dictionary \
--source-lang "$SOURCE_LANG" \
--srcdict "$MODEL_PREFIX.dict" \
--target-lang "$TARGET_LANG" \
--testpref "$ROOT/test.spm" \
--trainpref "$ROOT/train.spm" \
--validpref "$ROOT/valid.spm" \
--workers "$(nproc)"
49 changes: 49 additions & 0 deletions nmt/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

source ./vars.sh

ulimit -n 500000

# Based on bilingual experiments from https://aclanthology.org/2023.acl-long.154/.

CUDA_VISIBLE_DEVICES=0 fairseq-train \
"$BIN_PATH" \
--adam-betas '(0.9, 0.98)' \
--adam-eps 1e-06 \
--arch transformer \
--attention-dropout 0.2 \
--best-checkpoint-metric bleu \
--bpe sentencepiece \
--clip-norm 0.0 \
--criterion label_smoothed_cross_entropy \
--decoder-ffn-embed-dim 4096 \
--decoder-normalize-before \
--dropout 0.3 \
--encoder-ffn-embed-dim 4096 \
--encoder-normalize-before \
--eval-bleu \
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
--eval-bleu-detok space \
--eval-bleu-print-samples \
--eval-bleu-remove-bpe sentencepiece \
--keep-last-epochs 1 \
--label-smoothing 0.1 \
--log-format json \
--log-interval 100 \
--lr 0.001 \
--lr-scheduler inverse_sqrt \
--max-epoch 2000 \
--max-tokens 8000 \
--maximize-best-checkpoint-metric \
--optimizer adam \
--relu-dropout 0.2 \
--save-interval 10 \
--seed 2 \
--sentencepiece-model "$MODEL_PREFIX.model" \
--share-all-embeddings \
--update-freq 16 \
--validate-interval 5 \
--wandb-project "$WANDB_PROJECT" \
--warmup-init-lr 1e-07 \
--warmup-updates 400 \
--weight-decay 0.0001
6,193 changes: 6,193 additions & 0 deletions nmt/train/spa

Large diffs are not rendered by default.

63 changes: 63 additions & 0 deletions nmt/vars.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash

DEBUG=0

# Languages
SOURCE_LANG=eng
TARGET_LANG=spa

# Paths
ROOT=$(dirname "$0")
TRAIN_PATH=$ROOT/train
VALID_PATH=$ROOT/valid
TEST_PATH=$ROOT/test
RESULTS_PATH=$ROOT/results
MODEL_PREFIX=$ROOT/spm
BIN_PATH=$ROOT/bin/

# WandB
WANDB_CONSOLE=off
WANDB_PROJECT=wmt24-oldi

# Fairseq Variables
LANGS="$SOURCE_LANG,$TARGET_LANG"
LANG_PAIRS="$SOURCE_LANG-$TARGET_LANG,$TARGET_LANG-$SOURCE_LANG"

DATASET_IMPL="mmap"
if [ "$DEBUG" -eq 1 ]; then
DATASET_IMPL="raw"
fi

# SentencePiece Variables
SPM=$ROOT/../../sentencepiece/build/src
SPM_TRAIN="$SPM/spm_train"
SPM_ENCODE="$SPM/spm_encode"

set -e

encode() {
local _DATA_TYPE=$1
local _PATH=$ROOT

case $_DATA_TYPE in
train)
_PATH=$TRAIN_PATH
;;
valid)
_PATH=$VALID_PATH
;;
test)
_PATH=$TEST_PATH
;;
*)
echo "Invalid data type: $_DATA_TYPE"
exit 1
;;
esac

for LANG in $SOURCE_LANG $TARGET_LANG; do
$SPM_ENCODE \
--model="$MODEL_PREFIX.model" \
--output_format=piece < "$_PATH/$LANG" > "$ROOT/$_DATA_TYPE.spm.$LANG"
done
}
90 changes: 90 additions & 0 deletions scripts/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Flask stuff:
instance/
.webassets-cache

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
Empty file added scripts/README.md
Empty file.
64 changes: 64 additions & 0 deletions scripts/gcs_download_prov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import argparse
from concurrent.futures import ThreadPoolExecutor

from google.cloud import storage


def save_prov(blob, directory, filter_range):
path_parts = blob.name.split("/")
index = int(path_parts[-2])

if index < filter_range[0] or index > filter_range[1]:
return None

try:
content = blob.download_as_text()
with open(f"{directory}/{index}.json", "w") as f:
f.write(content)

print(f"Saved content from {blob.name}")

except Exception as e:
print(f"Failed to download {blob.name}: {str(e)}")
return None


def main(bucket_name, prefix, directory, filter_range):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)

blobs = [
blob
for blob in bucket.list_blobs(prefix=f"{prefix}/")
if blob.name.endswith("prov.json")
]

print(f"Found {len(blobs)} provenance files")

with ThreadPoolExecutor(max_workers=20) as executor:
executor.map(lambda blob: save_prov(blob, directory, filter_range), blobs)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download provenance files from Google Cloud Storage."
)
parser.add_argument("--prefix", type=str, help="GCS path prefix.")
parser.add_argument("--bucket", type=str, help="GCS bucket name.")
parser.add_argument(
"--directory",
type=str,
default="prov-json",
help="Local target directory to save files.",
)
parser.add_argument(
"--range",
type=int,
nargs=2,
default=(1, 6193),
help="Range of indices to process (inclusive).",
)

args = parser.parse_args()

main(args.bucket, args.prefix, args.directory, tuple(args.range))
Loading

0 comments on commit b8d4fd1

Please sign in to comment.