Skip to content

Commit

Permalink
feat: use context (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
scorix authored Dec 3, 2024
1 parent f61effb commit b36e445
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 27 deletions.
15 changes: 8 additions & 7 deletions pkg/grib2/cache/boundary.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache

import (
"context"
"fmt"

"golang.org/x/sync/singleflight"
Expand All @@ -12,7 +13,7 @@ type boundary struct {
minLon float32
maxLon float32

cache map[int]float32
cache Store
datasource GridDataSource
sfg singleflight.Group
}
Expand All @@ -24,24 +25,24 @@ func NewBoundary(minLat, maxLat, minLon, maxLon float32, datasource GridDataSour
minLon: minLon,
maxLon: maxLon,
datasource: datasource,
cache: make(map[int]float32),
cache: NewMapStore(),
}
}

func (b *boundary) ReadGridAt(grid int, lat, lon float32) (float32, error) {
func (b *boundary) ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error) {
if lat < b.minLat || lat > b.maxLat || lon < b.minLon || lon > b.maxLon {
return b.datasource.ReadGridAt(grid)
return b.datasource.ReadGridAt(ctx, grid)
}

v, err, _ := b.sfg.Do(fmt.Sprintf("%d", grid), func() (interface{}, error) {
vFromCache, ok := b.cache[grid]
vFromCache, ok := b.cache.Get(ctx, fmt.Sprintf("%d", grid))
if !ok {
vFromSource, err := b.datasource.ReadGridAt(grid)
vFromSource, err := b.datasource.ReadGridAt(ctx, grid)
if err != nil {
return 0, err
}

b.cache[grid] = vFromSource
b.cache.Set(ctx, fmt.Sprintf("%d", grid), vFromSource)
vFromCache = vFromSource
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/grib2/cache/boundary_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache_test

import (
"context"
"testing"

"github.com/scorix/grib-go/pkg/grib2/cache"
Expand All @@ -13,13 +14,13 @@ func TestBoundary(t *testing.T) {
bc := cache.NewBoundary(0, 10, 0, 10, ds)

// first read should be from source
v, err := bc.ReadGridAt(1, 1, 1)
v, err := bc.ReadGridAt(context.TODO(), 1, 1, 1)
require.NoError(t, err)
assert.Equal(t, float32(100), v)
assert.Equal(t, 1, ds.readCount)

// second read should be cached
v, err = bc.ReadGridAt(1, 1, 1)
v, err = bc.ReadGridAt(context.TODO(), 1, 1, 1)
require.NoError(t, err)
assert.Equal(t, float32(100), v)
assert.Equal(t, 1, ds.readCount)
Expand Down
10 changes: 6 additions & 4 deletions pkg/grib2/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package cache

import "context"

type GridDataSource interface {
ReadGridAt(grid int) (float32, error)
ReadGridAt(ctx context.Context, grid int) (float32, error)
}

type GridCache interface {
ReadGridAt(grid int, lat, lon float32) (float32, error)
ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error)
}

type noCache struct {
Expand All @@ -16,6 +18,6 @@ func NewNoCache(datasource GridDataSource) GridCache {
return &noCache{datasource: datasource}
}

func (n *noCache) ReadGridAt(grid int, lat, lon float32) (float32, error) {
return n.datasource.ReadGridAt(grid)
func (n *noCache) ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error) {
return n.datasource.ReadGridAt(ctx, grid)
}
7 changes: 4 additions & 3 deletions pkg/grib2/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache_test

import (
"context"
"testing"

"github.com/scorix/grib-go/pkg/grib2/cache"
Expand All @@ -13,7 +14,7 @@ type mockGridDataSource struct {
readCount int
}

func (m *mockGridDataSource) ReadGridAt(grid int) (float32, error) {
func (m *mockGridDataSource) ReadGridAt(ctx context.Context, grid int) (float32, error) {
m.readCount++
return m.gridValue, nil
}
Expand All @@ -23,13 +24,13 @@ func TestNoCache(t *testing.T) {
nc := cache.NewNoCache(ds)

// first read should be from source
v, err := nc.ReadGridAt(1, 1, 1)
v, err := nc.ReadGridAt(context.TODO(), 1, 1, 1)
require.NoError(t, err)
assert.Equal(t, float32(100), v)
assert.Equal(t, 1, ds.readCount)

// second read should not be cached
v, err = nc.ReadGridAt(1, 1, 1)
v, err = nc.ReadGridAt(context.TODO(), 1, 1, 1)
require.NoError(t, err)
assert.Equal(t, float32(100), v)
assert.Equal(t, 2, ds.readCount)
Expand Down
35 changes: 35 additions & 0 deletions pkg/grib2/cache/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package cache

import (
"context"
"sync"
)

type Store interface {
Get(ctx context.Context, key string) (float32, bool)
Set(ctx context.Context, key string, value float32)
}

type mapStore struct {
mu sync.RWMutex
cache map[string]float32
}

func NewMapStore() Store {
return &mapStore{cache: make(map[string]float32)}
}

func (m *mapStore) Get(ctx context.Context, key string) (float32, bool) {
m.mu.RLock()
defer m.mu.RUnlock()

v, ok := m.cache[key]
return v, ok
}

func (m *mapStore) Set(ctx context.Context, key string, value float32) {
m.mu.Lock()
defer m.mu.Unlock()

m.cache[key] = value
}
3 changes: 2 additions & 1 deletion pkg/grib2/drt/grid_point/simple.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gridpoint

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -100,7 +101,7 @@ func NewSimplePackingReader(r io.ReaderAt, start, end int64, sp *SimplePacking)
}
}

func (r *SimplePackingReader) ReadGridAt(n int) (float32, error) {
func (r *SimplePackingReader) ReadGridAt(ctx context.Context, n int) (float32, error) {
if n >= r.sp.NumVals {
return 0, fmt.Errorf("requesting[%d] is out of range, total number of values is %d", n, r.sp.NumVals)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/grib2/drt/grid_point/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gridpoint_test

import (
"bytes"
"context"
"math"
"testing"

Expand Down Expand Up @@ -44,7 +45,7 @@ func TestSimpleScaleReader(t *testing.T) {
t.Logf("simple packing: %+v", sp)

r := gridpoint.NewSimplePackingReader(bf, 0, 5, sp)
f, err := r.ReadGridAt(2)
f, err := r.ReadGridAt(context.TODO(), 2)
require.NoError(t, err)
assert.InDelta(t, float32(2.9611706734e+02), f, 1e-5)
}
7 changes: 4 additions & 3 deletions pkg/grib2/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package grib2

import (
"bytes"
"context"
"fmt"
"image"
"io"
Expand Down Expand Up @@ -235,7 +236,7 @@ func (m *message) Image() (image.Image, error) {
}

type MessageReader interface {
ReadLL(float32, float32) (float32, float32, float32, error)
ReadLL(ctx context.Context, lat float32, lon float32) (float32, float32, float32, error)
}

type simplePackingMessageReader struct {
Expand Down Expand Up @@ -290,11 +291,11 @@ func NewSimplePackingMessageReaderFromMessageIndex(r io.ReaderAt, mi *MessageInd
return NewSimplePackingMessageReader(r, mi.Offset, mi.Size, mi.DataOffset, sp, mi.GridDefinition, opts...)
}

func (r *simplePackingMessageReader) ReadLL(lat float32, lon float32) (float32, float32, float32, error) {
func (r *simplePackingMessageReader) ReadLL(ctx context.Context, lat float32, lon float32) (float32, float32, float32, error) {
grid := r.gdt.GetGridIndex(lat, lon)
lat, lng := r.gdt.GetGridPoint(grid)

v, err := r.cache.ReadGridAt(grid, lat, lng)
v, err := r.cache.ReadGridAt(ctx, grid, lat, lng)
if err != nil {
return 0, 0, 0, fmt.Errorf("read grid at point %d (lat: %f, lon: %f): %w", grid, lat, lng, err)
}
Expand Down
13 changes: 7 additions & 6 deletions pkg/grib2/message_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package grib2_test

import (
"context"
"errors"
"os"
"testing"
Expand Down Expand Up @@ -82,14 +83,14 @@ func TestMessageReader_ReadLL(t *testing.T) {
require.Equalf(t, i, grd, "expect: (%f,%f,%d), actual: (%f,%f,%d)", lat, lng, i, lat32, lng32, grd)

{
_, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
_, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
require.NoError(t, err)
require.InDelta(t, float32(val), float32(v), 1e-5)
}

{
// read again
_, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
_, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
require.NoError(t, err)
require.InDelta(t, float32(val), float32(v), 1e-5)
}
Expand Down Expand Up @@ -157,14 +158,14 @@ func TestMessageReader_ReadLL(t *testing.T) {
require.Equalf(t, i, grd, "expect: (%f,%f,%d), actual: (%f,%f,%d)", lat, lng, i, lat32, lng32, grd)

{
_, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
_, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
require.NoError(t, err)
require.InDelta(t, float32(val), float32(v), 1e-5)
}

{
// read again
_, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
_, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out
require.NoError(t, err)
require.InDelta(t, float32(val), float32(v), 1e-5)
}
Expand Down Expand Up @@ -283,7 +284,7 @@ func BenchmarkMessageReader_ReadLL(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
reader.ReadLL(-90, 0)
reader.ReadLL(context.TODO(), -90, 0)
}
})

Expand All @@ -308,7 +309,7 @@ func BenchmarkMessageReader_ReadLL(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
reader.ReadLL(-90, 0)
reader.ReadLL(context.TODO(), -90, 0)
}
})
}
Expand Down

0 comments on commit b36e445

Please sign in to comment.