-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
137 lines (116 loc) · 3.75 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import shutil
from pathlib import Path
from loguru import logger
from tqdm import tqdm
from transformers import TrainingArguments
from elpis.datasets import Dataset
from elpis.datasets.dataset import CleaningOptions
from elpis.datasets.preprocessing import process_batch
from elpis.models import ElanOptions, ElanTierSelector
from elpis.models.job import DataArguments, Job, ModelArguments
from elpis.trainer.trainer import run_job
from elpis.transcriber.results import build_elan, build_text
from elpis.transcriber.transcribe import build_pipeline, transcribe
DATASETS_PATH = Path(__file__).parent.parent.parent / "datasets"
TIMIT_PATH = DATASETS_PATH / "timit"
DIGITS_PATH = DATASETS_PATH / "digits-preview"
TRAINING_FILES = list((DIGITS_PATH / "train").rglob("*.*"))
TRANSCRIBE_AUDIO = DIGITS_PATH / "test/audio2.wav"
TIMIT_TIER_NAME = "default"
TIER_NAME = "tx"
print("------ Training files ------")
# print(training_files)
DIGITS_DATASET = Dataset(
name="my_dataset",
files=TRAINING_FILES,
cleaning_options=CleaningOptions(), # Default cleaning options
# Elan data extraction info - required if dataset includes .eaf files.
elan_options=ElanOptions(
selection_mechanism=ElanTierSelector.NAME, selection_value=TIER_NAME
),
)
TIMIT_DATASET = Dataset(
name="my_dataset",
files=list(TIMIT_PATH.rglob("*.*")),
cleaning_options=CleaningOptions(), # Default cleaning options
# Elan data extraction info - required if dataset includes .eaf files.
elan_options=ElanOptions(
selection_mechanism=ElanTierSelector.NAME, selection_value="default"
),
)
dataset = TIMIT_DATASET
# Setup
tmp_path = Path("testdir")
# tmp_path = Path("/tmp") / "testscript"
dataset_dir = tmp_path / "dataset"
model_dir = tmp_path / "model"
output_dir = tmp_path / "output"
cache_dir = tmp_path / "cache"
# Reset dir between runs
if tmp_path.exists():
shutil.rmtree(tmp_path)
# Make all directories
for directory in dataset_dir, model_dir, output_dir:
directory.mkdir(exist_ok=True, parents=True)
# Preprocessing
logger.info("Creating batches")
batches = dataset.to_batches()
logger.info("Processing batches")
for batch in tqdm(batches, unit="batch"):
process_batch(batch, dataset_dir)
# Train the model
job = Job(
model_args=ModelArguments(
"facebook/wav2vec2-base",
# ctc_zero_infinity=True,
attention_dropout=0.1,
# hidden_dropout=0.1,
# mask_time_prob=0.05,
layerdrop=0.0,
freeze_feature_encoder=True,
),
data_args=DataArguments(
dataset_name_or_path=str(dataset_dir),
text_column_name="transcript",
min_duration_in_seconds=1,
max_duration_in_seconds=5,
),
training_args=TrainingArguments(
output_dir=str(model_dir),
overwrite_output_dir=True,
# evaluation_strategy="epoch",
do_train=True,
do_eval=True,
per_device_train_batch_size=16,
per_device_eval_batch_size=1,
num_train_epochs=20,
learning_rate=1e-4,
group_by_length=True,
weight_decay=0.005,
warmup_steps=1000,
logging_steps=10,
eval_steps=100,
save_steps=400,
save_total_limit=2,
),
)
print("------ JOB ------")
print(job)
run_job(job)
print("------ TRAINED ------")
# Perform inference with pipeline
print("------ INFER ------")
asr = build_pipeline(
pretrained_location=str(model_dir.absolute()),
)
annotations = transcribe(TRANSCRIBE_AUDIO, asr)
print(build_text(annotations))
# Build output files
print("------ OUTPUT ------")
text_file = output_dir / "test.txt"
with open(text_file, "w") as output_file:
output_file.write(build_text(annotations))
elan_file = output_dir / "test.eaf"
eaf = build_elan(annotations)
eaf.to_file(str(elan_file))
print("voila ;)")