Skip to content

Commit

Permalink
(shortfin-sd) Program initialization and logging improvements (#444)
Browse files Browse the repository at this point in the history
Fixes program initialization per worker and systembuilder usage/options
  • Loading branch information
monorimet authored Nov 9, 2024
1 parent eefc353 commit 029d35e
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 104 deletions.
11 changes: 7 additions & 4 deletions shortfin/python/shortfin/support/logging_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
native_handler.setFormatter(NativeFormatter())

# TODO: Source from env vars.
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
logger.addHandler(native_handler)


Expand All @@ -47,7 +47,10 @@ def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger:
Returns a logger that can be used for the main module itself.
"""
logging.root.addHandler(native_handler)
logging.root.setLevel(logging.DEBUG) # TODO: source from env vars
main_module = sys.modules["__main__"]
return logging.getLogger(f"{main_module.__package__}.{module_suffix}")
logging.root.setLevel(logging.INFO)
logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}")
logger.setLevel(logging.INFO)
logger.addHandler(native_handler)

return logger
6 changes: 3 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN):
if os.path.exists(out_file):
needed = False
else:
name_path = "bin" if namespace == FileNamespace.BIN else ""
if name_path:
filename = os.path.join(name_path, filename)
# name_path = "bin" if namespace == FileNamespace.BIN else ""
# if name_path:
# filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
Expand Down
51 changes: 32 additions & 19 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,42 @@
logger = logging.getLogger(__name__)


def get_selected_devices(sb: sf.SystemBuilder, device_ids=None):
available = sb.available_devices
selected = []
if device_ids is not None:
if len(device_ids) >= len(available):
raise ValueError(
f"Requested more device ids ({device_ids}) than available ({available})."
)
for did in device_ids:
if isinstance(did, str):
try:
did = int(did)
except ValueError:
did = did
if did in available:
selected.append(did)
elif isinstance(did, int):
selected.append(available[did])
else:
raise ValueError(f"Device id {did} could not be parsed.")
else:
selected = available
return selected


class SystemManager:
def __init__(self, device="local-task", device_ids=None):
def __init__(self, device="local-task", device_ids=None, async_allocs=True):
if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
sc_query = sf.amdgpu.SystemBuilder()
available = sc_query.available_devices
selected = []
if device_ids is not None:
if len(device_ids) >= len(available):
raise ValueError(
f"Requested more device ids ({device_ids}) than available ({available})."
)
for did in device_ids:
if did in available:
selected.append(did)
elif isinstance(did, int):
selected.append(available[did])
else:
raise ValueError(f"Device id {did} could not be parsed.")
else:
selected = available
sb = sf.amdgpu.SystemBuilder(amdgpu_visible_devices=";".join(selected))
sb = sf.SystemBuilder(
system_type="amdgpu", amdgpu_async_allocations=async_allocs
)
if device_ids:
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
Expand Down
126 changes: 73 additions & 53 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,20 @@ def __init__(
self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {}
self.inference_modules: dict[str, sf.ProgramModule] = {}
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
self.inference_programs: dict[str, sf.Program] = {}
self.inference_programs: dict[int, dict[str, sf.Program]] = {}
self.trace_execution = trace_execution
self.show_progress = show_progress

self.prog_isolation = prog_isolations[prog_isolation]

self.workers_per_device = workers_per_device
self.fibers_per_device = fibers_per_device
self.prog_isolation = prog_isolations[prog_isolation]
if fibers_per_device % workers_per_device != 0:
raise ValueError(
"Currently, fibers_per_device must be divisible by workers_per_device"
)
self.fibers_per_worker = int(fibers_per_device / workers_per_device)

self.workers = []
self.fibers = []
self.fiber_status = []
Expand All @@ -81,7 +89,9 @@ def __init__(
)
self.fibers.append(fiber)
self.fiber_status.append(0)

for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}
# Scope dependent objects.
self.batcher = BatcherProcess(self)

Expand All @@ -108,52 +118,59 @@ def load_inference_parameters(
self.inference_parameters[component].append(p)

def start(self):
for fiber in self.fibers:
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]
self.inference_programs[component] = sf.Program(
# Initialize programs.
# This can work if we only initialize one set of programs per service, as our programs
# in SDXL are stateless and
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]
for worker_idx, worker in enumerate(self.workers):
worker_devices = self.fibers[
worker_idx * (self.fibers_per_worker)
].raw_devices

self.inference_programs[worker_idx][component] = sf.Program(
modules=component_modules,
devices=fiber.raw_devices,
devices=worker_devices,
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)

# TODO: export vmfbs with multiple batch size entrypoints

self.inference_functions["encode"] = {}
for bs in self.model_params.clip_batch_sizes:
self.inference_functions["encode"][bs] = self.inference_programs["clip"][
f"{self.model_params.clip_module_name}.encode_prompts"
]

self.inference_functions["denoise"] = {}
for bs in self.model_params.unet_batch_sizes:
self.inference_functions["denoise"][bs] = {
"unet": self.inference_programs["unet"][
f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}"
],
"init": self.inference_programs["scheduler"][
f"{self.model_params.scheduler_module_name}.run_initialize"
],
"scale": self.inference_programs["scheduler"][
f"{self.model_params.scheduler_module_name}.run_scale"
],
"step": self.inference_programs["scheduler"][
f"{self.model_params.scheduler_module_name}.run_step"
],
}

self.inference_functions["decode"] = {}
for bs in self.model_params.vae_batch_sizes:
self.inference_functions["decode"][bs] = self.inference_programs["vae"][
f"{self.model_params.vae_module_name}.decode"
]

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
for bs in self.model_params.clip_batch_sizes:
self.inference_functions[worker_idx]["encode"][
bs
] = self.inference_programs[worker_idx]["clip"][
f"{self.model_params.clip_module_name}.encode_prompts"
]
self.inference_functions[worker_idx]["denoise"] = {}
for bs in self.model_params.unet_batch_sizes:
self.inference_functions[worker_idx]["denoise"][bs] = {
"unet": self.inference_programs[worker_idx]["unet"][
f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}"
],
"init": self.inference_programs[worker_idx]["scheduler"][
f"{self.model_params.scheduler_module_name}.run_initialize"
],
"scale": self.inference_programs[worker_idx]["scheduler"][
f"{self.model_params.scheduler_module_name}.run_scale"
],
"step": self.inference_programs[worker_idx]["scheduler"][
f"{self.model_params.scheduler_module_name}.run_step"
],
}
self.inference_functions[worker_idx]["decode"] = {}
for bs in self.model_params.vae_batch_sizes:
self.inference_functions[worker_idx]["decode"][
bs
] = self.inference_programs[worker_idx]["vae"][
f"{self.model_params.vae_module_name}.decode"
]
# breakpoint()
self.batcher.launch()

def shutdown(self):
Expand Down Expand Up @@ -320,7 +337,11 @@ def __init__(
):
super().__init__(fiber=service.fibers[index])
self.service = service
self.worker_index = index
self.fiber_index = index
self.worker_index = int(
(index - index % self.service.fibers_per_worker)
/ self.service.fibers_per_worker
)
self.exec_requests: list[InferenceExecRequest] = []

@measure(type="exec", task="inference process")
Expand All @@ -335,7 +356,7 @@ async def run(self):
phases = self.exec_requests[0].phases

req_count = len(self.exec_requests)
device0 = self.service.fibers[self.worker_index].device(0)
device0 = self.service.fibers[self.fiber_index].device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
Expand All @@ -346,11 +367,11 @@ async def run(self):
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)

await device0
for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
self.service.fiber_status[self.worker_index] = 0
self.service.fiber_status[self.fiber_index] = 0

except Exception:
logger.exception("Fatal error in image generation")
Expand Down Expand Up @@ -400,8 +421,7 @@ async def _prepare(self, device, requests):

async def _encode(self, device, requests):
req_bs = len(requests)

entrypoints = self.service.inference_functions["encode"]
entrypoints = self.service.inference_functions[self.worker_index]["encode"]
for bs, fn in entrypoints.items():
if bs >= req_bs:
break
Expand Down Expand Up @@ -454,7 +474,7 @@ async def _denoise(self, device, requests):
step_count = requests[0].steps
cfg_mult = 2 if self.service.model_params.cfg_mode else 1
# Produce denoised latents
entrypoints = self.service.inference_functions["denoise"]
entrypoints = self.service.inference_functions[self.worker_index]["denoise"]
for bs, fns in entrypoints.items():
if bs >= req_bs:
break
Expand Down Expand Up @@ -590,7 +610,7 @@ async def _denoise(self, device, requests):
async def _decode(self, device, requests):
req_bs = len(requests)
# Decode latents to images
entrypoints = self.service.inference_functions["decode"]
entrypoints = self.service.inference_functions[self.worker_index]["decode"]
for bs, fn in entrypoints.items():
if bs >= req_bs:
break
Expand Down
55 changes: 55 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"prompt": [
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo"
],
"neg_prompt": [
"Watermark, blurry, oversaturated, low resolution, pollution"
],
"height": [
1024
],
"width": [
1024
],
"steps": [
20
],
"guidance_scale": [
7.5
],
"seed": [
0
],
"output_type": [
"base64"
]
}
Loading

0 comments on commit 029d35e

Please sign in to comment.