-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_mmm_async.py
125 lines (100 loc) · 3.54 KB
/
test_mmm_async.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import requests
import json
import time
import pandas as pd
import numpy as np
import sys
def create_payload():
# Number of weeks
n_weeks = 52
# Generate weekly dates
dates = pd.date_range(start='2023-01-01', periods=n_weeks, freq='W')
# Generate synthetic sales data (random for example purposes)
sales = np.random.randint(1000, 5000, size=n_weeks)
# Generate synthetic marketing spend data
tv_spend = np.random.randint(1000, 3000, size=n_weeks)
online_spend = np.random.randint(500, 2000, size=n_weeks)
# Create DataFrame
data = pd.DataFrame({
'date': dates,
'sales': sales,
'tv': tv_spend,
'online': online_spend
})
# Convert DataFrame to JSON for payload
data_json = data.to_json(orient="split")
# Example payload
payload = {
"df": data_json,
"date_column": "date",
"channel_columns": ["tv", "online"],
"adstock_max_lag": 2,
"yearly_seasonality": 2
}
return payload
def create_payload_csv():
# Load the user-uploaded data file
data = pd.read_csv('mmm_example.csv')
# Rename the 'y' column to 'sales' and select relevant columns
data.rename(columns={'y': 'sales'}, inplace=True)
mmm_data = data[['date_week', 'sales', 'x1', 'x2']]
# Convert 'date_week' to datetime format
mmm_data.loc[:, 'date_week'] = pd.to_datetime(mmm_data['date_week']).dt.strftime('%Y-%m-%d')
# Convert the prepared data to JSON format for payload, ensuring proper formatting
data_json = mmm_data.to_json(orient="split", index=False)
#print(data_json)
# Example payload
payload = {
"df": data_json,
"date_column": "date_week",
"channel_columns": ["x1", "x2"],
"adstock_max_lag": 2,
"yearly_seasonality": 8
}
return payload
def test_async_mmm_run(base_url):
# Payload that includes data
payload = create_payload_csv()
# Replace with your API endpoint for async run
run_url = f"{base_url}/run_mmm_async"
# Make a POST request to initiate the model run
headers = {'Content-Type': 'application/json'}
response = requests.post(run_url, data=json.dumps(payload), headers=headers)
print(response)
# Assert the status code for initiation
assert response.status_code == 200
# Extract task_id
task_id = response.json()["task_id"]
print(f"Got task_id {task_id}")
# Polling URL
results_url = f"{base_url}/get_results?task_id={task_id}"
# Poll for results
while True:
result_response = requests.get(results_url)
result_data = result_response.json()
if result_data["status"] == "completed":
# Handle completed task
print("Task completed:", result_data)
# Perform additional assertions here
break
elif result_data["status"] == "failed":
# Handle failed task
print("Task failed:", result_data)
break
elif result_data["status"] == "pending":
# Wait before polling again
print("Pending...")
time.sleep(10)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python test_script.py [local|deployed]")
sys.exit(1)
environment = sys.argv[1]
if environment == "local":
base_url = "http://localhost:5001"
elif environment == "deployed":
base_url = "https://nextgen-mmm.pymc-labs.com"
else:
print("Invalid argument. Use 'local' or 'deployed'.")
sys.exit(1)
test_async_mmm_run(base_url)