Skip to content

Commit

Permalink
SD PNDMScheduler + Unet example through Turbine (nod-ai#403)
Browse files Browse the repository at this point in the history
TODO: Need to update the rest of the schedulers in diffusers upstream for e2e test to work. Xfailed for now.
  • Loading branch information
aviator19941 committed Feb 17, 2024
1 parent f1c3d16 commit fabd52c
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 2 deletions.
1 change: 1 addition & 0 deletions core/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.unbind.int,
# decompositions added manually in this file
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
Expand Down
178 changes: 178 additions & 0 deletions models/turbine_models/custom_models/sd_inference/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from shark_turbine.aot import *
from iree import runtime as ireert
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np

from turbine_models.custom_models.sd_inference import utils
from diffusers import (
UNet2DConditionModel,
)

import safetensors
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--scheduler_id",
type=str,
help="Scheduler ID",
default="PNDM",
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [safetensors]",
)
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")


class Scheduler(torch.nn.Module):
def __init__(self, hf_model_name, num_inference_steps, scheduler):
super().__init__()
self.scheduler = scheduler
self.scheduler.set_timesteps(num_inference_steps)
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
self.guidance_scale = 7.5

def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
latents = latents * self.scheduler.init_noise_sigma
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
t = t.unsqueeze(0)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, timestep=t
)
unet_out = self.unet.forward(
latent_model_input, t, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
return latents


def export_scheduler(
scheduler,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_path=None,
device=None,
target_triple=None,
max_alloc=None,
):
mapper = {}
utils.save_external_weights(
mapper, scheduler, external_weights, external_weight_path
)

encoder_hidden_states_sizes = (2, 77, 768)
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)

sample = (batch_size, 4, height // 8, width // 8)

class CompiledScheduler(CompiledModule):
if external_weights:
params = export_parameters(
scheduler, external=True, external_scope="", name_mapper=mapper.get
)
else:
params = export_parameters(scheduler)

def main(
self,
sample=AbstractTensor(*sample, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
),
):
return jittable(scheduler.forward)(sample, encoder_hidden_states)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledScheduler(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = utils.create_safe_name(hf_model_name, "-scheduler")
if compile_to != "vmfb":
return module_str
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


if __name__ == "__main__":
args = parser.parse_args()
schedulers = utils.get_schedulers(args.hf_model_name)
scheduler = schedulers[args.scheduler_id]
scheduler_module = Scheduler(
args.hf_model_name, args.num_inference_steps, scheduler
)
mod_str = export_scheduler(
scheduler_module,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_path,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
172 changes: 172 additions & 0 deletions models/turbine_models/custom_models/sd_inference/schedulers_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
from turbine_models.model_runner import vmfbRunner
from iree import runtime as ireert
import torch
from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)

parser = argparse.ArgumentParser()

# TODO move common runner flags to generic flag file
parser.add_argument(
"--scheduler_id",
type=str,
help="Scheduler ID",
default="PNDM",
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--vmfb_path", type=str, default="", help="path to vmfb containing compiled module"
)
parser.add_argument(
"--external_weight_path",
type=str,
default="",
help="path to external weight parameters if model compiled without them",
)
parser.add_argument(
"--compare_vs_torch",
action="store_true",
help="Runs both turbine vmfb and a torch model to compare results",
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--hf_auth_token",
type=str,
help="The Hugging face auth token, required for some models",
)
parser.add_argument(
"--device",
type=str,
default="local-task",
help="local-sync, local-task, cuda, vulkan, rocm",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")


def run_scheduler(
device,
sample,
encoder_hidden_states,
vmfb_path,
hf_model_name,
hf_auth_token,
external_weight_path,
):
runner = vmfbRunner(device, vmfb_path, external_weight_path)

inputs = [
ireert.asdevicearray(runner.config.device, sample),
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
]
results = runner.ctx.modules.compiled_scheduler["main"](*inputs)
return results


def run_torch_scheduler(
hf_model_name, scheduler, num_inference_steps, sample, encoder_hidden_states
):
class Scheduler(torch.nn.Module):
def __init__(self, hf_model_name, num_inference_steps, scheduler):
super().__init__()
self.scheduler = scheduler
self.scheduler.set_timesteps(num_inference_steps)
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
self.guidance_scale = 7.5

def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
latents = latents * self.scheduler.init_noise_sigma
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
t = t.unsqueeze(0)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, timestep=t
)
unet_out = self.unet.forward(
latent_model_input, t, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False
)[0]
return latents

scheduler_module = Scheduler(hf_model_name, num_inference_steps, scheduler)
results = scheduler_module.forward(sample, encoder_hidden_states)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output


if __name__ == "__main__":
args = parser.parse_args()
sample = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32)

turbine_output = run_scheduler(
args.device,
sample,
encoder_hidden_states,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
args.external_weight_path,
)
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

if args.compare_vs_torch:
print("generating torch output: ")
from turbine_models.custom_models.sd_inference import utils

schedulers = utils.get_schedulers(args.hf_model_name)
scheduler = schedulers[args.scheduler_id]
torch_output = run_torch_scheduler(
args.hf_model_name,
scheduler,
args.num_inference_steps,
sample,
encoder_hidden_states,
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_output)
print("Largest Error: ", err)
assert err < 9e-3

# TODO: Figure out why we occasionally segfault without unlinking output variables
turbine_output = None
22 changes: 22 additions & 0 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import numpy as np
import safetensors
import re
from diffusers import (
PNDMScheduler,
)


def save_external_weights(
Expand Down Expand Up @@ -35,6 +38,7 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-flow-inline-constants-max-byte-length=1",
]
if device == "cpu":
flags.append("--iree-llvmcpu-enable-ukernels=all")
Expand Down Expand Up @@ -86,3 +90,21 @@ def create_safe_name(hf_model_name, model_name_str):
safe_name = hf_model_name.split("/")[-1].strip() + model_name_str
safe_name = re.sub("-", "_", safe_name)
return safe_name


def get_schedulers(model_id):
# TODO: Robust scheduler setup on pipeline creation -- if we don't
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.

batch_size = 2 if "xl" in model_id.lower() else 1

schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
return schedulers
Loading

0 comments on commit fabd52c

Please sign in to comment.