Skip to content

Commit

Permalink
message sse stuff is done
Browse files Browse the repository at this point in the history
  • Loading branch information
dleviminzi committed Mar 14, 2024
1 parent e0c1333 commit 06e5a05
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 67 deletions.
File renamed without changes.
File renamed without changes.
66 changes: 66 additions & 0 deletions examples/messages_stream/message_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package main

import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"strings"

"github.com/dleviminzi/anthrogo"
)

func main() {
c, err := anthrogo.NewClient()
if err != nil {
log.Fatal(err)
os.Exit(1)
}

systemPrompt := "you are an expert in all things bananas"

// Read user input for the prompt
reader := bufio.NewReader(os.Stdin)
fmt.Print("Enter your prompt: ")
userPrompt, _ := reader.ReadString('\n')
userPrompt = strings.TrimSuffix(userPrompt, "\n")

r, _, err := c.MessageStreamRequest(context.Background(), anthrogo.MessagePayload{
Model: anthrogo.ModelClaude3Opus,
Messages: []anthrogo.Message{{
Role: anthrogo.RoleTypeUser,
Content: []anthrogo.MessageContent{{
Type: anthrogo.ContentTypeText,
Text: &userPrompt,
}},
}},
System: &systemPrompt,
MaxTokens: 1000,
})
if err != nil {
log.Fatal(err)
os.Exit(1)
}
defer r.Close()

// Create an SSEDecoder
decoder := anthrogo.NewMessageSSEDecoder(r)
for {
message, err := decoder.Decode(anthrogo.DecodeOptions{ContentOnly: true})
if err != nil {
if err == io.EOF {
break
}
fmt.Print(err)
continue
}

if message.Event == "message_stop" {
break
}

fmt.Print(message.Data.Content)
}
}
192 changes: 125 additions & 67 deletions message_sse_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@ import (
"strings"
)

