Skip to content

Commit

Permalink
Suppress nanobind leak messages and cleanup logging.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Nov 17, 2024
1 parent a786114 commit 1956991
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 57 deletions.
1 change: 1 addition & 0 deletions shortfin/python/lib_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ template <typename CppType, typename KeepAlivePatient, typename... Args>
inline py::object custom_new_keep_alive(py::handle py_type,
KeepAlivePatient &keep_alive,
Args &&...args) {
py::set_leak_warnings(false);
py::object self = custom_new<CppType>(py_type, std::forward<Args>(args)...);
py::detail::keep_alive(
self.ptr(),
Expand Down
10 changes: 4 additions & 6 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def needs_update(ctx):

def needs_file(filename, ctx, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
print("__________________")
print(out_file)
if os.path.exists(out_file):
needed = False
else:
Expand All @@ -177,16 +179,13 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN):


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
vmfb_name = f"{filename}_{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace)


def get_cached_vmfb(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
vmfb_name = f"{filename}_{target}.vmfb"
return ctx.file(vmfb_name)


Expand Down Expand Up @@ -250,7 +249,6 @@ def sdxl(
params_filenames = get_params_filenames(model_params, model=model, splat=splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
Expand Down
1 change: 0 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
self.batcher = service.batcher
self.complete_infeed = self.system.create_queue()

@measure(type="throughput", num_items="num_output_images", freq=1, label="samples")
async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)
try:
Expand Down
5 changes: 3 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logging.info(f"Created local system with {self.ls.device_names} devices")
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
Expand All @@ -39,9 +39,10 @@ def start(self):
def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logging.info("System manager command processor stopped")
logger.info("System manager command processor stopped")
69 changes: 40 additions & 29 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from .tokenizer import Tokenizer
from .metrics import measure


logger = logging.getLogger("shortfin-sd.service")
logger.setLevel(logging.DEBUG)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
Expand Down Expand Up @@ -79,23 +77,35 @@ def __init__(

self.workers = []
self.fibers = []
self.fiber_status = []
self.idle_fibers = set()
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.workers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
self.workers.append(worker)
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.fibers_per_device):
tgt_worker = self.workers[i % len(self.workers)]
fiber = sysman.ls.create_fiber(
self.workers[i % len(self.workers)], devices=[device]
tgt_worker, devices=[device]
)
self.fibers.append(fiber)
self.fiber_status.append(0)
self.idle_fibers.add(fiber)
for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}
# Scope dependent objects.
self.batcher = BatcherProcess(self)

def get_worker_index(self, fiber):
if fiber not in self.fibers:
raise ValueError("A worker was requested from a rogue fiber.")
fiber_idx = self.fibers.index(fiber)
worker_idx = int(
(fiber_idx - fiber_idx % self.fibers_per_worker)
/ self.fibers_per_worker
)
return worker_idx

def load_inference_module(self, vmfb_path: Path, component: str = None):
if not self.inference_modules.get(component):
self.inference_modules[component] = []
Expand All @@ -112,7 +122,7 @@ def load_inference_parameters(
):
p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope)
for path in paths:
logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
p.load(path, format=format)
if not self.inference_parameters.get(component):
self.inference_parameters[component] = []
Expand All @@ -121,6 +131,9 @@ def load_inference_parameters(
def start(self):
# Initialize programs.
for component in self.inference_modules:
logger.info(
f"Loading component: {component}"
)
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
Expand All @@ -141,7 +154,6 @@ def start(self):
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
logger.info("Program loaded.")

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
Expand Down Expand Up @@ -270,14 +282,18 @@ def board_flights(self):
return
self.strobes = 0
batches = self.sort_batches()
for idx, batch in batches.items():
for fidx, status in enumerate(self.service.fiber_status):
if (
status == 0
or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL
):
self.board(batch["reqs"], index=fidx)
break
for batch in batches.values():
# Assign the batch to the next idle fiber.
if len(self.service.idle_fibers) == 0:
return
fiber = self.service.idle_fibers.pop()
fiber_idx = self.service.fibers.index(fiber)
worker_idx = self.service.get_worker_index(fiber)
logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})")
self.board(batch["reqs"], fiber=fiber)
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(fiber)


def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
Expand Down Expand Up @@ -310,20 +326,18 @@ def sort_batches(self):
}
return batches

def board(self, request_bundle, index):
def board(self, request_bundle, fiber):
pending = request_bundle
if len(pending) == 0:
return
exec_process = InferenceExecutorProcess(self.service, index)
exec_process = InferenceExecutorProcess(self.service, fiber)
for req in pending:
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
exec_process.exec_requests.append(req)
if exec_process.exec_requests:
for flighted_request in exec_process.exec_requests:
self.pending_requests.remove(flighted_request)
if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL:
self.service.fiber_status[index] = 1
exec_process.launch()


Expand All @@ -338,15 +352,11 @@ class InferenceExecutorProcess(sf.Process):
def __init__(
self,
service: GenerateService,
index: int,
fiber,
):
super().__init__(fiber=service.fibers[index])
super().__init__(fiber=fiber)
self.service = service
self.fiber_index = index
self.worker_index = int(
(index - index % self.service.fibers_per_worker)
/ self.service.fibers_per_worker
)
self.worker_index = self.service.get_worker_index(fiber)
self.exec_requests: list[InferenceExecRequest] = []

