Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 18, 2024
1 parent 6d1dc35 commit 3dcfa1f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,11 @@ def input_zp_stats_type():
'vae-fp16-fix',
default=False,
help='Rescale the VAE to not go NaN with FP16. Default: Disabled')
add_bool_arg(
parser,
'share-qkv-quant',
default=False,
help='Share QKV/KV quantization. Default: Disabled')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def compute_mlperf_fid(
model_to_replace=None,
samples_to_evaluate=500,
output_dir=None,
device='cpu',
vae_force_upcast=True):

assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions"
Expand All @@ -611,13 +612,13 @@ def compute_mlperf_fid(
dtype = next(iter(model_to_replace.unet.parameters())).dtype
res_dict = {}
model = BackendPytorch(
path_to_sdxl, 'xl', steps=20, batch_size=1, device='cpu', precision=dtype)
path_to_sdxl, 'xl', steps=20, batch_size=1, device=device, precision=dtype)
model.load()

if model_to_replace is not None:
model.pipe.unet = model_to_replace.unet
if not vae_force_upcast:
model.pipe.vae = model.pipe.vae
model.pipe.vae = model_to_replace.vae

model.pipe.vae.config.force_upcast = vae_force_upcast
ds = Coco(
Expand Down

0 comments on commit 3dcfa1f

Please sign in to comment.