From 06e5a05d995b268ad7eb85910eb031c5a7bb78bf Mon Sep 17 00:00:00 2001 From: Daniel Levi-Minzi Date: Wed, 13 Mar 2024 22:05:40 -0400 Subject: [PATCH] message sse stuff is done --- .../{basic_example.go => completion_basic.go} | 0 ...stream_example.go => completion_stream.go} | 0 .../{basic_message.go => message_basic.go} | 0 examples/messages_stream/message_stream.go | 66 ++++++ message_sse_decoder.go | 192 ++++++++++++------ messages.go | 36 ++++ 6 files changed, 227 insertions(+), 67 deletions(-) rename examples/completion/{basic_example.go => completion_basic.go} (100%) rename examples/completion_stream/{stream_example.go => completion_stream.go} (100%) rename examples/messages/{basic_message.go => message_basic.go} (100%) create mode 100644 examples/messages_stream/message_stream.go diff --git a/examples/completion/basic_example.go b/examples/completion/completion_basic.go similarity index 100% rename from examples/completion/basic_example.go rename to examples/completion/completion_basic.go diff --git a/examples/completion_stream/stream_example.go b/examples/completion_stream/completion_stream.go similarity index 100% rename from examples/completion_stream/stream_example.go rename to examples/completion_stream/completion_stream.go diff --git a/examples/messages/basic_message.go b/examples/messages/message_basic.go similarity index 100% rename from examples/messages/basic_message.go rename to examples/messages/message_basic.go diff --git a/examples/messages_stream/message_stream.go b/examples/messages_stream/message_stream.go new file mode 100644 index 0000000..69cc511 --- /dev/null +++ b/examples/messages_stream/message_stream.go @@ -0,0 +1,66 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "log" + "os" + "strings" + + "github.com/dleviminzi/anthrogo" +) + +func main() { + c, err := anthrogo.NewClient() + if err != nil { + log.Fatal(err) + os.Exit(1) + } + + systemPrompt := "you are an expert in all things bananas" + + // Read user input for the prompt + reader := bufio.NewReader(os.Stdin) + fmt.Print("Enter your prompt: ") + userPrompt, _ := reader.ReadString('\n') + userPrompt = strings.TrimSuffix(userPrompt, "\n") + + r, _, err := c.MessageStreamRequest(context.Background(), anthrogo.MessagePayload{ + Model: anthrogo.ModelClaude3Opus, + Messages: []anthrogo.Message{{ + Role: anthrogo.RoleTypeUser, + Content: []anthrogo.MessageContent{{ + Type: anthrogo.ContentTypeText, + Text: &userPrompt, + }}, + }}, + System: &systemPrompt, + MaxTokens: 1000, + }) + if err != nil { + log.Fatal(err) + os.Exit(1) + } + defer r.Close() + + // Create an SSEDecoder + decoder := anthrogo.NewMessageSSEDecoder(r) + for { + message, err := decoder.Decode(anthrogo.DecodeOptions{ContentOnly: true}) + if err != nil { + if err == io.EOF { + break + } + fmt.Print(err) + continue + } + + if message.Event == "message_stop" { + break + } + + fmt.Print(message.Data.Content) + } +} diff --git a/message_sse_decoder.go b/message_sse_decoder.go index c7d9cbb..e0a028d 100644 --- a/message_sse_decoder.go +++ b/message_sse_decoder.go @@ -8,12 +8,22 @@ import ( "strings" ) -type MessageSSEPayload struct { - Event string `json:"event"` - Data interface{} `json:"data"` +// MessageEventPayload is the decoded event from anthropic +type MessageEventPayload struct { + Event string + Data EventData } -type MessageStartData struct { +// EventData contains content which will be whatever the model output +// and Data which is the full data from the event +type EventData struct { + Content string + Data any +} + +// MessageStart is one of the data types for events and it represents the start of a +// a stream of messages. It contains metadata about the request. +type MessageStart struct { Type string `json:"type"` Message struct { ID string `json:"id"` @@ -30,7 +40,8 @@ type MessageStartData struct { } `json:"message"` } -type ContentBlockStartData struct { +// ContentBlockStart marks the start of a new content block in the message stream. +type ContentBlockStart struct { Type string `json:"type"` Index int `json:"index"` ContentBlock struct { @@ -39,27 +50,29 @@ type ContentBlockStartData struct { } `json:"content_block"` } +// PingData is a ping event type PingData struct { Type string `json:"type"` } -type ContentBlockDeltaData struct { - Type string `json:"type"` - Index int `json:"index"` - Delta TextDelta `json:"delta"` -} - -type TextDelta struct { - Type string `json:"type"` - Text string `json:"text"` +// ContentBlockDelta carries new content for a content block in the message stream. +type ContentBlockDelta struct { + Type string `json:"type"` + Index int `json:"index"` + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"delta"` } -type ContentBlockStopData struct { +// ContentBlockStop marks the end of a content block in the message stream. +type ContentBlockStop struct { Type string `json:"type"` Index int `json:"index"` } -type MessageDeltaData struct { +// MessageDelta events indicate top-level changes to the final message. +type MessageDelta struct { Type string `json:"type"` Delta interface{} `json:"delta"` Usage struct { @@ -67,10 +80,12 @@ type MessageDeltaData struct { } `json:"usage"` } +// MessageStopData is the final event in a message stream. type MessageStopData struct { Type string `json:"type"` } +// ErrorData is the event type for errors. type ErrorData struct { Type string `json:"type"` Error struct { @@ -79,19 +94,41 @@ type ErrorData struct { } `json:"error"` } +// MessageEvent is the event type for messages. It contains the message payload +// and an error if one occurred. +type MessageEvent struct { + Message *MessageEventPayload + Err *error +} + +// MessageSSEDecoder is a decoder for the SSE stream from the message endpoint. type MessageSSEDecoder struct { reader *bufio.Reader content []string } -func NewSSEDecoder(reader io.Reader) *MessageSSEDecoder { +// DecodeOptions are options for decoding the SSE stream. +type DecodeOptions struct { + ContentOnly bool +} + +// NewMessageSSEDecoder creates a new MessageSSEDecoder. +func NewMessageSSEDecoder(reader io.Reader) *MessageSSEDecoder { return &MessageSSEDecoder{ reader: bufio.NewReader(reader), content: make([]string, 0), } } -func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) { +// Decode reads the next event from the SSE stream. +func (d *MessageSSEDecoder) Decode(opts ...DecodeOptions) (*MessageEventPayload, error) { + var options DecodeOptions + if len(opts) > 1 { + return nil, fmt.Errorf("too many options provided, expected at most one") + } else if len(opts) == 1 { + options = opts[0] + } + line, err := d.reader.ReadString('\n') if err != nil { if err == io.EOF { @@ -103,7 +140,7 @@ func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) { line = strings.TrimSpace(line) if line == "" { // Recursively call Decode to read the next event - return d.Decode() + return d.Decode(opts...) } parts := strings.SplitN(line, ":", 2) @@ -115,38 +152,24 @@ func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) { value := strings.TrimSpace(parts[1]) if field == "event" { - data := d.decodeData(value) - - // Update the content array based on the event type - switch value { - case "content_block_start": - contentBlockStartData := data.(ContentBlockStartData) - index := contentBlockStartData.Index - if index >= len(d.content) { - d.content = append(d.content, make([]string, index-len(d.content)+1)...) - } - d.content[index] = contentBlockStartData.ContentBlock.Text - case "content_block_delta": - contentBlockDeltaData := data.(ContentBlockDeltaData) - index := contentBlockDeltaData.Index - if index >= len(d.content) { - d.content = append(d.content, make([]string, index-len(d.content)+1)...) - } - d.content[index] += contentBlockDeltaData.Delta.Text + data, err := d.decodeData(value) + if err != nil { + return nil, err } - return &MessageSSEPayload{ - Event: value, - Data: data, - }, nil + if data.Content != "" || !options.ContentOnly || value == "message_stop" { + return &MessageEventPayload{ + Event: value, + Data: data, + }, nil + } } // Recursively call Decode to read the next event if we didn't have one here - return d.Decode() + return d.Decode(opts...) } -// TODO: check for errors and return them here. -func (d *MessageSSEDecoder) decodeData(event string) any { - var data any +func (d *MessageSSEDecoder) decodeData(event string) (EventData, error) { + var eventData EventData for { line, err := d.reader.ReadString('\n') @@ -164,40 +187,75 @@ func (d *MessageSSEDecoder) decodeData(event string) any { switch event { case "message_start": - var messageStartData MessageStartData - json.Unmarshal([]byte(jsonData), &messageStartData) - data = messageStartData + var messageStartData MessageStart + err := json.Unmarshal([]byte(jsonData), &messageStartData) + if err != nil { + return eventData, err + } + eventData.Data = messageStartData case "content_block_start": - var contentBlockStartData ContentBlockStartData - json.Unmarshal([]byte(jsonData), &contentBlockStartData) - data = contentBlockStartData + var contentBlockStartData ContentBlockStart + err := json.Unmarshal([]byte(jsonData), &contentBlockStartData) + if err != nil { + return eventData, err + } + eventData.Data = contentBlockStartData + eventData.Content = contentBlockStartData.ContentBlock.Text + d.updateContent(contentBlockStartData.Index, contentBlockStartData.ContentBlock.Text) case "ping": var pingData PingData - json.Unmarshal([]byte(jsonData), &pingData) - data = pingData + err := json.Unmarshal([]byte(jsonData), &pingData) + if err != nil { + return eventData, err + } + eventData.Data = pingData case "content_block_delta": - var contentBlockDeltaData ContentBlockDeltaData - json.Unmarshal([]byte(jsonData), &contentBlockDeltaData) - data = contentBlockDeltaData + var contentBlockDeltaData ContentBlockDelta + err := json.Unmarshal([]byte(jsonData), &contentBlockDeltaData) + if err != nil { + return eventData, err + } + eventData.Data = contentBlockDeltaData + eventData.Content = contentBlockDeltaData.Delta.Text + d.updateContent(contentBlockDeltaData.Index, contentBlockDeltaData.Delta.Text) case "content_block_stop": - var contentBlockStopData ContentBlockStopData - json.Unmarshal([]byte(jsonData), &contentBlockStopData) - data = contentBlockStopData + var contentBlockStopData ContentBlockStop + err := json.Unmarshal([]byte(jsonData), &contentBlockStopData) + if err != nil { + return eventData, err + } + eventData.Data = contentBlockStopData case "message_delta": - var messageDeltaData MessageDeltaData - json.Unmarshal([]byte(jsonData), &messageDeltaData) - data = messageDeltaData + var messageDeltaData MessageDelta + err := json.Unmarshal([]byte(jsonData), &messageDeltaData) + if err != nil { + return eventData, err + } + eventData.Data = messageDeltaData case "message_stop": var messageStopData MessageStopData - json.Unmarshal([]byte(jsonData), &messageStopData) - data = messageStopData + err := json.Unmarshal([]byte(jsonData), &messageStopData) + if err != nil { + return eventData, err + } + eventData.Data = messageStopData case "error": var errorData ErrorData - json.Unmarshal([]byte(jsonData), &errorData) - data = errorData + err := json.Unmarshal([]byte(jsonData), &errorData) + if err != nil { + return eventData, err + } + return eventData, fmt.Errorf("error(%s) - %s", errorData.Error.Type, errorData.Error.Message) } } } - return data + return eventData, nil +} + +func (d *MessageSSEDecoder) updateContent(index int, content string) { + if index >= len(d.content) { + d.content = append(d.content, make([]string, index-len(d.content)+1)...) + } + d.content[index] += content } diff --git a/messages.go b/messages.go index 32d39fc..770b21e 100644 --- a/messages.go +++ b/messages.go @@ -102,6 +102,8 @@ type Usage struct { // MessageRequest sends a message to the model and returns the response. func (c *Client) MessageRequest(ctx context.Context, payload MessagePayload) (MessageResponse, error) { var resp MessageResponse + stream := false + payload.Stream = &stream req, cancel, err := c.createRequest(ctx, payload, RequestTypeMessages) if err != nil { @@ -136,3 +138,37 @@ func (c *Client) MessageRequest(ctx context.Context, payload MessagePayload) (Me return resp, nil } + +// MessageStreamRequest sends a message to the model and returns the body for the user to consume +func (c *Client) MessageStreamRequest(ctx context.Context, payload MessagePayload) (io.ReadCloser, context.CancelFunc, error) { + stream := true + payload.Stream = &stream + + req, cancel, err := c.createRequest(ctx, payload, RequestTypeMessages) + if err != nil { + return nil, nil, err + } + + res, err := c.doRequestWithRetries(req) + if err != nil { + return nil, nil, err + } + + if res.StatusCode != http.StatusOK { + var errorResponse ErrorResponse + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, nil, err + } + + err = json.Unmarshal(body, &errorResponse) + if err != nil { + return nil, nil, err + } + + return nil, nil, fmt.Errorf("%s: %s", errorResponse.Error.Type, errorResponse.Error.Message) + } + + return res.Body, cancel, nil +}