Skip to content

Commit

Permalink
Update eval_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Sep 13, 2023
1 parent 90637c7 commit 1789bee
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/brevitas_examples/super_resolution/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@

parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation')
parser.add_argument('--data_root', help='Path to folder containing BSD300 val folder')
parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint')
parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint. Default = None')
parser.add_argument(
'--save_path', type=str, default='outputs/', help='Save path for exported model')
'--save_path',
type=str,
default='outputs/',
help='Save path for exported model. Default = outputs/')
parser.add_argument(
'--model', type=str, default='quant_espcn_x2_w8a8_base', help='Name of the model configuration')
parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers')
parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size')
'--model',
type=str,
default='quant_espcn_x2_w8a8_base',
help='Name of the model configuration. Default = quant_espcn_x2_w8a8_base')
parser.add_argument(
'--workers', type=int, default=0, help='Number of data loading workers. Default = 0')
parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size. Default = 16')
parser.add_argument(
'--crop_size', type=int, default=512, help='The size to crop the image. Default = 512')
parser.add_argument('--use_pretrained', action='store_true', default=False)
parser.add_argument('--eval_acc_bw', action='store_true', default=False)
parser.add_argument('--save_model_io', action='store_true', default=False)
Expand All @@ -60,6 +69,7 @@ def main():
num_workers=args.workers,
batch_size=args.batch_size,
upscale_factor=model.upscale_factor,
crop_size=args.crop_size,
download=True)

test_psnr = evaluate_avg_psnr(testloader, model)
Expand Down

0 comments on commit 1789bee

Please sign in to comment.