Skip to content

Commit

Permalink
Fix getShardID does not return more than 100 shards (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
lordfarhan40 authored and harlow committed Feb 15, 2019
1 parent 2f58b13 commit 2037463
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
34 changes: 21 additions & 13 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
// new consumer with no-op checkpoint, counter, and logger
c := &Consumer{
streamName: streamName,
initialShardIteratorType: "TRIM_HORIZON",
initialShardIteratorType: kinesis.ShardIteratorTypeTrimHorizon,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
logger: &noopLogger{
Expand Down Expand Up @@ -241,20 +241,28 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan
}

func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
var ss []string
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
}
for {
resp, err := c.client.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("ListShards error: %v", err)
}

var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
for _, shard := range resp.Shards {
ss = append(ss, *shard.ShardId)
}

if resp.NextToken == nil {
return ss, nil
}

listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
}
}
return ss, nil
}

func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
Expand All @@ -264,7 +272,7 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
}

if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
Expand Down
28 changes: 12 additions & 16 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ func TestConsumer_Scan(t *testing.T) {
Records: records,
}, nil
},
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")},
},
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
Expand Down Expand Up @@ -94,11 +92,9 @@ func TestConsumer_Scan(t *testing.T) {

func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: make([]*kinesis.Shard, 0),
},
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: make([]*kinesis.Shard, 0),
}, nil
},
}
Expand Down Expand Up @@ -287,7 +283,11 @@ type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error)
}

func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return c.listShardsMock(in)
}

func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
Expand All @@ -298,10 +298,6 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
return c.getShardIteratorMock(in)
}

func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return c.describeStreamMock(in)
}

// implementation of checkpoint
type fakeCheckpoint struct {
cache map[string]string
Expand Down

0 comments on commit 2037463

Please sign in to comment.