Skip to content

Commit

Permalink
fix missing categories
Browse files Browse the repository at this point in the history
  • Loading branch information
BSalita committed Sep 5, 2024
1 parent fe87f20 commit a305346
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,10 +864,19 @@ def Predict_Game_Results(df):
learn = mlBridgeAi.load_model(predicted_contracts_model_file)
print_to_log_debug('isna:',df.isna().sum())
contracts_all = ['PASS']+[str(level+1)+strain+dbl+direction for level in range(7) for strain in 'CDHSN' for dbl in ['','X','XX'] for direction in 'NESW']
# Check if all contracts in df are in contracts_all
unknown_contracts = set(df['Contract'].unique()) - set(contracts_all)
if unknown_contracts:
print(f"Warning: Unknown contracts found: {unknown_contracts}")
# Add unknown contracts to contracts_all
contracts_all.extend(unknown_contracts)
df['Contract'] = pd.Categorical(df['Contract'], categories=contracts_all)
#df['Contract'] = df['Contract'].astype('string')
print(df['Contract'])
#df = df.drop(df[~df['Contract'].isin(learn.dls.vocab)].index)
# Ensure the model's vocabulary matches the current data
if set(df['Contract'].cat.categories) != set(learn.dls.vocab):
print("Warning: Mismatch between model vocabulary and data categories.")
print(f"Model vocab: {learn.dls.vocab}")
print(f"Data categories: {df['Contract'].cat.categories}")
# You might want to retrain the model or adjust the data here
assert df['Contract'].isin(mlBridgeLib.contract_classes).all(), df['Contract'][~df['Contract'].isin(mlBridgeLib.contract_classes)]
#df[learn.dls.y_names[0]] = pd.Categorical(df[learn.dls.y_names[0]], categories=learn.dls.vocab)
#import pickle
Expand Down

0 comments on commit a305346

Please sign in to comment.