-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_pipeline_config.py
134 lines (125 loc) · 4.28 KB
/
get_pipeline_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def get_unet_scheduler(i: int):
return f'''
{{
model_name: "unet",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample{i}" }},
{{ key: "timestep", value: "timestep{i}" }},
{{ key: "encoder_hidden_states", value: "encoder_hidden_states"}}
],
output_map: [{{ key: "noise", value: "noise{i}" }}]
}},
{{
model_name: "scheduler",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample{i}" }},
{{ key: "noise", value: "noise{i}" }},
{{ key: "i", value: "i{i}" }}
],
output_map: [
{{ key: "sample", value: "sample{i+1}" }},
{{ key: "timestep", value: "timestep{i+1}" }},
{{ key: "i", value: "i{i+1}" }}
]
}},
'''
def pbtxt(num_infer: int = 25, img_size: int = 512):
_unet_schedulers = "\n".join([get_unet_scheduler(i)
for i in range(0, num_infer-2)])
return f'''name: "pipeline"
max_batch_size: 0
platform: "ensemble"
input: [
{{ name: "prompts", data_type: TYPE_STRING, dims: [-1, 1] }}
]
output: [
{{
name: "output",
data_type: TYPE_FP32,
dims: [-1, 3, {img_size}, {img_size}]
}}
]
ensemble_scheduling: {{
step: [
{{
model_name: "tokenizer",
model_version: -1,
input_map: [{{ key: "prompts", value: "prompts" }}],
output_map: [
{{ key: "input_ids", value: "input_ids" }},
{{ key: "attention_mask", value: "attention_mask" }},
{{ key: "sample", value: "sample" }},
{{ key: "timestep", value: "timestep" }},
{{ key: "i", value: "i"}}
]
}},
{{
model_name: "text_encoder",
model_version: -1,
input_map: [
{{ key: "input_ids", value: "input_ids" }},
{{ key: "attention_mask", value: "attention_mask" }}
],
output_map: [{{ key: "encoder_hidden_states", value: "encoder_hidden_states"}}]
}},
{{
model_name: "unet",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample" }},
{{ key: "timestep", value: "timestep" }},
{{ key: "encoder_hidden_states", value: "encoder_hidden_states"}}
],
output_map: [{{ key: "noise", value: "noise" }}]
}},
{{
model_name: "scheduler",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample" }},
{{ key: "noise", value: "noise" }},
{{ key: "i", value: "i" }}
],
output_map: [
{{ key: "sample", value: "sample0" }},
{{ key: "timestep", value: "timestep0" }},
{{ key: "i", value: "i0" }}
]
}},
{_unet_schedulers}
{{
model_name: "unet",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample{num_infer-2}" }},
{{ key: "timestep", value: "timestep{num_infer-2}" }},
{{ key: "encoder_hidden_states", value: "encoder_hidden_states"}}
],
output_map: [{{ key: "noise", value: "noise{num_infer-2}" }}]
}},
{{
model_name: "scheduler",
model_version: -1,
input_map: [
{{ key: "sample", value: "sample{num_infer-2}" }},
{{ key: "noise", value: "noise{num_infer-2}" }},
{{ key: "i", value: "i{num_infer-2}" }}
],
output_map: [
{{ key: "sample", value: "sample{num_infer-1}" }}
]
}},
{{
model_name: "vae",
model_version: -1,
input_map: [{{ key: "z", value: "sample{num_infer-1}" }}],
output_map: [{{ key: "image", value: "output" }}]
}}
]
}}
'''
if __name__ == "__main__":
with open("./repo/pipeline/config.pbtxt", "w", encoding="utf-8") as f:
f.write(pbtxt(25, 512))