Skip to content

Commit

Permalink
ai changes
Browse files Browse the repository at this point in the history
  • Loading branch information
BSalita committed Jan 26, 2024
1 parent 8cb3d99 commit 23831a6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mlBridgeLib/mlBridgeAi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ def train_classification(dls, epochs=3, monitor='accuracy', min_delta=0.001, pat
learn = tabular_learner(dls, metrics=accuracy)

# Train the model
learn.fit_one_cycle(epochs) #, cbs=EarlyStoppingCallback(monitor=monitor, min_delta=min_delta, patience=patience)) # sometimes only a couple epochs is optimal
# error: Can't get attribute 'AMPMode' on <module 'fastai.callback.fp16'
#learn.to_fp16() # to_fp32() or to_bf16()

# Use one cycle policy for training with early stopping
learn.fit_one_cycle(epochs, cbs=EarlyStoppingCallback(monitor=monitor, min_delta=min_delta, patience=patience)) # sometimes only a couple epochs is optimal

return learn

def train_regression(dls, epochs=20, layers=[200]*10, y_range=(0,1), monitor='valid_loss', min_delta=0.001, patience=3):
def train_regression(dls, epochs=20, layers=[200]*10, y_range=(0,1), monitor='valid_loss', min_delta=0.001, patience=3):
"""
Train a tabular model for regression.
"""
Expand Down

0 comments on commit 23831a6

Please sign in to comment.