Skip to content

Commit

Permalink
Add tests for cancelations
Browse files Browse the repository at this point in the history
  • Loading branch information
umpc committed Jul 4, 2017
1 parent fad4b9f commit ac52001
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 21 deletions.
18 changes: 14 additions & 4 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ func (iterCh *IterChCloser) Close() error {
case iterCh.canceled <- struct{}{}:
default:
}

return nil
}

func (iterCh *IterChCloser) Records() <-chan Record {
return iterCh.ch
select {
case <-iterCh.canceled:
iterCh.canceled <- struct{}{}
return nil
default:
return iterCh.ch
}
}

// IterChParams contains configurable settings for CustomIterCh.
Expand Down Expand Up @@ -59,16 +66,19 @@ func (sm *SortedMap) recordFromIdx(i int) Record {
}

func (sm *SortedMap) sendRecord(iterCh IterChCloser, sendTimeout time.Duration, i int) bool {
select {
case <-iterCh.canceled:
iterCh.canceled <- struct{}{}
return false
default:
}

if sendTimeout <= time.Duration(0) {
iterCh.ch <- sm.recordFromIdx(i)
return true
}

select {
case <-iterCh.canceled:
return false

case iterCh.ch <- sm.recordFromIdx(i):
return true

Expand Down
44 changes: 44 additions & 0 deletions iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,39 @@ func TestCustomIterCh(t *testing.T) {
}()
}

func TestCancelCustomIterCh(t *testing.T) {
sm, _, err := newSortedMapFromRandRecords(1000)
if err != nil {
t.Fatal(err)
}

earlierDate := time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC)
laterDate := time.Now()

func() {
params := IterChParams{
LowerBound: earlierDate,
UpperBound: laterDate,
}

ch, err := sm.CustomIterCh(params)
if err != nil {
t.Fatal(err)
}
defer ch.Close()

ch.Close()

if err := verifyRecords(ch.Records(), params.Reversed); err != nil {
if err.Error() != "Channel was nil." {
t.Fatal(err)
}
} else {
t.Fatal("Channel was not closed.")
}
}()
}

func TestIterFunc(t *testing.T) {
sm, _, err := newSortedMapFromRandRecords(1000)
if err != nil {
Expand Down Expand Up @@ -362,6 +395,17 @@ func TestBoundedIterFunc(t *testing.T) {
}
}

func TestTestBoundedIterFuncWithNoBoundsReturned(t *testing.T) {
sm, _, err := newSortedMapFromRandRecords(1000)
if err != nil {
t.Fatal(err)
}

if _, err := sm.BoundedKeys(time.Now().Add(-1*time.Microsecond), time.Now()); err == nil {
t.Fatal("Values fall between or are equal to the given bounds when it should not have returned bounds.")
}
}

func TestReversedBoundedIterFunc(t *testing.T) {
sm, _, err := newSortedMapFromRandRecords(1000)
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,13 @@ func TestBoundedKeys(t *testing.T) {
t.Fatal("The returned slice was empty.")
}
}

func TestBoundedKeysWithNoBoundsReturned(t *testing.T) {
sm, _, err := newSortedMapFromRandRecords(300)
if err != nil {
t.Fatal(err)
}
if val, err := sm.BoundedKeys(time.Now().Add(-1*time.Microsecond), time.Now()); err == nil {
t.Fatalf("Values fall between or are equal to the given bounds when it should not have returned bounds: %+v", sm.idx[val[0]])
}
}
38 changes: 21 additions & 17 deletions testing_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,30 @@ func randRecords(n int) []Record {
func verifyRecords(ch <-chan Record, reverse bool) error {
previousRec := Record{}

for rec := range ch {
if previousRec.Key != nil {
switch reverse {
case false:
if previousRec.Val.(time.Time).After(rec.Val.(time.Time)) {
return fmt.Errorf("%v %v",
unsortedErr,
fmt.Sprintf("prev: %+v, current: %+v.", previousRec, rec),
)
}
case true:
if previousRec.Val.(time.Time).Before(rec.Val.(time.Time)) {
return fmt.Errorf("%v %v",
unsortedErr,
fmt.Sprintf("prev: %+v, current: %+v.", previousRec, rec),
)
if ch != nil {
for rec := range ch {
if previousRec.Key != nil {
switch reverse {
case false:
if previousRec.Val.(time.Time).After(rec.Val.(time.Time)) {
return fmt.Errorf("%v %v",
unsortedErr,
fmt.Sprintf("prev: %+v, current: %+v.", previousRec, rec),
)
}
case true:
if previousRec.Val.(time.Time).Before(rec.Val.(time.Time)) {
return fmt.Errorf("%v %v",
unsortedErr,
fmt.Sprintf("prev: %+v, current: %+v.", previousRec, rec),
)
}
}
}
previousRec = rec
}
previousRec = rec
} else {
return fmt.Errorf("Channel was nil.")
}

return nil
Expand Down

0 comments on commit ac52001

Please sign in to comment.