type MessageSSEPayload struct {
Event string `json:"event"`
Data interface{} `json:"data"`
// MessageEventPayload is the decoded event from anthropic
type MessageEventPayload struct {
Event string
Data EventData
}

type MessageStartData struct {
// EventData contains content which will be whatever the model output
// and Data which is the full data from the event
type EventData struct {
Content string
Data any
}

// MessageStart is one of the data types for events and it represents the start of a
// a stream of messages. It contains metadata about the request.
type MessageStart struct {
Type string `json:"type"`
Message struct {
ID string `json:"id"`
Expand All @@ -30,7 +40,8 @@ type MessageStartData struct {
} `json:"message"`
}

type ContentBlockStartData struct {
// ContentBlockStart marks the start of a new content block in the message stream.
type ContentBlockStart struct {
Type string `json:"type"`
Index int `json:"index"`
ContentBlock struct {
Expand All @@ -39,38 +50,42 @@ type ContentBlockStartData struct {
} `json:"content_block"`
}

// PingData is a ping event
type PingData struct {
Type string `json:"type"`
}

type ContentBlockDeltaData struct {
Type string `json:"type"`
Index int `json:"index"`
Delta TextDelta `json:"delta"`
}

type TextDelta struct {
Type string `json:"type"`
Text string `json:"text"`
// ContentBlockDelta carries new content for a content block in the message stream.
type ContentBlockDelta struct {
Type string `json:"type"`
Index int `json:"index"`
Delta struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
}

type ContentBlockStopData struct {
// ContentBlockStop marks the end of a content block in the message stream.
type ContentBlockStop struct {
Type string `json:"type"`
Index int `json:"index"`
}

type MessageDeltaData struct {
// MessageDelta events indicate top-level changes to the final message.
type MessageDelta struct {
Type string `json:"type"`
Delta interface{} `json:"delta"`
Usage struct {
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}

// MessageStopData is the final event in a message stream.
type MessageStopData struct {
Type string `json:"type"`
}

// ErrorData is the event type for errors.
type ErrorData struct {
Type string `json:"type"`
Error struct {
Expand All @@ -79,19 +94,41 @@ type ErrorData struct {
} `json:"error"`
}

// MessageEvent is the event type for messages. It contains the message payload
// and an error if one occurred.
type MessageEvent struct {
Message *MessageEventPayload
Err *error
}

// MessageSSEDecoder is a decoder for the SSE stream from the message endpoint.
type MessageSSEDecoder struct {
reader *bufio.Reader
content []string
}

func NewSSEDecoder(reader io.Reader) *MessageSSEDecoder {
// DecodeOptions are options for decoding the SSE stream.
type DecodeOptions struct {
ContentOnly bool
}

// NewMessageSSEDecoder creates a new MessageSSEDecoder.
func NewMessageSSEDecoder(reader io.Reader) *MessageSSEDecoder {
return &MessageSSEDecoder{
reader: bufio.NewReader(reader),
content: make([]string, 0),
}
}

func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) {
// Decode reads the next event from the SSE stream.
func (d *MessageSSEDecoder) Decode(opts ...DecodeOptions) (*MessageEventPayload, error) {
var options DecodeOptions
if len(opts) > 1 {
return nil, fmt.Errorf("too many options provided, expected at most one")
} else if len(opts) == 1 {
options = opts[0]
}

line, err := d.reader.ReadString('\n')
if err != nil {
if err == io.EOF {
Expand All @@ -103,7 +140,7 @@ func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) {
line = strings.TrimSpace(line)
if line == "" {
// Recursively call Decode to read the next event
return d.Decode()
return d.Decode(opts...)
}

parts := strings.SplitN(line, ":", 2)
Expand All @@ -115,38 +152,24 @@ func (d *MessageSSEDecoder) Decode() (*MessageSSEPayload, error) {
value := strings.TrimSpace(parts[1])

if field == "event" {
data := d.decodeData(value)

// Update the content array based on the event type
switch value {
case "content_block_start":
contentBlockStartData := data.(ContentBlockStartData)
index := contentBlockStartData.Index
if index >= len(d.content) {
d.content = append(d.content, make([]string, index-len(d.content)+1)...)
}
d.content[index] = contentBlockStartData.ContentBlock.Text
case "content_block_delta":
contentBlockDeltaData := data.(ContentBlockDeltaData)
index := contentBlockDeltaData.Index
if index >= len(d.content) {
d.content = append(d.content, make([]string, index-len(d.content)+1)...)
}
d.content[index] += contentBlockDeltaData.Delta.Text
data, err := d.decodeData(value)
if err != nil {
return nil, err
}

return &MessageSSEPayload{
Event: value,
Data: data,
}, nil
if data.Content != "" || !options.ContentOnly || value == "message_stop" {
return &MessageEventPayload{
Event: value,
Data: data,
}, nil
}
}
// Recursively call Decode to read the next event if we didn't have one here
return d.Decode()
return d.Decode(opts...)
}

// TODO: check for errors and return them here.
func (d *MessageSSEDecoder) decodeData(event string) any {
var data any
func (d *MessageSSEDecoder) decodeData(event string) (EventData, error) {
var eventData EventData

for {
line, err := d.reader.ReadString('\n')
Expand All @@ -164,40 +187,75 @@ func (d *MessageSSEDecoder) decodeData(event string) any {

switch event {
case "message_start":
var messageStartData MessageStartData
json.Unmarshal([]byte(jsonData), &messageStartData)
data = messageStartData
var messageStartData MessageStart
err := json.Unmarshal([]byte(jsonData), &messageStartData)
if err != nil {
return eventData, err
}
eventData.Data = messageStartData
case "content_block_start":
var contentBlockStartData ContentBlockStartData
json.Unmarshal([]byte(jsonData), &contentBlockStartData)
data = contentBlockStartData
var contentBlockStartData ContentBlockStart
err := json.Unmarshal([]byte(jsonData), &contentBlockStartData)
if err != nil {
return eventData, err
}
eventData.Data = contentBlockStartData
eventData.Content = contentBlockStartData.ContentBlock.Text
d.updateContent(contentBlockStartData.Index, contentBlockStartData.ContentBlock.Text)
case "ping":
var pingData PingData
json.Unmarshal([]byte(jsonData), &pingData)
data = pingData
err := json.Unmarshal([]byte(jsonData), &pingData)
if err != nil {
return eventData, err
}
eventData.Data = pingData
case "content_block_delta":
var contentBlockDeltaData ContentBlockDeltaData
json.Unmarshal([]byte(jsonData), &contentBlockDeltaData)
data = contentBlockDeltaData
var contentBlockDeltaData ContentBlockDelta
err := json.Unmarshal([]byte(jsonData), &contentBlockDeltaData)
if err != nil {
return eventData, err
}
eventData.Data = contentBlockDeltaData
eventData.Content = contentBlockDeltaData.Delta.Text
d.updateContent(contentBlockDeltaData.Index, contentBlockDeltaData.Delta.Text)
case "content_block_stop":
var contentBlockStopData ContentBlockStopData
json.Unmarshal([]byte(jsonData), &contentBlockStopData)
data = contentBlockStopData
var contentBlockStopData ContentBlockStop
err := json.Unmarshal([]byte(jsonData), &contentBlockStopData)
if err != nil {
return eventData, err
}
eventData.Data = contentBlockStopData
case "message_delta":
var messageDeltaData MessageDeltaData
json.Unmarshal([]byte(jsonData), &messageDeltaData)
data = messageDeltaData
var messageDeltaData MessageDelta
err := json.Unmarshal([]byte(jsonData), &messageDeltaData)
if err != nil {
return eventData, err
}
eventData.Data = messageDeltaData
case "message_stop":
var messageStopData MessageStopData
json.Unmarshal([]byte(jsonData), &messageStopData)
data = messageStopData
err := json.Unmarshal([]byte(jsonData), &messageStopData)
if err != nil {
return eventData, err
}
eventData.Data = messageStopData
case "error":
var errorData ErrorData
json.Unmarshal([]byte(jsonData), &errorData)
data = errorData
err := json.Unmarshal([]byte(jsonData), &errorData)
if err != nil {
return eventData, err
}
return eventData, fmt.Errorf("error(%s) - %s", errorData.Error.Type, errorData.Error.Message)
}
}
}

return data
return eventData, nil
}

func (d *MessageSSEDecoder) updateContent(index int, content string) {
if index >= len(d.content) {
d.content = append(d.content, make([]string, index-len(d.content)+1)...)
}
d.content[index] += content
}
Loading

0 comments on commit 06e5a05

Please sign in to comment.