Skip to content

Commit

Permalink
Parser now takes an io.ReadCloser (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
btamadio authored Feb 25, 2021
1 parent 541aa5d commit c45f46e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
41 changes: 28 additions & 13 deletions http_reward_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,8 @@ func (h *HTTPSource) GetRewards(ctx context.Context, banditContext interface{})
if err != nil {
return nil, err
}
defer resp.Body.Close()

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

return h.parser.Parse(respBody)
return h.parser.Parse(resp.Body)
}

// HTTPDoer is a basic interface for making HTTP requests. The net/http Client can be used or you can bring your own.
Expand All @@ -81,7 +75,7 @@ type HttpDoer interface {

// RewardParser will be called to convert the response from the reward service to a slice of distributions.
type RewardParser interface {
Parse([]byte) ([]Dist, error)
Parse(io.ReadCloser) ([]Dist, error)
}

// ContextMarshaler is called on the banditContext and the result will become the body of the request to the bandit service.
Expand All @@ -100,9 +94,9 @@ func WithContextMarshaler(m ContextMarshaler) HTTPSourceOption {
}

// ParseFunc is an adapter to allow a normal function to be used as a RewardParser
type ParseFunc func([]byte) ([]Dist, error)
type ParseFunc func(io.ReadCloser) ([]Dist, error)

func (p ParseFunc) Parse(b []byte) ([]Dist, error) { return p(b) }
func (p ParseFunc) Parse(rc io.ReadCloser) ([]Dist, error) { return p(rc) }

// MarshalFunc is an adapter to allow a normal function to be used as a ContextMarshaler
type MarshalFunc func(banditContext interface{}) ([]byte, error)
Expand All @@ -114,7 +108,14 @@ func (m MarshalFunc) Marshal(banditContext interface{}) ([]byte, error) { return
// `[{"alpha": 123, "beta": 456}, {"alpha": 3.1415, "beta": 9.999}]`
// Returns an error if alpha or beta value are missing or less than 1 for any arm.
// Any additional keys are ignored.
func BetaFromJSON(data []byte) ([]Dist, error) {
func BetaFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

var resp []struct {
Alpha *float64 `json:"alpha"`
Beta *float64 `json:"beta"`
Expand Down Expand Up @@ -149,7 +150,14 @@ func BetaFromJSON(data []byte) ([]Dist, error) {
// `[{"mu": 123, "sigma": 456}, {"mu": 3.1415, "sigma": 9.999}]`
// Returns an error if mu or sigma value are missing or sigma is less than 0 for any arm.
// Any additional keys are ignored.
func NormalFromJSON(data []byte) ([]Dist, error) {
func NormalFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

var resp []struct {
Mu *float64 `json:"mu"`
Sigma *float64 `json:"sigma"`
Expand Down Expand Up @@ -180,7 +188,14 @@ func NormalFromJSON(data []byte) ([]Dist, error) {
// Expects the JSON data to be in the form:
// `[{"mu": 123}, {"mu": 3.1415}]`
// Returns an error if mu value is missing for any arm. Any additional keys are ignored.
func PointFromJSON(data []byte) ([]Dist, error) {
func PointFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

var resp []struct {
Mu *float64
}
Expand Down
14 changes: 8 additions & 6 deletions mab_test/http_reward_source_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package mab

import (
"bytes"
"io/ioutil"
"testing"

"github.com/stitchfix/mab"
Expand Down Expand Up @@ -47,7 +49,7 @@ func TestBetaFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.BetaFromJSON(test.data)
actual, err := mab.BetaFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -94,7 +96,7 @@ func TestBetaFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.BetaFromJSON(test.data)
_, err := mab.BetaFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down Expand Up @@ -147,7 +149,7 @@ func TestNormalFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.NormalFromJSON(test.data)
actual, err := mab.NormalFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -190,7 +192,7 @@ func TestNormalFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.NormalFromJSON(test.data)
_, err := mab.NormalFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down Expand Up @@ -243,7 +245,7 @@ func TestPointFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.PointFromJSON(test.data)
actual, err := mab.PointFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -278,7 +280,7 @@ func TestPointFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.PointFromJSON(test.data)
_, err := mab.PointFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down

0 comments on commit c45f46e

Please sign in to comment.