Skip to content

Commit

Permalink
slight reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewpeng02 committed May 1, 2024
1 parent 38b79da commit 0724b98
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions frontend/src/features/Train/types/trainTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ export interface ConfusionMatrixChart {

// more detailed information, used when viewing a run
export interface DetailedTrainResultsData {
basicInfo: TrainResultsData
basic_info: TrainResultsData

allMetrics: Chart[]
all_metrics: Chart[]
}

export interface FileUploadData {
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/pages/train/[train_space_id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const mapTrainResultsDataToCharts = (
detailedTrainResultsData: DetailedTrainResultsData
) => {
// sort by graph_index asc and ignore negative graph indices
const sortedData = detailedTrainResultsData.allMetrics
const sortedData = detailedTrainResultsData.all_metrics
.filter((metric) => metric.graph_index >= 0)
.sort((a, b) => a.graph_index - b.graph_index);
const charts = [];
Expand Down
20 changes: 10 additions & 10 deletions training/training/core/celery/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def saveDetailedTrainResultsDataToS3(
):
s3 = boto3.resource("s3")
s3.Object(
"dlp-executions", f"{detailedTrainResultsData.basicInfo.trainspaceId}.json"
"dlp-executions", f"{detailedTrainResultsData.basic_info.trainspaceId}.json"
).put(Body=detailedTrainResultsData.json())


def collectClassificationTrainingResults(trainer, basicInfo):
def collectClassificationTrainingResults(trainer, basic_info):
trainTestLoss = [
{
"x_name": "Epoch",
Expand All @@ -63,8 +63,8 @@ def collectClassificationTrainingResults(trainer, basicInfo):

detailedTrainResultsData = DetailedTrainResultsData(
**{
"basicInfo": basicInfo,
"allMetrics": [
"basic_info": basic_info,
"all_metrics": [
{
"name": "Train and test loss vs epoch",
"time_series": trainTestLoss,
Expand Down Expand Up @@ -92,7 +92,7 @@ def collectClassificationTrainingResults(trainer, basicInfo):
@celery_app.task(name="tabularTrainTask")
def tabularTrainTask(input: dict, trainspaceId: str, uid: str):
tabularParams = TabularParams(**input)
basicInfo = TrainResultsData(
basic_info = TrainResultsData(
**{
"name": tabularParams.name,
"trainspaceId": trainspaceId,
Expand Down Expand Up @@ -137,7 +137,7 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str):
)

detailedTrainResultsData = collectClassificationTrainingResults(
trainer, basicInfo
trainer, basic_info
)

# save detailedTrainResultsData
Expand Down Expand Up @@ -174,8 +174,8 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str):

detailedTrainResultsData = DetailedTrainResultsData(
**{
"basicInfo": basicInfo,
"allMetrics": [
"basic_info": basic_info,
"all_metrics": [
{
"name": "Train and test loss vs epoch",
"time_series": trainTestLoss,
Expand All @@ -193,7 +193,7 @@ def tabularTrainTask(input: dict, trainspaceId: str, uid: str):
@celery_app.task(name="imageTrainTask")
def imageTrainTask(input: dict, trainspaceId: str, uid: str):
imageParams = ImageParams(**input)
basicInfo = TrainResultsData(
basic_info = TrainResultsData(
**{
"name": imageParams.name,
"trainspaceId": trainspaceId,
Expand Down Expand Up @@ -223,7 +223,7 @@ def imageTrainTask(input: dict, trainspaceId: str, uid: str):
dataCreator.getCategoryList(),
)
detailedTrainResultsData = collectClassificationTrainingResults(
trainer, basicInfo
trainer, basic_info
)

# save detailedTrainResultsData
Expand Down
2 changes: 1 addition & 1 deletion training/training/routes/training/results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def getDetailedTrainResultsData(request: Request, trainspace_id: str):
file_content = content_object.get()["Body"].read().decode("utf-8")
json_content = json.loads(file_content)
detailedTrainResultsData = DetailedTrainResultsData(**json_content)
if request.auth["uid"] != detailedTrainResultsData.basicInfo.uid:
if request.auth["uid"] != detailedTrainResultsData.basic_info.uid:
raise AuthenticationError("Invalid authorization")

except botocore.exceptions.ClientError as e:
Expand Down
4 changes: 2 additions & 2 deletions training/training/routes/training/results/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ class ConfusionMatrixChart(Schema):


class DetailedTrainResultsData(Schema):
basicInfo: TrainResultsData
basic_info: TrainResultsData

allMetrics: List[Chart]
all_metrics: List[Chart]

0 comments on commit 0724b98

Please sign in to comment.