diff --git a/training/training/core/celery/worker.py b/training/training/core/celery/worker.py index fddf5788..071112b3 100644 --- a/training/training/core/celery/worker.py +++ b/training/training/core/celery/worker.py @@ -38,7 +38,7 @@ def saveDetailedTrainResultsDataToS3( ).put(Body=detailedTrainResultsData.json()) -def collectClassificationTrainingResults(trainer, basic_info): +def collectTrainingResults(trainer, basic_info, is_classification): trainTestLoss = [ { "x_name": "Epoch", @@ -58,33 +58,37 @@ def collectClassificationTrainingResults(trainer, basic_info): trainTestLoss[0]["y_values"].append(epoch_result.train_loss) trainTestLoss[1]["x_values"].append(epoch_result.epoch_num) trainTestLoss[1]["y_values"].append(epoch_result.test_loss) - confusionMatrix = trainer.generate_confusion_matrix() - aucRocCurve = trainer.generate_AUC_ROC_CURVE() - detailedTrainResultsData = DetailedTrainResultsData( - **{ - "basic_info": basic_info, - "all_metrics": [ - { - "name": "Train and test loss vs epoch", - "time_series": trainTestLoss, - "graph_index": 0, - "chart_type": "LINE", - }, - { - "name": "Confusion matrix", - "values": confusionMatrix.tolist(), - "chart_type": "CONFUSION_MATRIX", - "graph_index": 1, - }, - { - "name": "AUC/ROC curve", - "values": aucRocCurve, - "chart_type": "AUC/ROC", - "graph_index": 2, - }, - ], + all_metrics = [ + { + "name": "Train and test loss vs epoch", + "time_series": trainTestLoss, + "graph_index": 0, + "chart_type": "LINE", } + ] + if is_classification: + confusionMatrix = trainer.generate_confusion_matrix() + aucRocCurve = trainer.generate_AUC_ROC_CURVE() + all_metrics.append( + { + "name": "Confusion matrix", + "values": confusionMatrix.tolist(), + "chart_type": "CONFUSION_MATRIX", + "graph_index": 1, + } + ) + all_metrics.append( + { + "name": "AUC/ROC curve", + "values": aucRocCurve, + "chart_type": "AUC/ROC", + "graph_index": 2, + } + ) + + detailedTrainResultsData = DetailedTrainResultsData( + **{"basic_info": basic_info, "all_metrics": all_metrics} ) return detailedTrainResultsData @@ -135,13 +139,6 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str): tabularParams.epochs, dataCreator.getCategoryList(), ) - - detailedTrainResultsData = collectClassificationTrainingResults( - trainer, basic_info - ) - - # save detailedTrainResultsData - saveDetailedTrainResultsDataToS3(detailedTrainResultsData) else: trainer = RegressionTrainer( train_loader, @@ -151,43 +148,12 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str): criterionHandler, tabularParams.epochs, ) + detailedTrainResultsData = collectTrainingResults( + trainer, basic_info, tabularParams.problem_type == "CLASSIFICATION" + ) - trainTestLoss = [ - { - "x_name": "Epoch", - "y_name": "Train loss", - "x_values": [], - "y_values": [], - }, - { - "x_name": "Epoch", - "y_name": "Test loss", - "x_values": [], - "y_values": [], - }, - ] - for epoch_result in trainer: - trainTestLoss[0]["x_values"].append(epoch_result.epoch_num) - trainTestLoss[0]["y_values"].append(epoch_result.train_loss) - trainTestLoss[1]["x_values"].append(epoch_result.epoch_num) - trainTestLoss[1]["y_values"].append(epoch_result.test_loss) - - detailedTrainResultsData = DetailedTrainResultsData( - **{ - "basic_info": basic_info, - "all_metrics": [ - { - "name": "Train and test loss vs epoch", - "time_series": trainTestLoss, - "graph_index": 0, - "chart_type": "LINE", - } - ], - } - ) - - # save detailedTrainResultsData - saveDetailedTrainResultsDataToS3(detailedTrainResultsData) + # save detailedTrainResultsData + saveDetailedTrainResultsDataToS3(detailedTrainResultsData) @celery_app.task(name="imageTrainTask") @@ -222,9 +188,7 @@ def imageTrainTask(input: dict, trainspaceId: str, uid: str): imageParams.epochs, dataCreator.getCategoryList(), ) - detailedTrainResultsData = collectClassificationTrainingResults( - trainer, basic_info - ) + detailedTrainResultsData = collectTrainingResults(trainer, basic_info, True) # save detailedTrainResultsData saveDetailedTrainResultsDataToS3(detailedTrainResultsData)