diff --git a/pkg/grib2/cache/boundary.go b/pkg/grib2/cache/boundary.go index 542a35d..4426edb 100644 --- a/pkg/grib2/cache/boundary.go +++ b/pkg/grib2/cache/boundary.go @@ -1,6 +1,7 @@ package cache import ( + "context" "fmt" "golang.org/x/sync/singleflight" @@ -12,7 +13,7 @@ type boundary struct { minLon float32 maxLon float32 - cache map[int]float32 + cache Store datasource GridDataSource sfg singleflight.Group } @@ -24,24 +25,24 @@ func NewBoundary(minLat, maxLat, minLon, maxLon float32, datasource GridDataSour minLon: minLon, maxLon: maxLon, datasource: datasource, - cache: make(map[int]float32), + cache: NewMapStore(), } } -func (b *boundary) ReadGridAt(grid int, lat, lon float32) (float32, error) { +func (b *boundary) ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error) { if lat < b.minLat || lat > b.maxLat || lon < b.minLon || lon > b.maxLon { - return b.datasource.ReadGridAt(grid) + return b.datasource.ReadGridAt(ctx, grid) } v, err, _ := b.sfg.Do(fmt.Sprintf("%d", grid), func() (interface{}, error) { - vFromCache, ok := b.cache[grid] + vFromCache, ok := b.cache.Get(ctx, fmt.Sprintf("%d", grid)) if !ok { - vFromSource, err := b.datasource.ReadGridAt(grid) + vFromSource, err := b.datasource.ReadGridAt(ctx, grid) if err != nil { return 0, err } - b.cache[grid] = vFromSource + b.cache.Set(ctx, fmt.Sprintf("%d", grid), vFromSource) vFromCache = vFromSource } diff --git a/pkg/grib2/cache/boundary_test.go b/pkg/grib2/cache/boundary_test.go index 5e4afd7..661ad47 100644 --- a/pkg/grib2/cache/boundary_test.go +++ b/pkg/grib2/cache/boundary_test.go @@ -1,6 +1,7 @@ package cache_test import ( + "context" "testing" "github.com/scorix/grib-go/pkg/grib2/cache" @@ -13,13 +14,13 @@ func TestBoundary(t *testing.T) { bc := cache.NewBoundary(0, 10, 0, 10, ds) // first read should be from source - v, err := bc.ReadGridAt(1, 1, 1) + v, err := bc.ReadGridAt(context.TODO(), 1, 1, 1) require.NoError(t, err) assert.Equal(t, float32(100), v) assert.Equal(t, 1, ds.readCount) // second read should be cached - v, err = bc.ReadGridAt(1, 1, 1) + v, err = bc.ReadGridAt(context.TODO(), 1, 1, 1) require.NoError(t, err) assert.Equal(t, float32(100), v) assert.Equal(t, 1, ds.readCount) diff --git a/pkg/grib2/cache/cache.go b/pkg/grib2/cache/cache.go index 2ed5c6d..c0c57ad 100644 --- a/pkg/grib2/cache/cache.go +++ b/pkg/grib2/cache/cache.go @@ -1,11 +1,13 @@ package cache +import "context" + type GridDataSource interface { - ReadGridAt(grid int) (float32, error) + ReadGridAt(ctx context.Context, grid int) (float32, error) } type GridCache interface { - ReadGridAt(grid int, lat, lon float32) (float32, error) + ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error) } type noCache struct { @@ -16,6 +18,6 @@ func NewNoCache(datasource GridDataSource) GridCache { return &noCache{datasource: datasource} } -func (n *noCache) ReadGridAt(grid int, lat, lon float32) (float32, error) { - return n.datasource.ReadGridAt(grid) +func (n *noCache) ReadGridAt(ctx context.Context, grid int, lat, lon float32) (float32, error) { + return n.datasource.ReadGridAt(ctx, grid) } diff --git a/pkg/grib2/cache/cache_test.go b/pkg/grib2/cache/cache_test.go index 81079b9..e7f898a 100644 --- a/pkg/grib2/cache/cache_test.go +++ b/pkg/grib2/cache/cache_test.go @@ -1,6 +1,7 @@ package cache_test import ( + "context" "testing" "github.com/scorix/grib-go/pkg/grib2/cache" @@ -13,7 +14,7 @@ type mockGridDataSource struct { readCount int } -func (m *mockGridDataSource) ReadGridAt(grid int) (float32, error) { +func (m *mockGridDataSource) ReadGridAt(ctx context.Context, grid int) (float32, error) { m.readCount++ return m.gridValue, nil } @@ -23,13 +24,13 @@ func TestNoCache(t *testing.T) { nc := cache.NewNoCache(ds) // first read should be from source - v, err := nc.ReadGridAt(1, 1, 1) + v, err := nc.ReadGridAt(context.TODO(), 1, 1, 1) require.NoError(t, err) assert.Equal(t, float32(100), v) assert.Equal(t, 1, ds.readCount) // second read should not be cached - v, err = nc.ReadGridAt(1, 1, 1) + v, err = nc.ReadGridAt(context.TODO(), 1, 1, 1) require.NoError(t, err) assert.Equal(t, float32(100), v) assert.Equal(t, 2, ds.readCount) diff --git a/pkg/grib2/cache/store.go b/pkg/grib2/cache/store.go new file mode 100644 index 0000000..07081c3 --- /dev/null +++ b/pkg/grib2/cache/store.go @@ -0,0 +1,35 @@ +package cache + +import ( + "context" + "sync" +) + +type Store interface { + Get(ctx context.Context, key string) (float32, bool) + Set(ctx context.Context, key string, value float32) +} + +type mapStore struct { + mu sync.RWMutex + cache map[string]float32 +} + +func NewMapStore() Store { + return &mapStore{cache: make(map[string]float32)} +} + +func (m *mapStore) Get(ctx context.Context, key string) (float32, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + v, ok := m.cache[key] + return v, ok +} + +func (m *mapStore) Set(ctx context.Context, key string, value float32) { + m.mu.Lock() + defer m.mu.Unlock() + + m.cache[key] = value +} diff --git a/pkg/grib2/drt/grid_point/simple.go b/pkg/grib2/drt/grid_point/simple.go index 3adad3b..391ff6a 100644 --- a/pkg/grib2/drt/grid_point/simple.go +++ b/pkg/grib2/drt/grid_point/simple.go @@ -1,6 +1,7 @@ package gridpoint import ( + "context" "errors" "fmt" "io" @@ -100,7 +101,7 @@ func NewSimplePackingReader(r io.ReaderAt, start, end int64, sp *SimplePacking) } } -func (r *SimplePackingReader) ReadGridAt(n int) (float32, error) { +func (r *SimplePackingReader) ReadGridAt(ctx context.Context, n int) (float32, error) { if n >= r.sp.NumVals { return 0, fmt.Errorf("requesting[%d] is out of range, total number of values is %d", n, r.sp.NumVals) } diff --git a/pkg/grib2/drt/grid_point/simple_test.go b/pkg/grib2/drt/grid_point/simple_test.go index a6c463a..dec7f01 100644 --- a/pkg/grib2/drt/grid_point/simple_test.go +++ b/pkg/grib2/drt/grid_point/simple_test.go @@ -2,6 +2,7 @@ package gridpoint_test import ( "bytes" + "context" "math" "testing" @@ -44,7 +45,7 @@ func TestSimpleScaleReader(t *testing.T) { t.Logf("simple packing: %+v", sp) r := gridpoint.NewSimplePackingReader(bf, 0, 5, sp) - f, err := r.ReadGridAt(2) + f, err := r.ReadGridAt(context.TODO(), 2) require.NoError(t, err) assert.InDelta(t, float32(2.9611706734e+02), f, 1e-5) } diff --git a/pkg/grib2/message.go b/pkg/grib2/message.go index f5e48df..f2d429f 100644 --- a/pkg/grib2/message.go +++ b/pkg/grib2/message.go @@ -2,6 +2,7 @@ package grib2 import ( "bytes" + "context" "fmt" "image" "io" @@ -235,7 +236,7 @@ func (m *message) Image() (image.Image, error) { } type MessageReader interface { - ReadLL(float32, float32) (float32, float32, float32, error) + ReadLL(ctx context.Context, lat float32, lon float32) (float32, float32, float32, error) } type simplePackingMessageReader struct { @@ -290,11 +291,11 @@ func NewSimplePackingMessageReaderFromMessageIndex(r io.ReaderAt, mi *MessageInd return NewSimplePackingMessageReader(r, mi.Offset, mi.Size, mi.DataOffset, sp, mi.GridDefinition, opts...) } -func (r *simplePackingMessageReader) ReadLL(lat float32, lon float32) (float32, float32, float32, error) { +func (r *simplePackingMessageReader) ReadLL(ctx context.Context, lat float32, lon float32) (float32, float32, float32, error) { grid := r.gdt.GetGridIndex(lat, lon) lat, lng := r.gdt.GetGridPoint(grid) - v, err := r.cache.ReadGridAt(grid, lat, lng) + v, err := r.cache.ReadGridAt(ctx, grid, lat, lng) if err != nil { return 0, 0, 0, fmt.Errorf("read grid at point %d (lat: %f, lon: %f): %w", grid, lat, lng, err) } diff --git a/pkg/grib2/message_test.go b/pkg/grib2/message_test.go index c0999c5..0baf19a 100644 --- a/pkg/grib2/message_test.go +++ b/pkg/grib2/message_test.go @@ -1,6 +1,7 @@ package grib2_test import ( + "context" "errors" "os" "testing" @@ -82,14 +83,14 @@ func TestMessageReader_ReadLL(t *testing.T) { require.Equalf(t, i, grd, "expect: (%f,%f,%d), actual: (%f,%f,%d)", lat, lng, i, lat32, lng32, grd) { - _, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out + _, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out require.NoError(t, err) require.InDelta(t, float32(val), float32(v), 1e-5) } { // read again - _, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out + _, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out require.NoError(t, err) require.InDelta(t, float32(val), float32(v), 1e-5) } @@ -157,14 +158,14 @@ func TestMessageReader_ReadLL(t *testing.T) { require.Equalf(t, i, grd, "expect: (%f,%f,%d), actual: (%f,%f,%d)", lat, lng, i, lat32, lng32, grd) { - _, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out + _, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out require.NoError(t, err) require.InDelta(t, float32(val), float32(v), 1e-5) } { // read again - _, _, v, err := reader.ReadLL(lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out + _, _, v, err := reader.ReadLL(context.TODO(), lat32, lng32) // grib_get -l 90,0 pkg/testdata/hpbl.grib2.out require.NoError(t, err) require.InDelta(t, float32(val), float32(v), 1e-5) } @@ -283,7 +284,7 @@ func BenchmarkMessageReader_ReadLL(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - reader.ReadLL(-90, 0) + reader.ReadLL(context.TODO(), -90, 0) } }) @@ -308,7 +309,7 @@ func BenchmarkMessageReader_ReadLL(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - reader.ReadLL(-90, 0) + reader.ReadLL(context.TODO(), -90, 0) } }) }