Skip to content
This repository has been archived by the owner on Sep 11, 2022. It is now read-only.

Commit

Permalink
Merge pull request #182 from yt605155624/pwg_benchmark
Browse files Browse the repository at this point in the history
fix type
  • Loading branch information
yt605155624 authored Sep 24, 2021
2 parents 4ce6342 + c64b515 commit 062b0bd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/parallelwave_gan/baker/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ def str2bool(str):
benchmark_group = parser.add_argument_group(
'benchmark', 'arguments related to benchmark.')
benchmark_group.add_argument(
"--batch-size", type=str, default="8", help="batch size.")
"--batch-size", type=int, default=8, help="batch size.")
benchmark_group.add_argument(
"--max-iter", type=str, default="400000", help="train max steps.")
"--max-iter", type=int, default=400000, help="train max steps.")

benchmark_group.add_argument(
"--run-benchmark",
Expand All @@ -250,8 +250,8 @@ def str2bool(str):

# 增加 --batch_size --max_iter 用于 benchmark 调用
if args.run_benchmark:
config.batch_size = int(args.batch_size)
config.train_max_steps = int(args.max_iter)
config.batch_size = args.batch_size
config.train_max_steps = args.max_iter

print("========Args========")
print(yaml.safe_dump(vars(args)))
Expand Down

0 comments on commit 062b0bd

Please sign in to comment.