-
Notifications
You must be signed in to change notification settings - Fork 18
/
predict.py
40 lines (31 loc) · 1.14 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
import feast
from joblib import load
class DriverRankingModel:
def __init__(self):
# Load model
self.model = load("driver_model.bin")
# Set up feature store
self.fs = feast.FeatureStore(repo_path="driver_ranking/")
def predict(self, driver_ids):
# Read features from Feast
driver_features = self.fs.get_online_features(
entity_rows=[{"driver_id": driver_id} for driver_id in driver_ids],
feature_refs=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
],
)
df = pd.DataFrame.from_dict(driver_features.to_dict())
# Make prediction
df["prediction"] = self.model.predict(df[sorted(df)])
# Choose best driver
best_driver_id = df["driver_id"].iloc[df["prediction"].argmax()]
# return best driver
return best_driver_id
if __name__ == "__main__":
drivers = [1001, 1002, 1003, 1004]
model = DriverRankingModel()
best_driver = model.predict(drivers)
print(best_driver)