-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
doc: add comments and update quick guide
Signed-off-by: Yu Fan <fany@buaa.edu.cn>
- Loading branch information
1 parent
0c5f643
commit 520d514
Showing
26 changed files
with
1,082 additions
and
225 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 64 additions & 31 deletions
95
examples/cloud-edge-collaborative-inference-for-llm/README.md
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file modified
BIN
+122 KB
(280%)
examples/cloud-edge-collaborative-inference-for-llm/assets/Oracle Router Demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
examples/cloud-edge-collaborative-inference-for-llm/performance-cost-plot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
from scipy.optimize import curve_fit | ||
|
||
colors = plt.cm.Paired.colors # Set1 调色板 | ||
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", colors) | ||
|
||
# a sigmoid function to fit non-oracle models' performance vs cost | ||
def sigmoid_fit(x, L, k, x0): | ||
return L / (1 + np.exp(-k * (x - x0))) | ||
|
||
def plot_accuracy_cost(models, costs, accuracy, non_oracle_costs, non_oracle_accuracy): | ||
# Fit the sigmoid model | ||
params_sigmoid, _ = curve_fit(sigmoid_fit, non_oracle_costs, non_oracle_accuracy, p0=[100, 1, 0.2]) | ||
|
||
# Generate points for the sigmoid fitted curve | ||
curve_x_sigmoid = np.linspace(min(non_oracle_costs), max(non_oracle_costs), 100) | ||
curve_y_sigmoid = sigmoid_fit(curve_x_sigmoid, *params_sigmoid) | ||
|
||
plt.figure(figsize=(10, 6)) | ||
|
||
# Plot all models | ||
for i in range(len(models)): | ||
if "Oracle" in models[i]: | ||
marker = '^' # Triangle marker for Oracle models | ||
else: | ||
marker = 'o' # Circle marker for non-Oracle models | ||
plt.scatter(costs[i], accuracy[i], label=models[i], marker=marker) | ||
|
||
# Plot the sigmoid fitted curve | ||
plt.plot(curve_x_sigmoid, curve_y_sigmoid, 'gray', linestyle='dashed') # Gray dashed line for the curve | ||
|
||
plt.title('Model Performance vs Cost') | ||
plt.xlabel('Cost($/M token)') | ||
plt.ylabel('Accuracy (%)') | ||
plt.legend(title='Model Name') | ||
plt.grid(True) | ||
plt.savefig('model_performance_sigmoid_fitted_curve.png', dpi=300) | ||
plt.show() | ||
|
||
if __name__ == '__main__': | ||
models = [ | ||
"Oracle-Qwen2.5-7b-instruct + gpt-4o-mini", | ||
"Oracle-Qwen2.5-1.5b-instruct + gpt-4o-mini", | ||
"Oracle-Qwen2.5-3b-instruct + gpt-4o-mini", | ||
"gpt-4o-mini", | ||
"Qwen2.5-7B-Instruct", | ||
"Qwen2.5-3B-Instruct", | ||
"Qwen2.5-1.5B-Instruct" | ||
] | ||
# The Oracle Routed Model's cost is an average weighted by the Edge Ratio between edge model costs and cloud model costs. | ||
# The edge model’s cost is estimated based on its parameter size. | ||
costs = [0.16, 0.18, 0.17, 0.60, 0.10, 0.08, 0.05] | ||
accuracy = [84.22, 82.75, 82.22, 75.99, 71.84, 60.3, 58.35] | ||
|
||
# Non Oracle Models: gpt-4o-mini, Qwen2.5-7B-Instruct, Qwen2.5-3B-Instruct, Qwen2.5-1.5B-Instruct | ||
non_oracle_costs = costs[-4:] # Costs in $/M token | ||
non_oracle_accuracy = accuracy[-4:] # Accuracies in % | ||
|
||
plot_accuracy_cost(models, costs, accuracy, non_oracle_costs, non_oracle_accuracy) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 28 additions & 1 deletion
29
...cloud-edge-collaborative-inference-for-llm/testalgorithms/query-routing/data_processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,39 @@ | ||
import numpy as np | ||
# Copyright 2024 The KubeEdge Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from sedna.common.class_factory import ClassFactory, ClassType | ||
from sedna.datasources import BaseDataSource | ||
|
||
@ClassFactory.register(ClassType.GENERAL, alias="OracleRouterDatasetProcessor") | ||
class OracleRouterDatasetProcessor: | ||
""" A Customized Dataset Processor for Oracle Router""" | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
def __call__(self, dataset): | ||
"""Transform the dataset to another format for Oracle Router | ||
Parameters | ||
---------- | ||
dataset : sedna.datasources.BaseDataSource | ||
The dataset loaded by Sedna | ||
Returns | ||
------- | ||
sedna.datasources.BaseDataSource | ||
Transformed dataset | ||
""" | ||
dataset.x = [{"query": x, "gold": y} for x,y in zip(dataset.x, dataset.y)] | ||
return dataset |
Oops, something went wrong.