forked from keirp/automatic_prompt_engineer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
506 lines (431 loc) · 27 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
import openai
import numpy as np
import pandas as pd
import gradio as gr
from experiments.data.instruction_induction.load_data import load_data
from automatic_prompt_engineer.ape import get_simple_prompt_gen_template
from automatic_prompt_engineer import ape, evaluate, config, template, llm
model_types = ['text-ada-001', 'text-babbage-001', 'text-curie-001', 'text-davinci-002']
mode_types = ['forward', 'insert']
eval_types = ['likelihood', 'bandits']
task_types = ['antonyms', 'cause_and_effect', 'common_concept', 'diff', 'first_word_letter',
'informal_to_formal', 'larger_animal', 'letters_list', 'taxonomy_animal', 'negation', 'num_to_verbal',
'active_to_passive', 'singular_to_plural', 'rhymes',
'second_word_letter', 'sentence_similarity', 'sentiment', 'orthography_starts_with',
'sum', 'synonyms', 'translation_en-de', 'translation_en-es',
'translation_en-fr', 'word_in_context']
def load_task(task):
inputs, outputs = load_data('induce', task)
train_data = '\n'.join([f'{inp} >>> {out[0] if len(out) == 1 else out}' for inp, out in zip(inputs, outputs)])
inputs, outputs = load_data('execute', task)
test_data = '\n'.join([f'{inp} >>> {out[0] if len(out) == 1 else out}' for inp, out in zip(inputs, outputs)])
return train_data, test_data
# Problem dataset: [common_concept, larger_animal,taxonomy_animal,rhymes,sentiment,orthography_starts_with]
def parse_data(dataset):
"""Parses the data from the text of a csv into two lists of strings.
The CSV has no header and the first column is the input and the second column is the output."""
dataset = dataset.split('\n')
dataset = [line.split(' >>> ') for line in dataset]
dataset = [inp for inp, _ in dataset], [out for inp, out in dataset]
return dataset
def prompts_to_df(prompts, scores):
"""Converts a list of prompts into a dataframe."""
df = pd.DataFrame()
df['Prompt'] = prompts
df['log(p)'] = scores
df['log(p)'] = df['log(p)'].apply(lambda x: round(x, 3)) # Round the scores to 3 decimal places
df = df.head(15) # Only show the top 15 prompts
return df
def prompt_overview():
gen_prompt = ''
eval_prompt = ''
return gen_prompt, eval_prompt
def run_ape(prompt_gen_data, eval_data,
eval_template='Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]',
prompt_gen_template=None,
demos_template='Input: [INPUT]\nOutput: [OUTPUT]',
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500, # basic
num_subsamples=None, num_demos=None, # advanced
num_samples=None, num_few_shot=None # advanced
):
"""
Function that wraps the find_prompts function to make it easier to use.
Design goals: include default values for most parameters, and automatically
fill out the config dict for the user in a way that fits almost all use cases.
The main shortcuts this function takes are:
- Uses the same dataset for prompt generation, evaluation, and few shot demos
- Uses UCB algorithm for evaluation
- Fixes the number of prompts per round to num_prompts // 3 (so the first three rounds will
sample every prompt once)
- Fixes the number of samples per prompt per round to 5
Parameters:
prompt_gen_data: The dataset to use for prompt generation.
eval_data: The dataset to use for evaluation.
eval_template: The template for the evaluation queries.
prompt_gen_template: The template to use for prompt generation.
demos_template: The template for the demos.
eval_model: The model to use for evaluation.
prompt_gen_model: The model to use for prompt generation.
prompt_gen_mode: The mode to use for prompt generation.
num_prompts: The number of prompts to generate during the search.
eval_rounds: The number of evaluation rounds to run.
prompt_gen_batch_size: The batch size to use for prompt generation.
eval_batch_size: The batch size to use for evaluation.
num_subsamples: The number of different demos to generate the instruction.
num_demos: The number of demos to generate the instruction.
num_samples: The number of samples at each round for evaluation.
num_few_shot: The number of few shot examples to use for evaluation.
Returns:
An evaluation result and a function to evaluate the prompts with new inputs.
"""
prompt_gen_data = parse_data(prompt_gen_data)
eval_data = parse_data(eval_data)
prompt_gen_template = get_simple_prompt_gen_template(prompt_gen_template, prompt_gen_mode)
if demos_template is None:
demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
conf = config.simple_config(
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds, prompt_gen_batch_size, eval_batch_size)
if num_subsamples is not None:
conf['generation']['num_subsamples'] = num_subsamples
conf['generation']['num_prompts_per_subsample'] = num_prompts // num_subsamples
if num_demos is not None:
conf['generation']['num_demos'] = num_demos
if num_samples is not None:
if conf['evaluation']['method'] == 'bandits':
conf['evaluation']['base_eval_config']['num_samples'] = num_samples
else:
conf['evaluation']['num_samples'] = num_samples
if num_few_shot is not None:
if conf['evaluation']['method'] == 'bandits':
conf['evaluation']['base_eval_config']['num_few_shot'] = num_few_shot
else:
conf['evaluation']['num_few_shot'] = num_few_shot
res, demo_fn = ape.find_prompts(eval_template, demos_template, prompt_gen_data, eval_data, conf,
prompt_gen_template=prompt_gen_template)
prompts, scores = res.sorted()
df = prompts_to_df(prompts, scores)
generation_query = ape.get_generation_query(eval_template, demos_template, conf, prompt_gen_data,
prompt_gen_template)[0]
evaluation_query = ape.get_evaluation_query(eval_template, demos_template, conf, eval_data, prompt_gen_data)[0]
return df, generation_query, evaluation_query, prompts[0], prompts[0], scores[0]
def basic_ape(prompt_gen_data, eval_data,
eval_template='Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]',
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500):
return run_ape(prompt_gen_data, eval_data, eval_template,
eval_model=eval_model, prompt_gen_model=prompt_gen_model, prompt_gen_mode=prompt_gen_mode,
num_prompts=num_prompts, eval_rounds=eval_rounds, prompt_gen_batch_size=prompt_gen_batch_size,
eval_batch_size=eval_batch_size)
def advance_ape(prompt_gen_data, eval_data,
eval_template, prompt_gen_template, demos_template,
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500, # basic
num_subsamples=None, num_demos=None, # advanced
num_samples=None, num_few_shot=None # advanced
):
if prompt_gen_mode == 'forward':
if prompt_gen_template[-5:] != '[APE]':
raise ValueError('The prompt_gen_template must end with [APE] for forward mode.')
return run_ape(prompt_gen_data, eval_data, eval_template, prompt_gen_template, demos_template,
eval_model=eval_model, prompt_gen_model=prompt_gen_model, prompt_gen_mode=prompt_gen_mode,
num_prompts=num_prompts, eval_rounds=eval_rounds, prompt_gen_batch_size=prompt_gen_batch_size,
eval_batch_size=eval_batch_size, num_subsamples=num_subsamples, num_demos=num_demos,
num_samples=num_samples, num_few_shot=num_few_shot)
def estimate_cost(prompt_gen_data, eval_data,
eval_template='Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]',
prompt_gen_template=None,
demos_template='Input: [INPUT]\nOutput: [OUTPUT]',
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500, # basic
num_subsamples=None, num_demos=None, # advanced
num_samples=None, num_few_shot=None # advanced
):
"""
Function that wraps the find_prompts function to make it easier to use.
Design goals: include default values for most parameters, and automatically
fill out the config dict for the user in a way that fits almost all use cases.
The main shortcuts this function takes are:
- Uses the same dataset for prompt generation, evaluation, and few shot demos
- Uses UCB algorithm for evaluation
- Fixes the number of prompts per round to num_prompts // 3 (so the first three rounds will
sample every prompt once)
- Fixes the number of samples per prompt per round to 5
Parameters:
prompt_gen_data: The dataset to use for prompt generation.
eval_data: The dataset to use for evaluation.
eval_template: The template for the evaluation queries.
prompt_gen_template: The template to use for prompt generation.
demos_template: The template for the demos.
eval_model: The model to use for evaluation.
prompt_gen_model: The model to use for prompt generation.
prompt_gen_mode: The mode to use for prompt generation.
num_prompts: The number of prompts to generate during the search.
eval_rounds: The number of evaluation rounds to run.
prompt_gen_batch_size: The batch size to use for prompt generation.
eval_batch_size: The batch size to use for evaluation.
num_subsamples: The number of different demos to generate the instruction.
num_demos: The number of demos to generate the instruction.
num_samples: The number of samples at each round for evaluation.
num_few_shot: The number of few shot examples to use for evaluation.
Returns:
An evaluation result and a function to evaluate the prompts with new inputs.
"""
prompt_gen_data = parse_data(prompt_gen_data)
eval_data = parse_data(eval_data)
prompt_gen_template = get_simple_prompt_gen_template(prompt_gen_template, prompt_gen_mode)
if demos_template is None:
demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
conf = config.simple_config(
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds, prompt_gen_batch_size, eval_batch_size)
if num_subsamples is not None:
conf['generation']['num_subsamples'] = num_subsamples
conf['generation']['num_prompts_per_subsample'] = num_prompts // num_subsamples
if num_demos is not None:
conf['generation']['num_demos'] = num_demos
if num_samples is not None:
if conf['evaluation']['method'] == 'bandits':
conf['evaluation']['base_eval_config']['num_samples'] = num_samples
else:
conf['evaluation']['num_samples'] = num_samples
if num_few_shot is not None:
if conf['evaluation']['method'] == 'bandits':
conf['evaluation']['base_eval_config']['num_few_shot'] = num_few_shot
else:
conf['evaluation']['num_few_shot'] = num_few_shot
cost = ape.estimate_cost(eval_template, demos_template, prompt_gen_data, eval_data, conf,
prompt_gen_template=prompt_gen_template)
generation_query = ape.get_generation_query(eval_template, demos_template, conf, prompt_gen_data,
prompt_gen_template)[0]
evaluation_query = ape.get_evaluation_query(eval_template, demos_template, conf, eval_data, prompt_gen_data)[0]
return cost, generation_query, evaluation_query
def basic_estimate_cost(prompt_gen_data,
eval_data,
eval_template='Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]',
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500):
return estimate_cost(prompt_gen_data, eval_data, eval_template,
eval_model=eval_model, prompt_gen_model=prompt_gen_model, prompt_gen_mode=prompt_gen_mode,
num_prompts=num_prompts, eval_rounds=eval_rounds, prompt_gen_batch_size=prompt_gen_batch_size,
eval_batch_size=eval_batch_size)
def advance_estimate_cost(prompt_gen_data, eval_data,
eval_template, prompt_gen_template, demos_template,
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500, # basic
num_subsamples=None, num_demos=None, # advanced
num_samples=None, num_few_shot=None # advanced
):
if prompt_gen_template[-5:] == '[APE]':
prompt_gen_mode = 'forward'
else:
prompt_gen_mode = 'insert'
return estimate_cost(prompt_gen_data, eval_data, eval_template, prompt_gen_template, demos_template,
eval_model=eval_model, prompt_gen_model=prompt_gen_model, prompt_gen_mode=prompt_gen_mode,
num_prompts=num_prompts, eval_rounds=eval_rounds, prompt_gen_batch_size=prompt_gen_batch_size,
eval_batch_size=eval_batch_size, num_subsamples=num_subsamples, num_demos=num_demos,
num_samples=num_samples, num_few_shot=num_few_shot)
def compute_score(prompt,
eval_data,
eval_template,
demos_template,
eval_model='text-davinci-002',
num_few_shot=None # advanced
):
eval_data = parse_data(eval_data)
eval_template = template.EvalTemplate(eval_template)
demos_template = template.DemosTemplate(demos_template)
conf = config.update_config({}, 'configs/default.yaml')
conf['evaluation']['model']['gpt_config']['model'] = eval_model
conf['evaluation']['num_samples'] = min(len(eval_data[0]), 50)
conf['evaluation']['num_few_shot'] = num_few_shot
res = evaluate.evalute_prompts([prompt], eval_template, eval_data, demos_template, eval_data,
conf['evaluation']['method'], conf['evaluation'])
return round(res.sorted()[1][0], 3)
def run_prompt(prompt, inputs,
eval_template='Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]',
eval_model='text-davinci-002',
prompt_gen_model='text-davinci-002',
prompt_gen_mode='forward',
num_prompts=50, eval_rounds=10, prompt_gen_batch_size=200, eval_batch_size=500):
conf = config.simple_config(
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds, prompt_gen_batch_size, eval_batch_size)
eval_template = template.EvalTemplate(eval_template)
model = llm.model_from_config(conf['evaluation']['base_eval_config']['model'])
if inputs == '':
queries = [prompt]
else:
if not isinstance(inputs, list):
inputs = [inputs]
queries = []
for input_ in inputs:
query = eval_template.fill(prompt=prompt, input=input_)
queries.append(query)
outputs = model.generate_text(queries, n=1)
return outputs[0].strip()
def get_demo():
assert openai.api_key is not None, 'Please set your OpenAI API key first.'
assert openai.api_key != '', 'Please set your OpenAI API key first.'
with gr.Blocks(title="Automatic Prompt Engineer", css=None, ) as demo:
gr.Markdown("# Automatic Prompt Engineer")
gr.Markdown("""This WebUI demonstrates how to use Automatic Prompt Engineer [APE](arxiv link) to optimize
prompts for text generation. In its simplest form, APE takes as input a dataset (a list of inputs and a
list of outputs), a prompt template, and optimizes this prompt template so that it generates the outputs
given the inputs. APE accomplishes this in two steps. First, it uses a language model to generate a set of
candidate prompts. Then, it uses a prompt evaluation function to evaluate the quality of each candidate
prompt. Finally, it returns the prompt with the highest evaluation score. Let's try it out!""")
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## APE")
with gr.Tab("Basic"):
prompt_gen_mode = gr.Dropdown(label="Prompt Generation Mode", choices=mode_types,
value="forward")
with gr.Row():
eval_template = gr.Textbox(lines=5,
value="Instruction: [PROMPT]\nInput: [INPUT]\nOutput: [OUTPUT]",
label="Evaluation Template")
with gr.Row():
basic_cost = gr.Textbox(lines=1, value="", label="Estimated Cost ($)", disabled=True)
basic_cost_button = gr.Button("Estimate Cost")
basic_ape_button = gr.Button("APE")
with gr.Tab("Advanced"):
with gr.Row():
prompt_gen_template = gr.Textbox(lines=6,
value="""I gave a friend an instruction. Based on the instruction they produced the following input-output pairs:\n\n[full_DEMO]\n\nThe instruction was to [APE]""",
label="Prompt Generation Template")
demos_template = gr.Textbox(lines=6, value="Input: [INPUT]\nOutput: [OUTPUT]",
label="Demos Template")
with gr.Row():
cost = gr.Textbox(lines=1, value="", label="Estimated Cost ($)", disabled=True)
cost_button = gr.Button("Estimate Cost")
ape_button = gr.Button("APE")
with gr.Tab("Dataset"):
with gr.Row():
prompt_gen_data = gr.Textbox(lines=10, value=load_task('antonyms')[0],
label="Data for prompt generation")
eval_data = gr.Textbox(lines=10, value=load_task('antonyms')[1],
label="Data for scoring")
with gr.Row():
task = gr.Dropdown(label="Task", choices=task_types, value="antonyms")
load_task_button = gr.Button("Load Task")
with gr.Column(scale=2):
gr.Markdown("## Configuration")
with gr.Tab("Basic"):
with gr.Row():
prompt_gen_model = gr.Dropdown(label="Prompt Generation Model", choices=model_types,
value="text-davinci-002")
eval_model = gr.Dropdown(label="Evaluation Model", choices=model_types,
value="text-davinci-002")
with gr.Row():
num_prompts = gr.Slider(label="Number of Prompts", minimum=1, maximum=250, step=10, value=50)
eval_rounds = gr.Slider(label="Rounds of Evaluation", minimum=1, maximum=20, step=1, value=5)
with gr.Row():
prompt_gen_batch_size = gr.Slider(label="Batch Size for prompt generation",
minimum=1, maximum=500, step=10, value=200)
eval_batch_size = gr.Slider(label="Batch Size for Evaluation",
minimum=4, maximum=1000, step=1, value=500)
with gr.Tab("Advanced"):
with gr.Column(scale=1):
with gr.Row():
num_subsamples = gr.Slider(minimum=1, maximum=10, value=5, step=1,
label="Number of subsamples for instruction generation")
num_demos = gr.Slider(minimum=1, maximum=25, value=5, step=1,
label="Number of demos for instruction generation")
with gr.Row():
num_samples = gr.Slider(minimum=1, maximum=50, value=10, step=1,
label="Number of evaluation samples at each round")
num_few_shot = gr.Slider(minimum=0, maximum=25, value=0, step=1,
label="Number of few-shot examples used in evaluation")
gr.Markdown("## Results")
with gr.Row():
with gr.Tab("APE Results"):
output_df = gr.DataFrame(type='pandas', headers=['Prompt', 'Likelihood'], wrap=True, interactive=False)
with gr.Tab("Prompt Overview"):
with gr.Row():
generation_prompt_sample = gr.Textbox(lines=8, value="",
label="Instruction Generation Prompts",
disabled=True)
evaluation_prompt_sample = gr.Textbox(lines=8, value="",
label="Evaluation Prompts",
disabled=True)
with gr.Tab("Prompt Deployment"):
with gr.Row():
with gr.Column(scale=1):
test_prompt = gr.Textbox(lines=4, value="Please evaluate the following expression.",
label="Prompt")
test_inputs = gr.Textbox(lines=1, value="",
label="Input (If empty, prompt is executed directly)")
answer_button = gr.Button("Submit")
with gr.Column(scale=1):
test_output = gr.Textbox(lines=9, value="", label="Model Output")
with gr.Tab("Prompt Score"):
with gr.Row():
with gr.Column(scale=1):
score_prompt = gr.Textbox(lines=3, value="Please evaluate the following expression.",
label="Prompt (Evaluate on scoring dataset using Evaluation Template)")
compute_score_button = gr.Button("Compute Score")
with gr.Column(scale=1):
test_score = gr.Textbox(lines=1, value="", label="Log(p)", disabled=True)
##############################
# Button Callbacks
##############################
load_task_button.click(load_task, inputs=[task], outputs=[prompt_gen_data, eval_data])
basic_ape_button.click(basic_ape,
inputs=[prompt_gen_data, eval_data, # Data
eval_template, # Templates
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds,
prompt_gen_batch_size, eval_batch_size, # Basic Configuration
],
outputs=[output_df, generation_prompt_sample, evaluation_prompt_sample,
test_prompt, score_prompt, test_score])
ape_button.click(advance_ape,
inputs=[prompt_gen_data, eval_data, # Data
eval_template, prompt_gen_template, demos_template, # Templates
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds,
prompt_gen_batch_size, eval_batch_size, # Basic
num_subsamples, num_demos, num_samples, num_few_shot # Advanced
],
outputs=[output_df, generation_prompt_sample, evaluation_prompt_sample,
test_prompt, score_prompt, test_score])
basic_cost_button.click(basic_estimate_cost,
inputs=[prompt_gen_data, eval_data, # Data
eval_template, # Templates
eval_model, prompt_gen_model, prompt_gen_mode, num_prompts, eval_rounds,
prompt_gen_batch_size, eval_batch_size, # Basic Configuration
],
outputs=[basic_cost, generation_prompt_sample, evaluation_prompt_sample])
cost_button.click(advance_estimate_cost,
inputs=[prompt_gen_data, eval_data, # Data
eval_template, prompt_gen_template, demos_template, # Templates
eval_model, prompt_gen_model, num_prompts, eval_rounds,
prompt_gen_batch_size, eval_batch_size, # Basic
num_subsamples, num_demos, num_samples, num_few_shot # Advanced
],
outputs=[cost, generation_prompt_sample, evaluation_prompt_sample])
compute_score_button.click(compute_score,
inputs=[score_prompt,
eval_data,
eval_template,
demos_template,
eval_model,
num_few_shot],
outputs=[test_score])
answer_button.click(run_prompt,
inputs=[test_prompt, test_inputs, eval_template, eval_model,
prompt_gen_model, prompt_gen_mode,
num_prompts, eval_rounds, prompt_gen_batch_size, eval_batch_size],
outputs=[test_output])
return demo
if __name__ == "__main__":
demo = get_demo()
demo.launch(show_error=True)