Skip to content

Commit

Permalink
update core
Browse files Browse the repository at this point in the history
  • Loading branch information
bmandracchia committed Oct 15, 2024
1 parent e4d7a9e commit 2708a1c
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 41 deletions.
51 changes: 51 additions & 0 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,57 @@
" print('Inferred learning rate: ', lr)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class visionTrainer(Learner):\n",
" \"\"\"\n",
" A custom implementation of the FastAI Learner class for training models in bioinformatics applications.\n",
"\n",
" \"\"\"\n",
" \n",
" def __init__(self, \n",
" dataloaders: DataLoaders, # The DataLoader objects containing training and validation datasets.\n",
" model: callable, # A callable model that will be trained on the dataset.\n",
" loss_fn: Any | None = None, # The loss function to optimize during training. If None, defaults to a suitable default.\n",
" optimizer: Optimizer | OptimWrapper = Adam, # The optimizer function to use. Defaults to Adam if not specified.\n",
" lr: float | slice = 1e-3, # Learning rate for the optimizer. Can be a float or a slice object for learning rate scheduling.\n",
" splitter: callable = trainable_params, # \n",
" callbacks: Callback | MutableSequence | None = None, # A callable that determines which parameters of the model should be updated during training.\n",
" metrics: Any | MutableSequence | None = None, # Optional list of callback functions to customize training behavior.\n",
" csv_log: bool = False, # Metrics to evaluate the performance of the model during training.\n",
" show_graph: bool = True, # Whether to log training history to a CSV file. If True, logs will be appended to 'history.csv'.\n",
" show_summary: bool = False, # The base directory where models are saved or loaded from. Defaults to None.\n",
" find_lr: bool = False, # Subdirectory within the base path where trained models are stored. Default is 'models'.\n",
" find_lr_fn = valley, # Weight decay factor for optimization. Defaults to None.\n",
" path: str | Path | None = None, # Whether to apply weight decay to batch normalization and bias parameters.\n",
" model_dir: str | Path = 'models', # Whether to update the batch normalization statistics during training.\n",
" wd: float | int | None = None, \n",
" wd_bn_bias: bool = False, \n",
" train_bn: bool = True, \n",
" moms: tuple = ..., # Tuple of tuples representing the momentum values for different layers in the model. Defaults to FastAI's default settings if not specified.\n",
" default_cbs: bool = True, # Automatically include default callbacks such as ShowGraphCallback and CSVLogger.\n",
" ):\n",
" cbs = callbacks if callbacks is not None else [] # Ensure cbs is a list\n",
" if default_cbs:\n",
" if show_graph:\n",
" cbs.append(ShowGraphCallback())\n",
" if csv_log:\n",
" cbs.append(CSVLogger(fname='history.csv', append=False))\n",
" \n",
" super().__init__(dataloaders, model, loss_fn, optimizer, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms)\n",
" \n",
" if show_summary:\n",
" print(self.summary())\n",
" if find_lr:\n",
" self.lr_find(suggest_funcs=find_lr_fn)\n",
" lr = float('%.1g'%(lr))\n",
" print('Inferred learning rate: ', lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
9 changes: 6 additions & 3 deletions nbs/07_callbacks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
}
],
"source": [
"show_doc(ShortEpochCallback)\n"
"show_doc(ShortEpochCallback)"
]
},
{
Expand Down Expand Up @@ -124,7 +124,6 @@
}
],
"source": [
"\n",
"show_doc(GradientAccumulation)"
]
},
Expand Down Expand Up @@ -535,9 +534,13 @@
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"display_name": "bioMONAI-env",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
6 changes: 5 additions & 1 deletion nbs/901_demo_RI2FL_2d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"display_name": "bioMONAI-env",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 2708a1c

Please sign in to comment.