Skip to content

Commit

Permalink
fix: load model only when cache is not found
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Fan <fany@buaa.edu.cn>
  • Loading branch information
FuryMartin committed Oct 6, 2024
1 parent 37f4737 commit 5857c58
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def load(self, **kwargs):
self.model = APIBasedLLM(**self.kwargs)
else:
raise Exception(f"Backend {self.backend} is not supported")

self.model.load(model_url=self.model_name)

# TODO cloud service must be configured in JointInference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ def __init__(self, **kwargs) -> None:
self.config = kwargs
self._parse_kwargs(**kwargs)
self.is_cache_loaded = False
self.model_loaded = False

def load(self):
raise NotImplementedError

def _parse_kwargs(self, **kwargs):
self.model_name = kwargs.get("model", None)
self.quantization = kwargs.get("quantization", "full")
self.temperature = kwargs.get("temperature", 0.8)
self.top_p = kwargs.get("top_p", 0.8)
Expand All @@ -47,14 +49,17 @@ def inference(self, data):
if messages[0]['role'] == "system":
system_prompt = messages[0]["content"]
else:
system_prompt = ""
system_prompt = None

question = messages[-1]["content"]

if self.use_cache:
response = self._try_cache(question, system_prompt)
if response is not None:
return response

if not self.model_loaded:
self.load(self.model_name)

response = self._infer(messages)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ def load(self, model_url):
model_url,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
# quantization = self.quantization # Need to align with HF API
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,5 @@ def cleanup(self):
destroy_model_parallel()
destroy_distributed_environment()

del self.model.llm_engine.model_executor
if hasattr(self, "model"):
del self.model.llm_engine.model_executor
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ algorithm:
# name of the hyperparameter; string type;
- model:
values:
- "Qwen/Qwen2.5-0.5B-Instruct"
- "Qwen/Qwen2.5-1.5B-Instruct"
- "Qwen/Qwen2.5-3B-Instruct"
- "Qwen/Qwen2.5-7B-Instruct"
- backend:
values:
- "vllm"
Expand All @@ -32,7 +34,7 @@ algorithm:
- top_p:
values:
- 0.8
- max_token:
- max_tokens:
values:
- 512
- repetition_penalty:
Expand All @@ -41,7 +43,7 @@ algorithm:
- tensor_parallel_size:
values:
# 1 or total count of gpu
- 1
- 4
- gpu_memory_utilization:
values:
- 0.9
Expand All @@ -68,7 +70,7 @@ algorithm:
- top_p:
values:
- 0.8
- max_token:
- max_tokens:
values:
- 512
- repetition_penalty:
Expand All @@ -80,7 +82,7 @@ algorithm:

- type: "hard_example_mining"
# name of python module; string type;
# BERTRouter, EdgeOnly, CloudOnly, RandomRouter
# BERTRouter, EdgeOnly, CloudOnly, RandomRouter, OracleRouter
name: "OracleRouter"
# the url address of python module; string type;
url: "./examples/cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/hard_sample_mining.py"
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ def edge_rate(y_true, y_pred):

y_pred = [pred.is_hard_example for pred in infer_res]

return round(1 - sum(y_pred) / len(y_pred),2)
edge_rate = 1 - sum(y_pred) / len(y_pred)

return round(edge_rate * 100,2)

0 comments on commit 5857c58

Please sign in to comment.