-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(shortfin-sd) Cleanup fiber distribution, logging, error handling. #555
Changes from 5 commits
1956991
1f87373
a6dfbca
0022e7f
62b051f
a08c4b8
a857653
edf0350
538f4a2
370766a
917e1d4
d217b7d
553028f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ | |
from pathlib import Path | ||
import sys | ||
import os | ||
import io | ||
import copy | ||
import subprocess | ||
|
||
|
@@ -35,7 +34,6 @@ | |
|
||
logger = logging.getLogger("shortfin-sd") | ||
logger.addHandler(native_handler) | ||
logger.setLevel(logging.INFO) | ||
logger.propagate = False | ||
|
||
THIS_DIR = Path(__file__).resolve().parent | ||
|
@@ -46,16 +44,16 @@ async def lifespan(app: FastAPI): | |
sysman.start() | ||
try: | ||
for service_name, service in services.items(): | ||
logging.info("Initializing service '%s':", service_name) | ||
logging.info(str(service)) | ||
logger.info("Initializing service '%s':", service_name) | ||
logger.info(str(service)) | ||
service.start() | ||
except: | ||
sysman.shutdown() | ||
raise | ||
yield | ||
try: | ||
for service_name, service in services.items(): | ||
logging.info("Shutting down service '%s'", service_name) | ||
logger.info("Shutting down service '%s'", service_name) | ||
service.shutdown() | ||
finally: | ||
sysman.shutdown() | ||
|
@@ -83,11 +81,14 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): | |
app.put("/generate")(generate_request) | ||
|
||
|
||
def configure(args) -> SystemManager: | ||
def configure_sys(args) -> SystemManager: | ||
# Setup system (configure devices, etc). | ||
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) | ||
sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) | ||
return sysman, model_config, flagfile, tuning_spec | ||
|
||
|
||
def configure_service(args, sysman, model_config, flagfile, tuning_spec): | ||
# Setup each service we are hosting. | ||
tokenizers = [] | ||
for idx, tok_name in enumerate(args.tokenizers): | ||
|
@@ -135,7 +136,7 @@ def get_configs(args): | |
f"--model={modelname}", | ||
f"--topology={topology_inp}", | ||
] | ||
outs = subprocess.check_output(cfg_builder_args).decode() | ||
outs = subprocess.check_output(cfg_builder_args, stderr=subprocess.DEVNULL).decode() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes debugging impossible... We really need to finish the build API and get this stuff more for free |
||
outs_paths = outs.splitlines() | ||
for i in outs_paths: | ||
if "sdxl_config" in i and not args.model_config: | ||
|
@@ -158,18 +159,19 @@ def get_configs(args): | |
arglist = spec.strip("--").split("=") | ||
arg = arglist[0] | ||
if len(arglist) > 2: | ||
print(arglist) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stray print? |
||
value = arglist[1:] | ||
for val in value: | ||
try: | ||
val = int(val) | ||
except ValueError: | ||
continue | ||
val = val | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you doing val=val and value=value below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly for the device ids arg, where user can pass either device indexes (ints) or device UIDs (strings). |
||
elif len(arglist) == 2: | ||
value = arglist[-1] | ||
try: | ||
value = int(value) | ||
except ValueError: | ||
continue | ||
value = value | ||
else: | ||
# It's a boolean arg. | ||
value = True | ||
|
@@ -178,7 +180,6 @@ def get_configs(args): | |
# It's an env var. | ||
arglist = spec.split("=") | ||
os.environ[arglist[0]] = arglist[1] | ||
|
||
return model_config, topology_config, flagfile, tuning_spec, args | ||
|
||
|
||
|
@@ -222,7 +223,14 @@ def get_modules(args, model_config, flagfile, td_spec): | |
f"--iree-hip-target={args.target}", | ||
f"--iree-compile-extra-args={' '.join(ireec_args)}", | ||
] | ||
output = subprocess.check_output(builder_args).decode() | ||
logger.info(f"Preparing runtime artifacts for {modelname}...") | ||
logger.debug( | ||
f"COMMAND LINE EQUIVALENT: " | ||
+ " ".join([str(argn) for argn in builder_args]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If using an fstring, don't also concat like this. |
||
) | ||
output = subprocess.check_output( | ||
builder_args, stderr=subprocess.DEVNULL | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. Debugging not good. |
||
).decode() | ||
|
||
output_paths = output.splitlines() | ||
filenames.extend(output_paths) | ||
|
@@ -257,7 +265,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): | |
type=str, | ||
required=False, | ||
default="gfx942", | ||
choices=["gfx942", "gfx1100"], | ||
choices=["gfx942", "gfx1100", "gfx90a"], | ||
help="Primary inferencing device LLVM target arch.", | ||
) | ||
parser.add_argument( | ||
|
@@ -297,7 +305,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): | |
parser.add_argument( | ||
"--isolation", | ||
type=str, | ||
default="per_fiber", | ||
default="per_call", | ||
choices=["per_fiber", "per_call", "none"], | ||
help="Concurrency control -- How to isolate programs.", | ||
) | ||
|
@@ -371,9 +379,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): | |
home = Path.home() | ||
artdir = home / ".cache" / "shark" | ||
args.artifacts_dir = str(artdir) | ||
else: | ||
args.artifacts_dir = Path(args.artifacts_dir).resolve() | ||
|
||
global sysman | ||
sysman = configure(args) | ||
sysman, model_config, flagfile, tuning_spec = configure_sys(args) | ||
configure_service(args, sysman, model_config, flagfile, tuning_spec) | ||
uvicorn.run( | ||
app, | ||
host=args.host, | ||
|
@@ -393,8 +404,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): | |
"disable_existing_loggers": False, | ||
"formatters": { | ||
"default": { | ||
"format": "%(asctime)s - %(levelname)s - %(message)s", | ||
"()": "uvicorn.logging.DefaultFormatter", | ||
"format": "[{asctime}] {message}", | ||
"datefmt": "%Y-%m-%d %H:%M:%S", | ||
"style": "{", | ||
"use_colors": True, | ||
}, | ||
}, | ||
"handlers": { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move this to the module registration function in lib_ext.cpp. it is a global.
Agreed this thing is flaky. ASAN does a more reliable job.