-
Notifications
You must be signed in to change notification settings - Fork 43
/
create_model.py
executable file
·128 lines (112 loc) · 4.2 KB
/
create_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
from typing import Any, Dict, List, Optional
import jsonargparse
import torch.nn as nn
from jsonargparse.typing import NonNegativeInt
from pytorch_lightning import seed_everything
import laia.common.logging as log
from laia.common.arguments import CommonArgs, CreateCRNNArgs
from laia.common.saver import ModelSaver
from laia.models.htr import LaiaCRNN
from laia.scripts.htr import common_main
from laia.utils import SymbolsTable
def run(
syms: str,
fixed_input_height: Optional[NonNegativeInt] = 0,
adaptive_pooling: str = "avgpool-16",
common: CommonArgs = CommonArgs(),
crnn: CreateCRNNArgs = CreateCRNNArgs(),
save_model: bool = False,
) -> LaiaCRNN:
seed_everything(common.seed)
crnn.num_output_labels = len(SymbolsTable(syms))
if crnn is not None:
if fixed_input_height:
conv_output_size = LaiaCRNN.get_conv_output_size(
size=(fixed_input_height, fixed_input_height),
cnn_kernel_size=crnn.cnn_kernel_size,
cnn_stride=crnn.cnn_stride,
cnn_dilation=crnn.cnn_dilation,
cnn_poolsize=crnn.cnn_poolsize,
)
fixed_size_after_conv = conv_output_size[1 if crnn.vertical_text else 0]
assert (
fixed_size_after_conv > 0
), "The image size is too small for the CNN architecture"
crnn.image_sequencer = f"none-{fixed_size_after_conv}"
else:
crnn.image_sequencer = adaptive_pooling
crnn.rnn_type = getattr(nn, crnn.rnn_type)
crnn.cnn_activation = [getattr(nn, act) for act in crnn.cnn_activation]
model = LaiaCRNN(**vars(crnn))
log.info(
"Model has {} parameters",
sum(param.numel() for param in model.parameters()),
)
if save_model:
ModelSaver(common.train_path, common.model_filename).save(
LaiaCRNN, **vars(crnn)
)
return model
def get_args(argv: Optional[List[str]] = None) -> Dict[str, Any]:
parser = jsonargparse.ArgumentParser(
parse_as_dict=True,
description=(
"Create a model for HTR composed of a set of convolutional blocks, followed"
" by a set of bidirectional RNN layers, and a final linear layer. Each"
" convolutional block is composed by a 2D convolutional layer, an optional"
" batch normalization layer, a non-linear activation function, and an"
" optional 2D max-pooling layer. A dropout layer might precede each"
" block, rnn layer, and the final linear layer"
),
)
parser.add_argument(
"--config", action=jsonargparse.ActionConfigFile, help="Configuration file"
)
parser.add_argument(
"syms",
type=str,
help=(
"Mapping from strings to integers. "
"The CTC symbol must be mapped to integer 0"
),
)
parser.add_argument(
"--fixed_input_height",
type=NonNegativeInt,
default=0,
help=(
"Height of the input images. If 0, a variable height model "
"will be used (see `adaptive_pooling`). This will be used to compute the "
"model output height at the end of the convolutional layers"
),
)
parser.add_argument(
"--adaptive_pooling",
type=str,
default="avgpool-16",
help=(
"Use our custom adaptive pooling layers. This option allows training with"
" variable height images. Takes into account the size of each individual"
" image within the bach (before padding). (allowed: {avg,max}pool-N)"
),
)
parser.add_argument(
"--save_model",
type=bool,
default=True,
help="Whether to save the model to a file",
)
parser.add_class_arguments(CommonArgs, "common")
parser.add_function_arguments(log.config, "logging")
parser.add_class_arguments(CreateCRNNArgs, "crnn")
args = parser.parse_args(argv, with_meta=False)
args["common"] = CommonArgs(**args["common"])
args["crnn"] = CreateCRNNArgs(**args["crnn"])
return args
def main():
args = get_args()
args = common_main(args)
run(**args)
if __name__ == "__main__":
main()