Skip to content

Commit

Permalink
chore: attach stream to response (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Jan 27, 2024
1 parent da56d3b commit 07713f3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 47 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ func (r *MyResponse) SetStatusCode(code int) error {
func (r *MyResponse) AcceptContentType() string {
// Return the accepted content type of the response
}

// Optional. Implement this method if you want to stream the response body.
func (r *MyResponse) StreamCallback() StreamCallback {
// Return the stream callback if any.
}
```

## Usage
Expand Down
43 changes: 22 additions & 21 deletions examples/cmd/stream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ func (r *createPostRequest) ContentType() string {
}

type CreatePostResponse struct {
HTTPStatusCode int `json:"-"`
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
HTTPStatusCode int `json:"-"`
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
StreamCallbackFn restclientgo.StreamCallback `json:"-"`
}

func (r *CreatePostResponse) Decode(body io.Reader) error {
Expand All @@ -61,34 +62,34 @@ func (r *CreatePostResponse) SetStatusCode(code int) error {
}

func (r *CreatePostResponse) SetHeaders(headers restclientgo.Headers) error { return nil }
func (r *CreatePostResponse) StreamCallback() restclientgo.StreamCallback {
return r.StreamCallbackFn
}

func main() {

var response string
restClient := restclientgo.New("http://localhost:11434/api")

restClient.SetStreamCallback(
func(data []byte) error {
var createPostResponse CreatePostResponse

err := json.Unmarshal(data, &createPostResponse)
if err != nil {
return err
}

response += createPostResponse.Response
fmt.Printf(createPostResponse.Response)

return nil
},
)

restClient.SetRequestModifier(func(req *http.Request) *http.Request {
req.Header.Set("Accept", "application/json")
return req
})

var createPostResponse CreatePostResponse
createPostResponse.StreamCallbackFn = func(data []byte) error {
var createPostResponse CreatePostResponse

err := json.Unmarshal(data, &createPostResponse)
if err != nil {
return err
}

response += createPostResponse.Response
fmt.Printf(createPostResponse.Response)

return nil
}

err := restClient.Post(
context.Background(),
Expand Down
54 changes: 28 additions & 26 deletions restclientgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ type RestClient struct {
endpoint string
requestModifier func(*http.Request) *http.Request
forceDecodeOnError bool
streamCallback StreamCallback
}

type Error string
Expand Down Expand Up @@ -69,6 +68,11 @@ type Response interface {
SetHeaders(headers Headers) error
}

type Streamable interface {
// StreamCallback get the stream callback if any.
StreamCallback() StreamCallback
}

// New creates a new RestClient.
func New(endpoint string) *RestClient {
return &RestClient{
Expand Down Expand Up @@ -105,10 +109,6 @@ func (r *RestClient) WithDecodeOnError(decodeOnError bool) *RestClient {
return r
}

func (r *RestClient) SetStreamCallback(streamCallback StreamCallback) {
r.streamCallback = streamCallback
}

func (r *RestClient) SetEndpoint(endpoint string) {
r.endpoint = endpoint
}
Expand Down Expand Up @@ -142,6 +142,7 @@ func (r *RestClient) Patch(ctx context.Context, request Request, response Respon
return r.do(ctx, methodPatch, request, response)
}

//nolint:gocognit
func (r *RestClient) do(ctx context.Context, method httpMethod, request Request, response Response) error {
requestPath, err := request.Path()
if err != nil {
Expand Down Expand Up @@ -206,40 +207,25 @@ func (r *RestClient) do(ctx context.Context, method httpMethod, request Request,
return nil
}

err = r.matchContentType(httpResponse, response)
err = matchContentType(httpResponse, response)
if err != nil {
return err
}

if r.streamCallback == nil {
err = response.Decode(httpResponse.Body)
if streamable, isStreamable := response.(Streamable); isStreamable && streamable.StreamCallback() != nil {
err = stream(streamable.StreamCallback(), httpResponse.Body)
} else {
err = r.decodeBody(httpResponse.Body)
err = response.Decode(httpResponse.Body)
}

if err != nil {
return fmt.Errorf("%w: %w", ErrResponseDecode, err)
}

return nil
}

func (r *RestClient) decodeBody(body io.Reader) error {
scanner := bufio.NewScanner(body)

scanBuf := make([]byte, 0, maxStreamBufferSize)
scanner.Buffer(scanBuf, maxStreamBufferSize)

for scanner.Scan() {
err := r.streamCallback(scanner.Bytes())
if err != nil {
return err
}
}

return nil
}

func (r *RestClient) matchContentType(httpResponse *http.Response, response Response) error {
func matchContentType(httpResponse *http.Response, response Response) error {
contentTypeToMatch := response.AcceptContentType()
contentType := httpResponse.Header.Get("Content-Type")

Expand All @@ -255,3 +241,19 @@ func (r *RestClient) matchContentType(httpResponse *http.Response, response Resp

return ErrNoContentType
}

func stream(streamCallback StreamCallback, body io.Reader) error {
scanner := bufio.NewScanner(body)

scanBuf := make([]byte, 0, maxStreamBufferSize)
scanner.Buffer(scanBuf, maxStreamBufferSize)

for scanner.Scan() {
err := streamCallback(scanner.Bytes())
if err != nil {
return err
}
}

return nil
}

0 comments on commit 07713f3

Please sign in to comment.