Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shortfin-sd) Cleanup fiber distribution, logging, error handling. #555

Merged
merged 13 commits into from
Nov 18, 2024
Merged
1 change: 1 addition & 0 deletions shortfin/python/lib_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ template <typename CppType, typename KeepAlivePatient, typename... Args>
inline py::object custom_new_keep_alive(py::handle py_type,
KeepAlivePatient &keep_alive,
Args &&...args) {
py::set_leak_warnings(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this to the module registration function in lib_ext.cpp. it is a global.

Agreed this thing is flaky. ASAN does a more reliable job.

py::object self = custom_new<CppType>(py_type, std::forward<Args>(args)...);
py::detail::keep_alive(
self.ptr(),
Expand Down
8 changes: 2 additions & 6 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,13 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN):


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
vmfb_name = f"{filename}_{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace)


def get_cached_vmfb(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
vmfb_name = f"{filename}_{target}.vmfb"
return ctx.file(vmfb_name)


Expand Down Expand Up @@ -250,7 +247,6 @@ def sdxl(
params_filenames = get_params_filenames(model_params, model=model, splat=splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
Expand Down
1 change: 0 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
self.batcher = service.batcher
self.complete_infeed = self.system.create_queue()

@measure(type="throughput", num_items="num_output_images", freq=1, label="samples")
async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)
try:
Expand Down
5 changes: 3 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logging.info(f"Created local system with {self.ls.device_names} devices")
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
Expand All @@ -39,9 +39,10 @@ def start(self):
def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logging.info("System manager command processor stopped")
logger.info("System manager command processor stopped")
67 changes: 36 additions & 31 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from .tokenizer import Tokenizer
from .metrics import measure


logger = logging.getLogger("shortfin-sd.service")
logger.setLevel(logging.DEBUG)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
Expand Down Expand Up @@ -79,23 +77,32 @@ def __init__(

self.workers = []
self.fibers = []
self.fiber_status = []
self.idle_fibers = set()
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.workers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
self.workers.append(worker)
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.fibers_per_device):
fiber = sysman.ls.create_fiber(
self.workers[i % len(self.workers)], devices=[device]
)
tgt_worker = self.workers[i % len(self.workers)]
fiber = sysman.ls.create_fiber(tgt_worker, devices=[device])
self.fibers.append(fiber)
self.fiber_status.append(0)
self.idle_fibers.add(fiber)
for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}
# Scope dependent objects.
self.batcher = BatcherProcess(self)

def get_worker_index(self, fiber):
if fiber not in self.fibers:
raise ValueError("A worker was requested from a rogue fiber.")
fiber_idx = self.fibers.index(fiber)
worker_idx = int(
(fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker
)
return worker_idx

def load_inference_module(self, vmfb_path: Path, component: str = None):
if not self.inference_modules.get(component):
self.inference_modules[component] = []
Expand All @@ -112,7 +119,7 @@ def load_inference_parameters(
):
p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope)
for path in paths:
logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
p.load(path, format=format)
if not self.inference_parameters.get(component):
self.inference_parameters[component] = []
Expand All @@ -121,6 +128,7 @@ def load_inference_parameters(
def start(self):
# Initialize programs.
for component in self.inference_modules:
logger.info(f"Loading component: {component}")
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
Expand All @@ -141,7 +149,6 @@ def start(self):
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
logger.info("Program loaded.")

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
Expand Down Expand Up @@ -270,14 +277,17 @@ def board_flights(self):
return
self.strobes = 0
batches = self.sort_batches()
for idx, batch in batches.items():
for fidx, status in enumerate(self.service.fiber_status):
if (
status == 0
or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL
):
self.board(batch["reqs"], index=fidx)
break
for batch in batches.values():
# Assign the batch to the next idle fiber.
if len(self.service.idle_fibers) == 0:
return
fiber = self.service.idle_fibers.pop()
fiber_idx = self.service.fibers.index(fiber)
worker_idx = self.service.get_worker_index(fiber)
logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})")
self.board(batch["reqs"], fiber=fiber)
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(fiber)

def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
Expand Down Expand Up @@ -310,20 +320,18 @@ def sort_batches(self):
}
return batches

def board(self, request_bundle, index):
def board(self, request_bundle, fiber):
pending = request_bundle
if len(pending) == 0:
return
exec_process = InferenceExecutorProcess(self.service, index)
exec_process = InferenceExecutorProcess(self.service, fiber)
for req in pending:
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
exec_process.exec_requests.append(req)
if exec_process.exec_requests:
for flighted_request in exec_process.exec_requests:
self.pending_requests.remove(flighted_request)
if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL:
self.service.fiber_status[index] = 1
exec_process.launch()


Expand All @@ -338,15 +346,11 @@ class InferenceExecutorProcess(sf.Process):
def __init__(
self,
service: GenerateService,
index: int,
fiber,
):
super().__init__(fiber=service.fibers[index])
super().__init__(fiber=fiber)
self.service = service
self.fiber_index = index
self.worker_index = int(
(index - index % self.service.fibers_per_worker)
/ self.service.fibers_per_worker
)
self.worker_index = self.service.get_worker_index(fiber)
self.exec_requests: list[InferenceExecRequest] = []

@measure(type="exec", task="inference process")
Expand All @@ -360,7 +364,7 @@ async def run(self):
phase = req.phase
phases = self.exec_requests[0].phases
req_count = len(self.exec_requests)
device0 = self.service.fibers[self.fiber_index].device(0)
device0 = self.fiber.device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
Expand All @@ -375,7 +379,8 @@ async def run(self):
for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
self.service.fiber_status[self.fiber_index] = 0
if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(self.fiber)

except Exception:
logger.exception("Fatal error in image generation")
Expand Down Expand Up @@ -574,7 +579,7 @@ async def _denoise(self, device, requests):
for i, t in tqdm(
enumerate(range(step_count)),
disable=(not self.service.show_progress),
desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})",
desc=f"DENOISE (bs{req_bs})",
):
step = sfnp.device_array.for_device(device, [1], sfnp.sint64)
s_host = step.for_transfer()
Expand Down
44 changes: 29 additions & 15 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
import sys
import os
import io
import copy
import subprocess