@measure(type="exec", task="inference process")
Expand All @@ -360,7 +370,7 @@ async def run(self):
phase = req.phase
phases = self.exec_requests[0].phases
req_count = len(self.exec_requests)
device0 = self.service.fibers[self.fiber_index].device(0)
device0 = self.fiber.device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
Expand All @@ -375,7 +385,8 @@ async def run(self):
for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
self.service.fiber_status[self.fiber_index] = 0
if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(self.fiber)

except Exception:
logger.exception("Fatal error in image generation")
Expand Down Expand Up @@ -574,7 +585,7 @@ async def _denoise(self, device, requests):
for i, t in tqdm(
enumerate(range(step_count)),
disable=(not self.service.show_progress),
desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})",
desc=f"DENOISE (bs{req_bs})",
):
step = sfnp.device_array.for_device(device, [1], sfnp.sint64)
s_host = step.for_transfer()
Expand Down
40 changes: 23 additions & 17 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
import sys
import os
import io
import copy
import subprocess

Expand All @@ -35,32 +34,29 @@

logger = logging.getLogger("shortfin-sd")
logger.addHandler(native_handler)
logger.setLevel(logging.INFO)
logger.propagate = False

THIS_DIR = Path(__file__).resolve().parent


@asynccontextmanager
async def lifespan(app: FastAPI):
sysman.start()
try:
for service_name, service in services.items():
logging.info("Initializing service '%s':", service_name)
logging.info(str(service))
logger.info("Initializing service '%s':", service_name)
logger.info(str(service))
service.start()
except:
sysman.shutdown()
raise
yield
try:
for service_name, service in services.items():
logging.info("Shutting down service '%s'", service_name)
logger.info("Shutting down service '%s'", service_name)
service.shutdown()
finally:
sysman.shutdown()


sysman: SystemManager
services: dict[str, Any] = {}
app = FastAPI(lifespan=lifespan)
Expand All @@ -83,11 +79,13 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
app.put("/generate")(generate_request)


def configure(args) -> SystemManager:
def configure_sys(args) -> SystemManager:
# Setup system (configure devices, etc).
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
return sysman, model_config, flagfile, tuning_spec

def configure_service(args, sysman, model_config, flagfile, tuning_spec):
# Setup each service we are hosting.
tokenizers = []
for idx, tok_name in enumerate(args.tokenizers):
Expand Down Expand Up @@ -135,7 +133,7 @@ def get_configs(args):
f"--model={modelname}",
f"--topology={topology_inp}",
]
outs = subprocess.check_output(cfg_builder_args).decode()
outs = subprocess.check_output(cfg_builder_args, stderr=subprocess.DEVNULL).decode()
outs_paths = outs.splitlines()
for i in outs_paths:
if "sdxl_config" in i and not args.model_config:
Expand All @@ -158,18 +156,19 @@ def get_configs(args):
arglist = spec.strip("--").split("=")
arg = arglist[0]
if len(arglist) > 2:
print(arglist)
value = arglist[1:]
for val in value:
try:
val = int(val)
except ValueError:
continue
val = val
elif len(arglist) == 2:
value = arglist[-1]
try:
value = int(value)
except ValueError:
continue
value = value
else:
# It's a boolean arg.
value = True
Expand All @@ -178,7 +177,6 @@ def get_configs(args):
# It's an env var.
arglist = spec.split("=")
os.environ[arglist[0]] = arglist[1]

return model_config, topology_config, flagfile, tuning_spec, args


Expand Down Expand Up @@ -222,7 +220,9 @@ def get_modules(args, model_config, flagfile, td_spec):
f"--iree-hip-target={args.target}",
f"--iree-compile-extra-args={' '.join(ireec_args)}",
]
output = subprocess.check_output(builder_args).decode()
logger.info(f"Preparing runtime artifacts for {modelname}...")
logger.debug(f"COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]))
output = subprocess.check_output(builder_args, stderr=subprocess.DEVNULL).decode()

output_paths = output.splitlines()
filenames.extend(output_paths)
Expand Down Expand Up @@ -257,7 +257,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
type=str,
required=False,
default="gfx942",
choices=["gfx942", "gfx1100"],
choices=["gfx942", "gfx1100", "gfx90a"],
help="Primary inferencing device LLVM target arch.",
)
parser.add_argument(
Expand Down Expand Up @@ -297,7 +297,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser.add_argument(
"--isolation",
type=str,
default="per_fiber",
default="per_call",
choices=["per_fiber", "per_call", "none"],
help="Concurrency control -- How to isolate programs.",
)
Expand Down Expand Up @@ -371,9 +371,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
home = Path.home()
artdir = home / ".cache" / "shark"
args.artifacts_dir = str(artdir)
else:
args.artifacts_dir = Path(args.artifacts_dir).resolve()

global sysman
sysman = configure(args)
sysman, model_config, flagfile, tuning_spec = configure_sys(args)
configure_service(args, sysman, model_config, flagfile, tuning_spec)
uvicorn.run(
app,
host=args.host,
Expand All @@ -393,8 +396,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(levelname)s - %(message)s",
"()": "uvicorn.logging.DefaultFormatter",
"format": "[{asctime}] {message}",
"datefmt": "%Y-%m-%d %H:%M:%S",
"style": "{",
"use_colors": True,
},
},
"handlers": {
Expand Down
Loading

0 comments on commit 1956991

Please sign in to comment.