From 2be7241710cd221fddcccee956245a2c360248ea Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 25 Sep 2023 08:42:46 +0200 Subject: [PATCH 1/2] fix: use custom marshaler for n_epochs --- fine_tuning_job.go | 35 ++++++++++++++++++++++++++++++++++- fine_tuning_job_test.go | 23 +++++++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index a840b7ec..7b4bb418 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -23,8 +24,40 @@ type FineTuningJob struct { TrainedTokens int `json:"trained_tokens"` } +type HyperparameterNEpochs struct { + IntValue *int `json:"-"` + StringValue *string `json:"-"` +} + +func (h *HyperparameterNEpochs) UnmarshalJSON(data []byte) error { + var intValue int + var stringValue string + + if err := json.Unmarshal(data, &intValue); err == nil { + h.IntValue = &intValue + return nil + } + + if err := json.Unmarshal(data, &stringValue); err != nil { + return err + } + + h.StringValue = &stringValue + return nil +} + +func (h *HyperparameterNEpochs) MarshalJSON() ([]byte, error) { + if h.IntValue != nil { + return json.Marshal(*h.IntValue) + } else if h.StringValue != nil { + return json.Marshal(*h.StringValue) + } + + return nil, fmt.Errorf("invalid hyperparameter n_epochs") +} + type Hyperparameters struct { - Epochs int `json:"n_epochs"` + Epochs *HyperparameterNEpochs `json:"n_epochs,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 519c6cd2..bea20302 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -21,8 +21,27 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/v1/fine_tuning/jobs", func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + nEpochs := "auto" + resBytes, _ := json.Marshal(FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: Hyperparameters{ + Epochs: &HyperparameterNEpochs{ + IntValue: nil, + StringValue: &nEpochs, + }, + }, + TrainedTokens: 5768, + }) fmt.Fprintln(w, string(resBytes)) }, ) From 581e933f4b9d52398a541e79e028a3e318bf5dbd Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 5 Oct 2023 22:46:00 +0200 Subject: [PATCH 2/2] chore: use any for n_epochs --- fine_tuning_job.go | 35 +---------------------------------- fine_tuning_job_test.go | 6 +----- 2 files changed, 2 insertions(+), 39 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 7b4bb418..07b0c337 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -2,7 +2,6 @@ package openai import ( "context" - "encoding/json" "fmt" "net/http" "net/url" @@ -24,40 +23,8 @@ type FineTuningJob struct { TrainedTokens int `json:"trained_tokens"` } -type HyperparameterNEpochs struct { - IntValue *int `json:"-"` - StringValue *string `json:"-"` -} - -func (h *HyperparameterNEpochs) UnmarshalJSON(data []byte) error { - var intValue int - var stringValue string - - if err := json.Unmarshal(data, &intValue); err == nil { - h.IntValue = &intValue - return nil - } - - if err := json.Unmarshal(data, &stringValue); err != nil { - return err - } - - h.StringValue = &stringValue - return nil -} - -func (h *HyperparameterNEpochs) MarshalJSON() ([]byte, error) { - if h.IntValue != nil { - return json.Marshal(*h.IntValue) - } else if h.StringValue != nil { - return json.Marshal(*h.StringValue) - } - - return nil, fmt.Errorf("invalid hyperparameter n_epochs") -} - type Hyperparameters struct { - Epochs *HyperparameterNEpochs `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index bea20302..f6d41c33 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -21,7 +21,6 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/v1/fine_tuning/jobs", func(w http.ResponseWriter, r *http.Request) { - nEpochs := "auto" resBytes, _ := json.Marshal(FineTuningJob{ Object: "fine_tuning.job", ID: testFineTuninigJobID, @@ -35,10 +34,7 @@ func TestFineTuningJob(t *testing.T) { ValidationFile: "", TrainingFile: "file-abc123", Hyperparameters: Hyperparameters{ - Epochs: &HyperparameterNEpochs{ - IntValue: nil, - StringValue: &nEpochs, - }, + Epochs: "auto", }, TrainedTokens: 5768, })