Expand All @@ -35,7 +34,6 @@

logger = logging.getLogger("shortfin-sd")
logger.addHandler(native_handler)
logger.setLevel(logging.INFO)
logger.propagate = False

THIS_DIR = Path(__file__).resolve().parent
Expand All @@ -46,16 +44,16 @@ async def lifespan(app: FastAPI):
sysman.start()
try:
for service_name, service in services.items():
logging.info("Initializing service '%s':", service_name)
logging.info(str(service))
logger.info("Initializing service '%s':", service_name)
logger.info(str(service))
service.start()
except:
sysman.shutdown()
raise
yield
try:
for service_name, service in services.items():
logging.info("Shutting down service '%s'", service_name)
logger.info("Shutting down service '%s'", service_name)
service.shutdown()
finally:
sysman.shutdown()
Expand Down Expand Up @@ -83,11 +81,14 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
app.put("/generate")(generate_request)


def configure(args) -> SystemManager:
def configure_sys(args) -> SystemManager:
# Setup system (configure devices, etc).
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
return sysman, model_config, flagfile, tuning_spec


def configure_service(args, sysman, model_config, flagfile, tuning_spec):
# Setup each service we are hosting.
tokenizers = []
for idx, tok_name in enumerate(args.tokenizers):
Expand Down Expand Up @@ -135,7 +136,7 @@ def get_configs(args):
f"--model={modelname}",
f"--topology={topology_inp}",
]
outs = subprocess.check_output(cfg_builder_args).decode()
outs = subprocess.check_output(cfg_builder_args, stderr=subprocess.DEVNULL).decode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes debugging impossible... We really need to finish the build API and get this stuff more for free

outs_paths = outs.splitlines()
for i in outs_paths:
if "sdxl_config" in i and not args.model_config:
Expand All @@ -158,18 +159,19 @@ def get_configs(args):
arglist = spec.strip("--").split("=")
arg = arglist[0]
if len(arglist) > 2:
print(arglist)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray print?

value = arglist[1:]
for val in value:
try:
val = int(val)
except ValueError:
continue
val = val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you doing val=val and value=value below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly for the device ids arg, where user can pass either device indexes (ints) or device UIDs (strings).
The logic being: if we can't convert them to ints, then we pass them as they were recieved.
doing continue here breaks parsing of non-list/int arguments e.g. isolation.

elif len(arglist) == 2:
value = arglist[-1]
try:
value = int(value)
except ValueError:
continue
value = value
else:
# It's a boolean arg.
value = True
Expand All @@ -178,7 +180,6 @@ def get_configs(args):
# It's an env var.
arglist = spec.split("=")
os.environ[arglist[0]] = arglist[1]

return model_config, topology_config, flagfile, tuning_spec, args


Expand Down Expand Up @@ -222,7 +223,14 @@ def get_modules(args, model_config, flagfile, td_spec):
f"--iree-hip-target={args.target}",
f"--iree-compile-extra-args={' '.join(ireec_args)}",
]
output = subprocess.check_output(builder_args).decode()
logger.info(f"Preparing runtime artifacts for {modelname}...")
logger.debug(
f"COMMAND LINE EQUIVALENT: "
+ " ".join([str(argn) for argn in builder_args])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using an fstring, don't also concat like this.

)
output = subprocess.check_output(
builder_args, stderr=subprocess.DEVNULL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Debugging not good.

).decode()

output_paths = output.splitlines()
filenames.extend(output_paths)
Expand Down Expand Up @@ -257,7 +265,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
type=str,
required=False,
default="gfx942",
choices=["gfx942", "gfx1100"],
choices=["gfx942", "gfx1100", "gfx90a"],
help="Primary inferencing device LLVM target arch.",
)
parser.add_argument(
Expand Down Expand Up @@ -297,7 +305,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser.add_argument(
"--isolation",
type=str,
default="per_fiber",
default="per_call",
choices=["per_fiber", "per_call", "none"],
help="Concurrency control -- How to isolate programs.",
)
Expand Down Expand Up @@ -371,9 +379,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
home = Path.home()
artdir = home / ".cache" / "shark"
args.artifacts_dir = str(artdir)
else:
args.artifacts_dir = Path(args.artifacts_dir).resolve()

global sysman
sysman = configure(args)
sysman, model_config, flagfile, tuning_spec = configure_sys(args)
configure_service(args, sysman, model_config, flagfile, tuning_spec)
uvicorn.run(
app,
host=args.host,
Expand All @@ -393,8 +404,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(levelname)s - %(message)s",
"()": "uvicorn.logging.DefaultFormatter",
"format": "[{asctime}] {message}",
"datefmt": "%Y-%m-%d %H:%M:%S",
"style": "{",
"use_colors": True,
},
},
"handlers": {
Expand Down
Loading
Loading