Skip to content

Commit

Permalink
reduce duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewpeng02 committed May 1, 2024
1 parent a0f4272 commit d9c11d8
Showing 1 changed file with 36 additions and 72 deletions.
108 changes: 36 additions & 72 deletions training/training/core/celery/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def saveDetailedTrainResultsDataToS3(
).put(Body=detailedTrainResultsData.json())


def collectClassificationTrainingResults(trainer, basic_info):
def collectTrainingResults(trainer, basic_info, is_classification):
trainTestLoss = [
{
"x_name": "Epoch",
Expand All @@ -58,33 +58,37 @@ def collectClassificationTrainingResults(trainer, basic_info):
trainTestLoss[0]["y_values"].append(epoch_result.train_loss)
trainTestLoss[1]["x_values"].append(epoch_result.epoch_num)
trainTestLoss[1]["y_values"].append(epoch_result.test_loss)
confusionMatrix = trainer.generate_confusion_matrix()
aucRocCurve = trainer.generate_AUC_ROC_CURVE()

detailedTrainResultsData = DetailedTrainResultsData(
**{
"basic_info": basic_info,
"all_metrics": [
{
"name": "Train and test loss vs epoch",
"time_series": trainTestLoss,
"graph_index": 0,
"chart_type": "LINE",
},
{
"name": "Confusion matrix",
"values": confusionMatrix.tolist(),
"chart_type": "CONFUSION_MATRIX",
"graph_index": 1,
},
{
"name": "AUC/ROC curve",
"values": aucRocCurve,
"chart_type": "AUC/ROC",
"graph_index": 2,
},
],
all_metrics = [
{
"name": "Train and test loss vs epoch",
"time_series": trainTestLoss,
"graph_index": 0,
"chart_type": "LINE",
}
]
if is_classification:
confusionMatrix = trainer.generate_confusion_matrix()
aucRocCurve = trainer.generate_AUC_ROC_CURVE()
all_metrics.append(
{
"name": "Confusion matrix",
"values": confusionMatrix.tolist(),
"chart_type": "CONFUSION_MATRIX",
"graph_index": 1,
}
)
all_metrics.append(
{
"name": "AUC/ROC curve",
"values": aucRocCurve,
"chart_type": "AUC/ROC",
"graph_index": 2,
}
)

detailedTrainResultsData = DetailedTrainResultsData(
**{"basic_info": basic_info, "all_metrics": all_metrics}
)
return detailedTrainResultsData

Expand Down Expand Up @@ -135,13 +139,6 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str):
tabularParams.epochs,
dataCreator.getCategoryList(),
)

detailedTrainResultsData = collectClassificationTrainingResults(
trainer, basic_info
)

# save detailedTrainResultsData
saveDetailedTrainResultsDataToS3(detailedTrainResultsData)
else:
trainer = RegressionTrainer(
train_loader,
Expand All @@ -151,43 +148,12 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str):
criterionHandler,
tabularParams.epochs,
)
detailedTrainResultsData = collectTrainingResults(
trainer, basic_info, tabularParams.problem_type == "CLASSIFICATION"
)

trainTestLoss = [
{
"x_name": "Epoch",
"y_name": "Train loss",
"x_values": [],
"y_values": [],
},
{
"x_name": "Epoch",
"y_name": "Test loss",
"x_values": [],
"y_values": [],
},
]
for epoch_result in trainer:
trainTestLoss[0]["x_values"].append(epoch_result.epoch_num)
trainTestLoss[0]["y_values"].append(epoch_result.train_loss)
trainTestLoss[1]["x_values"].append(epoch_result.epoch_num)
trainTestLoss[1]["y_values"].append(epoch_result.test_loss)

detailedTrainResultsData = DetailedTrainResultsData(
**{
"basic_info": basic_info,
"all_metrics": [
{
"name": "Train and test loss vs epoch",
"time_series": trainTestLoss,
"graph_index": 0,
"chart_type": "LINE",
}
],
}
)

# save detailedTrainResultsData
saveDetailedTrainResultsDataToS3(detailedTrainResultsData)
# save detailedTrainResultsData
saveDetailedTrainResultsDataToS3(detailedTrainResultsData)


@celery_app.task(name="imageTrainTask")
Expand Down Expand Up @@ -222,9 +188,7 @@ def imageTrainTask(input: dict, trainspaceId: str, uid: str):
imageParams.epochs,
dataCreator.getCategoryList(),
)
detailedTrainResultsData = collectClassificationTrainingResults(
trainer, basic_info
)
detailedTrainResultsData = collectTrainingResults(trainer, basic_info, True)

# save detailedTrainResultsData
saveDetailedTrainResultsDataToS3(detailedTrainResultsData)

0 comments on commit d9c11d8

Please sign in to comment.