From e095df5325a39ed94940dbe3882d2aa14eb64ad0 Mon Sep 17 00:00:00 2001 From: floodwm Date: Fri, 20 Sep 2024 23:54:25 +0300 Subject: [PATCH] run_id string Optional (#855) Filter messages by the run ID that generated them. Co-authored-by: wappi --- .zshrc | 0 client_test.go | 2 +- messages.go | 5 +++++ messages_test.go | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 .zshrc diff --git a/.zshrc b/.zshrc new file mode 100644 index 000000000..e69de29bb diff --git a/client_test.go b/client_test.go index 7119d8a7e..3f27b9dd7 100644 --- a/client_test.go +++ b/client_test.go @@ -340,7 +340,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { return client.CreateMessage(ctx, "", MessageRequest{}) }}, {"ListMessage", func() (any, error) { - return client.ListMessage(ctx, "", nil, nil, nil, nil) + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) }}, {"RetrieveMessage", func() (any, error) { return client.RetrieveMessage(ctx, "", "") diff --git a/messages.go b/messages.go index 1fddd6314..eefc29a36 100644 --- a/messages.go +++ b/messages.go @@ -100,6 +100,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, order *string, after *string, before *string, + runID *string, ) (messages MessagesList, err error) { urlValues := url.Values{} if limit != nil { @@ -114,6 +115,10 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, if before != nil { urlValues.Add("before", *before) } + if runID != nil { + urlValues.Add("run_id", *runID) + } + encodedValues := "" if len(urlValues) > 0 { encodedValues = "?" + urlValues.Encode() diff --git a/messages_test.go b/messages_test.go index 71ceb4d3a..b25755f98 100644 --- a/messages_test.go +++ b/messages_test.go @@ -208,7 +208,7 @@ func TestMessages(t *testing.T) { } var msgs openai.MessagesList - msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -219,7 +219,8 @@ func TestMessages(t *testing.T) { order := "desc" after := "obj_foo" before := "obj_bar" - msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages")