-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from owczr/develop
Add scripts to train model and submit a job on Azure ML
- Loading branch information
Showing
7 changed files
with
317 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
numpy==1.23.5 | ||
#pydicom==2.4.4 | ||
pydicom==2.4.4 | ||
scikit-image==0.20.0 | ||
tensorflow==2.12.0 | ||
tqdm==4.65.0 | ||
pytest==7.4.0 | ||
click==8.0.4 | ||
azure-ai-ml==1.12.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
FROM continuumio/miniconda3 | ||
|
||
RUN conda update -n base -c defaults conda | ||
|
||
COPY conda_dependencies.yaml . | ||
RUN conda env create -f conda_dependencies.yaml -q && \ | ||
rm conda_dependencies.yaml && \ | ||
conda run pip cache purge && \ | ||
conda clean -a -y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: conda | ||
channels: | ||
- conda-forge | ||
- defaults | ||
- anaconda | ||
dependencies: | ||
- python | ||
- pip | ||
- pip: | ||
- numpy==1.23.5 | ||
- pydicom==2.4.4 | ||
- scikit-image==0.20.0 | ||
- tensorflow==2.12.0 | ||
- tqdm==4.65.0 | ||
- pytest==7.4.0 | ||
- click==8.0.4 | ||
- azure-ai-ml==1.12.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json | ||
name: cancer-env | ||
build: | ||
path: Dockerfile |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
from datetime import datetime | ||
|
||
import click | ||
from dotenv import load_dotenv | ||
from azure.ai.ml import MLClient, command, Input | ||
from azure.identity import DefaultAzureCredential | ||
from azure.ai.ml.entities import AmlCompute, Model | ||
from azure.ai.ml.constants import AssetTypes | ||
|
||
from src.config import MODELS | ||
|
||
|
||
load_dotenv() | ||
|
||
|
||
def connect_to_workspace(): | ||
subscription_id = os.getenv("AZURE_SUBSCRIPTION_ID") | ||
resource_group = os.getenv("AZURE_RESOURCE_GROUP") | ||
workspace = os.getenv("AZURE_WORKSPACE") | ||
|
||
ml_client = MLClient( | ||
DefaultAzureCredential(), subscription_id, resource_group, workspace | ||
) | ||
|
||
return ml_client | ||
|
||
|
||
def get_compute(ml_client): | ||
cpu_compute_target = os.getenv("AZURE_COMPUTE_TARGET") | ||
size = os.getenv("AZURE_COMPUTE_SIZE") | ||
min_instances = os.getenv("AZURE_COMPUTE_MIN_INSTANCES") | ||
max_instances = os.getenv("AZURE_COMPUTE_MAX_INSTANCES") | ||
|
||
try: | ||
ml_client.compute.get(cpu_compute_target) | ||
except Exception: | ||
click.echo("Creating a new cpu compute target...") | ||
compute = AmlCompute( | ||
name=cpu_compute_target, | ||
size=size, | ||
min_instances=min_instances, | ||
max_instances=max_instances, | ||
) | ||
ml_client.compute.begin_create_or_update(compute).result() | ||
|
||
|
||
def submit_job(ml_client, model, optimizer, loss, metric, epochs, batch_size): | ||
code = os.getenv("AZURE_CODE_PATH") | ||
environment = os.getenv("AZURE_ENVIRONMENT") | ||
type_ = os.getenv("AZURE_STORAGE_TYPE") | ||
path = os.getenv("AZURE_STORAGE_PATH") | ||
compute = os.getenv("AZURE_COMPUTE_TARGET") | ||
|
||
train_path = os.path.join(path, "train") | ||
test_path = os.path.join(path, "test") | ||
|
||
command_job = command( | ||
code=code, | ||
command=( | ||
f"python -m src.scripts.azure.machine_learning.train_{model}" | ||
" --train ${{inputs.train}} --test ${{inputs.test}}" | ||
" --epochs ${{inputs.epochs}} --optimizer ${{inputs.optimizer}}" | ||
" --loss ${{inputs.loss}} --metric ${{inputs.metric}}" | ||
" --batch_size ${{inputs.batch_size}} --model ${{inputs.model}}" | ||
), | ||
environment=environment, | ||
inputs={ | ||
"train": Input( | ||
type=type_, | ||
path=train_path, | ||
), | ||
"test": Input( | ||
type=type_, | ||
path=test_path, | ||
), | ||
"optimizer": optimizer, | ||
"loss": loss, | ||
"metric": metric, | ||
"epochs": epochs, | ||
"batch_size": batch_size, | ||
"model": model, | ||
}, | ||
compute=compute, | ||
name=f"train_{model}_{datetime.now().strftime('%Y%m%d%H%M%S')}", | ||
) | ||
|
||
returned_job = ml_client.jobs.create_or_update(command_job) | ||
|
||
return returned_job | ||
|
||
|
||
def register_model(ml_client, returned_job, run_name, run_description): | ||
run_model = Model( | ||
path=f"azureml://jobs/{returned_job.name}/outputs/artifacts/paths/model/", | ||
name=run_name, | ||
description=run_description, | ||
type=AssetTypes.MLFLOW_MODEL, | ||
) | ||
|
||
ml_client.models.create_or_update(run_model) | ||
|
||
|
||
@click.command() | ||
@click.option("--model", type=click.Choice(MODELS), help="Model to train") | ||
@click.option( | ||
"--optimizer", | ||
type=click.Choice(["adam", "sgd"]), | ||
default="adam", | ||
help="Optimizer to use", | ||
) | ||
@click.option( | ||
"--loss", | ||
type=click.Choice(["binary_crossentropy", "categorical_crossentropy"]), | ||
default="binary_crossentropy", | ||
help="Loss function to use", | ||
) | ||
@click.option( | ||
"--metric", | ||
type=click.Choice(["accuracy", "f1"]), | ||
default="accuracy", | ||
help="Metrics to use", | ||
) | ||
@click.option("--epochs", type=int, default=10, help="Number of epochs to train for") | ||
@click.option("--batch_size", type=int, default=32, help="Batch size to use") | ||
def run(model, optimizer, loss, metric, epochs, batch_size): | ||
if model not in MODELS: | ||
raise ValueError(f"Model {model} not supported") | ||
|
||
ml_client = connect_to_workspace() | ||
|
||
get_compute(ml_client=ml_client) | ||
|
||
returned_job = submit_job( | ||
ml_client=ml_client, | ||
model=model, | ||
optimizer=optimizer, | ||
loss=loss, | ||
metric=metric, | ||
epochs=epochs, | ||
batch_size=batch_size, | ||
) | ||
|
||
click.echo("Job created with:") | ||
click.echo(f" - id: {returned_job.id}") | ||
click.echo(f" - name: {returned_job.name}") | ||
click.echo(f" - url: {returned_job.studio_url}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run() # pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
import logging | ||
from datetime import datetime | ||
|
||
import click | ||
import tensorflow as tf | ||
|
||
from src.model.builders import ( | ||
ConvNeXtBuilder, | ||
DenseNetBuilder, | ||
EfficientNetBuilder, | ||
EfficientNetV2Builder, | ||
InceptionNetBuilder, | ||
InceptionResNetBuilder, | ||
MobileNetBuilder, | ||
ResNetBuilder, | ||
ResNetV2Builder, | ||
VGGBuilder, | ||
XceptionBuilder, | ||
) | ||
from src.model.director import ModelDirector | ||
from src.dataset.dataset_loader import DatasetLoader | ||
from src.config import EARLY_STOPPING_CONFIG, REDUCE_LR_CONFIG, MODELS | ||
|
||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@click.command() | ||
@click.command( | ||
"--model", type=click.Choice(MODELS), default="mobilenet", help="Model to train" | ||
) | ||
@click.option( | ||
"--train", type=click.Path(exists=True), help="Path to the training dataset" | ||
) | ||
@click.option("--test", type=click.Path(exists=True), help="Path to the test dataset") | ||
@click.option( | ||
"--optimizer", | ||
type=click.Choice(["adam", "sgd"]), | ||
default="adam", | ||
help="Optimizer to use", | ||
) | ||
@click.option( | ||
"--loss", | ||
type=click.Choice(["binary_crossentropy", "categorical_crossentropy"]), | ||
default="binary_crossentropy", | ||
help="Loss function to use", | ||
) | ||
@click.option( | ||
"--metric", | ||
type=click.Choice(["accuracy", "f1"]), | ||
default="accuracy", | ||
help="Metrics to use", | ||
) | ||
@click.option("--epochs", type=int, default=10, help="Number of epochs to train for") | ||
def run(model, train, test, optimizer, loss, metric, epochs): | ||
logger.info(f"Started training run at {datetime.now()}") | ||
logger.info( | ||
f"Run parameters - optimizer: {optimizer}, loss: {loss}, metrics: {metric}" | ||
) | ||
|
||
builder = { | ||
"convnext": ConvNeXtBuilder, | ||
"densenet": DenseNetBuilder, | ||
"efficientnet": EfficientNetBuilder, | ||
"efficientnetv2": EfficientNetV2Builder, | ||
"inceptionnet": InceptionNetBuilder, | ||
"inceptionresnet": InceptionResNetBuilder, | ||
"mobilenet": MobileNetBuilder, | ||
"resnet": ResNetBuilder, | ||
"resnetv2": ResNetV2Builder, | ||
"vgg": VGGBuilder, | ||
"xception": XceptionBuilder, | ||
}[model]() | ||
director = ModelDirector(builder) | ||
model = director.make() | ||
logger.info(f"Built model with {str(builder)}") | ||
|
||
train_loader = DatasetLoader(train) | ||
test_loader = DatasetLoader(test) | ||
|
||
train_dataset = train_loader.get_dataset() | ||
test_dataset = test_loader.get_dataset() | ||
logger.info("Loaded train and test datasets") | ||
|
||
model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) | ||
logger.info("Compiled model") | ||
|
||
ec = tf.keras.callbacks.EarlyStopping(**EARLY_STOPPING_CONFIG) | ||
lr = tf.keras.callbacks.ReduceLROnPlateau(**REDUCE_LR_CONFIG) | ||
|
||
model.fit(train_dataset, epochs=epochs, callbacks=[ec, lr]) | ||
logger.info("Trained model") | ||
|
||
model.evaluate(test_dataset) | ||
logger.info("Evaluated model") | ||
|
||
logger.info(f"Finished training at {datetime.now()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run() # pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters