forked from jpuigcerver/PyLaia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpylaia-htr-netout
executable file
·143 lines (128 loc) · 4.58 KB
/
pylaia-htr-netout
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python
from __future__ import absolute_import
import argparse
import os
import sys
from tqdm import tqdm
from typing import Any
import torch
import torch.nn.functional as functional
import laia.common.logging as log
from laia.common.arguments import add_argument, args, add_defaults
from laia.common.loader import ModelLoader, CheckpointLoader
from laia.data import ImageDataLoader, ImageFromListDataset
from laia.engine.feeders import ImageFeeder, ItemFeeder
from laia.experiments import Experiment
from laia.losses.ctc_loss import transform_output
from laia.utils import ImageToTensor
from laia.utils.kaldi import ArchiveLatticeWriter, ArchiveMatrixWriter
if __name__ == "__main__":
add_defaults(
"batch_size", "gpu", "train_path", "show_progress_bar", logging_level="WARNING"
)
add_argument(
"img_dirs", type=str, nargs="+", help="Directory containing word images"
)
add_argument(
"img_list",
type=argparse.FileType("r"),
help="File or list containing images to decode",
)
add_argument(
"--model_filename", type=str, default="model", help="File name of the model"
)
add_argument(
"--checkpoint",
type=str,
default="experiment.ckpt.lowest-valid-cer*",
help="Name of the model checkpoint to use, can be a glob pattern",
)
add_argument(
"--source",
type=str,
default="experiment",
choices=["experiment", "model"],
help="Type of class which generated the checkpoint",
)
add_argument(
"--output_transform",
type=str,
default=None,
choices=["softmax", "log_softmax"],
help="Apply this transformation at the end of the model. "
'For instance, use "softmax" to get posterior probabilities as the '
"output of the model",
)
add_argument(
"--output_matrix",
type=argparse.FileType("wb"),
default=None,
help="Path of the Kaldi's archive containing the output matrices "
"(one for each sample), where each row represents a timestep and "
"each column represents a CTC label",
)
add_argument(
"--output_lattice",
type=argparse.FileType("w"),
default=None,
help="Path of the Kaldi's archive containing the output lattices"
"(one for each sample), representing the CTC output",
)
add_argument(
"--digits",
type=int,
default=10,
help="Number of digits to be used for formatting",
)
args = args()
device = torch.device("cuda:{}".format(args.gpu - 1) if args.gpu else "cpu")
model = ModelLoader(
args.train_path, filename=args.model_filename, device=device
).load()
if model is None:
log.error("Could not find the model")
exit(1)
state = CheckpointLoader(device=device).load_by(
os.path.join(args.train_path, args.checkpoint)
)
model.load_state_dict(
state if args.source == "model" else Experiment.get_model_state_dict(state)
)
model = model.to(device)
model.eval()
dataset = ImageFromListDataset(
args.img_list, img_dirs=args.img_dirs, img_transform=ImageToTensor()
)
dataset_loader = ImageDataLoader(
dataset=dataset, image_channels=1, batch_size=args.batch_size, num_workers=8
)
batch_input_fn = ImageFeeder(device=device, parent_feeder=ItemFeeder("img"))
archive_writers = []
if args.output_matrix is not None:
archive_writers.append(ArchiveMatrixWriter(args.output_matrix))
if args.output_lattice is not None:
archive_writers.append(
ArchiveLatticeWriter(args.output_lattice, digits=args.digits, negate=True)
)
if not archive_writers:
log.error(
"You did not specify any output file! "
"Use --output_matrix and/or --output_lattice"
)
exit(1)
if args.show_progress_bar:
dataset_loader = tqdm(dataset_loader)
for batch in dataset_loader:
batch_input = batch_input_fn(batch)
batch_output = model(batch_input)
batch_output, batch_sizes = transform_output(batch_output)
batch_output = batch_output.permute(1, 0, 2)
if args.output_transform:
batch_output = getattr(functional, args.output_transform)(
batch_output, dim=-1
)
batch_output = batch_output.detach().cpu().numpy()
for key, out, out_size in zip(batch["id"], batch_output, batch_sizes):
out = out[:out_size, :]
for writer in archive_writers:
writer.write(key, out)