Skip to content

Commit

Permalink
🎨✅ Fix cuda related test failure and uncommit runtime_config.yaml
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
  • Loading branch information
gkumbhat committed Sep 19, 2023
1 parent 65471a2 commit 8b73e37
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,9 +1184,9 @@ def convert_peft_model_to_type(
# then move the peft model to that type on our training device.
torch_dtype = get_torch_dtype(torch_dtype)
# If our requested dtype is bfloat16 & we don't support it, fall back to float32
if (
if torch_dtype == torch.bfloat16 and (
device == "cpu" or not torch.cuda.is_bf16_supported()
) and torch_dtype == torch.bfloat16:
):
log.warning(
"<NLP18555772W>",
"Requested data type torch.bfloat16 is unsupported; falling back to torch.float32",
Expand Down
19 changes: 18 additions & 1 deletion runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ model_management:
finders:
default:
type: LOCAL
remote_tgis:
type: TGIS-AUTO
config:
test_connection: true
initializers:
default:
type: LOCAL
Expand All @@ -23,4 +27,17 @@ model_management:
load_timeout: 120
grpc_port: null
http_port: null
health_poll_delay: 1.0
health_poll_delay: 1.0
remote_models:
flan-t5-xl:
hostname: localhost:8033
prompt_dir: tgis_prompts
llama-70b:
hostname: localhost:8034
prompt_dir: tgis_prompts

connection:
hostname: "foo.{model_id}:1234"
ca_cert_file: null
client_cert_file: null
client_key_file: null

0 comments on commit 8b73e37

Please sign in to comment.