From f61effbfa01c5e8db2072e94560e6120f01b05b1 Mon Sep 17 00:00:00 2001 From: scorix Date: Tue, 3 Dec 2024 11:36:24 +0800 Subject: [PATCH] feat: boundary cache (#36) --- pkg/grib2/cache/boundary.go | 52 ++++++++++++++++++++++++++++++++ pkg/grib2/cache/boundary_test.go | 26 ++++++++++++++++ pkg/grib2/cache/cache.go | 21 +++++++++++++ pkg/grib2/cache/cache_test.go | 36 ++++++++++++++++++++++ pkg/grib2/message.go | 47 ++++++++++++++++++++--------- 5 files changed, 168 insertions(+), 14 deletions(-) create mode 100644 pkg/grib2/cache/boundary.go create mode 100644 pkg/grib2/cache/boundary_test.go create mode 100644 pkg/grib2/cache/cache.go create mode 100644 pkg/grib2/cache/cache_test.go diff --git a/pkg/grib2/cache/boundary.go b/pkg/grib2/cache/boundary.go new file mode 100644 index 0000000..542a35d --- /dev/null +++ b/pkg/grib2/cache/boundary.go @@ -0,0 +1,52 @@ +package cache + +import ( + "fmt" + + "golang.org/x/sync/singleflight" +) + +type boundary struct { + minLat float32 + maxLat float32 + minLon float32 + maxLon float32 + + cache map[int]float32 + datasource GridDataSource + sfg singleflight.Group +} + +func NewBoundary(minLat, maxLat, minLon, maxLon float32, datasource GridDataSource) GridCache { + return &boundary{ + minLat: minLat, + maxLat: maxLat, + minLon: minLon, + maxLon: maxLon, + datasource: datasource, + cache: make(map[int]float32), + } +} + +func (b *boundary) ReadGridAt(grid int, lat, lon float32) (float32, error) { + if lat < b.minLat || lat > b.maxLat || lon < b.minLon || lon > b.maxLon { + return b.datasource.ReadGridAt(grid) + } + + v, err, _ := b.sfg.Do(fmt.Sprintf("%d", grid), func() (interface{}, error) { + vFromCache, ok := b.cache[grid] + if !ok { + vFromSource, err := b.datasource.ReadGridAt(grid) + if err != nil { + return 0, err + } + + b.cache[grid] = vFromSource + vFromCache = vFromSource + } + + return vFromCache, nil + }) + + return v.(float32), err +} diff --git a/pkg/grib2/cache/boundary_test.go b/pkg/grib2/cache/boundary_test.go new file mode 100644 index 0000000..5e4afd7 --- /dev/null +++ b/pkg/grib2/cache/boundary_test.go @@ -0,0 +1,26 @@ +package cache_test + +import ( + "testing" + + "github.com/scorix/grib-go/pkg/grib2/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBoundary(t *testing.T) { + ds := &mockGridDataSource{gridValue: 100} + bc := cache.NewBoundary(0, 10, 0, 10, ds) + + // first read should be from source + v, err := bc.ReadGridAt(1, 1, 1) + require.NoError(t, err) + assert.Equal(t, float32(100), v) + assert.Equal(t, 1, ds.readCount) + + // second read should be cached + v, err = bc.ReadGridAt(1, 1, 1) + require.NoError(t, err) + assert.Equal(t, float32(100), v) + assert.Equal(t, 1, ds.readCount) +} diff --git a/pkg/grib2/cache/cache.go b/pkg/grib2/cache/cache.go new file mode 100644 index 0000000..2ed5c6d --- /dev/null +++ b/pkg/grib2/cache/cache.go @@ -0,0 +1,21 @@ +package cache + +type GridDataSource interface { + ReadGridAt(grid int) (float32, error) +} + +type GridCache interface { + ReadGridAt(grid int, lat, lon float32) (float32, error) +} + +type noCache struct { + datasource GridDataSource +} + +func NewNoCache(datasource GridDataSource) GridCache { + return &noCache{datasource: datasource} +} + +func (n *noCache) ReadGridAt(grid int, lat, lon float32) (float32, error) { + return n.datasource.ReadGridAt(grid) +} diff --git a/pkg/grib2/cache/cache_test.go b/pkg/grib2/cache/cache_test.go new file mode 100644 index 0000000..81079b9 --- /dev/null +++ b/pkg/grib2/cache/cache_test.go @@ -0,0 +1,36 @@ +package cache_test + +import ( + "testing" + + "github.com/scorix/grib-go/pkg/grib2/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockGridDataSource struct { + gridValue float32 + readCount int +} + +func (m *mockGridDataSource) ReadGridAt(grid int) (float32, error) { + m.readCount++ + return m.gridValue, nil +} + +func TestNoCache(t *testing.T) { + ds := &mockGridDataSource{gridValue: 100} + nc := cache.NewNoCache(ds) + + // first read should be from source + v, err := nc.ReadGridAt(1, 1, 1) + require.NoError(t, err) + assert.Equal(t, float32(100), v) + assert.Equal(t, 1, ds.readCount) + + // second read should not be cached + v, err = nc.ReadGridAt(1, 1, 1) + require.NoError(t, err) + assert.Equal(t, float32(100), v) + assert.Equal(t, 2, ds.readCount) +} diff --git a/pkg/grib2/message.go b/pkg/grib2/message.go index a26a2e6..f5e48df 100644 --- a/pkg/grib2/message.go +++ b/pkg/grib2/message.go @@ -8,6 +8,7 @@ import ( "time" "github.com/scorix/grib-go/internal/pkg/bitio" + "github.com/scorix/grib-go/pkg/grib2/cache" "github.com/scorix/grib-go/pkg/grib2/drt" gridpoint "github.com/scorix/grib-go/pkg/grib2/drt/grid_point" "github.com/scorix/grib-go/pkg/grib2/gdt" @@ -238,12 +239,13 @@ type MessageReader interface { } type simplePackingMessageReader struct { - sp *gridpoint.SimplePacking - spr *gridpoint.SimplePackingReader - gdt gdt.Template + sp *gridpoint.SimplePacking + spr *gridpoint.SimplePackingReader + gdt gdt.Template + cache cache.GridCache } -func NewSimplePackingMessageReaderFromMessage(r io.ReaderAt, m IndexedMessage) (MessageReader, error) { +func NewSimplePackingMessageReaderFromMessage(r io.ReaderAt, m IndexedMessage, opts ...SimplePackingMessageReaderOptions) (MessageReader, error) { sp, ok := m.GetDataRepresentationTemplate().(*gridpoint.SimplePacking) if !ok { return nil, fmt.Errorf("unsupported data representation template: %T", m.GetDataRepresentationTemplate()) @@ -251,31 +253,48 @@ func NewSimplePackingMessageReaderFromMessage(r io.ReaderAt, m IndexedMessage) ( gdt := m.GetGridDefinitionTemplate() - return NewSimplePackingMessageReader(r, m.GetOffset(), m.GetSize(), m.GetDataOffset(), sp, gdt) + return NewSimplePackingMessageReader(r, m.GetOffset(), m.GetSize(), m.GetDataOffset(), sp, gdt, opts...) } -func NewSimplePackingMessageReader(r io.ReaderAt, messageOffset int64, messageSize int64, dataOffset int64, sp *gridpoint.SimplePacking, gdt gdt.Template) (MessageReader, error) { - return &simplePackingMessageReader{ - spr: gridpoint.NewSimplePackingReader(r, dataOffset, messageOffset+messageSize, sp), - sp: sp, - gdt: gdt, - }, nil +type SimplePackingMessageReaderOptions func(r *simplePackingMessageReader) + +func WithBoundary(minLat, maxLat, minLon, maxLon float32) SimplePackingMessageReaderOptions { + return func(r *simplePackingMessageReader) { + r.cache = cache.NewBoundary(minLat, maxLat, minLon, maxLon, r.spr) + } +} + +func NewSimplePackingMessageReader(r io.ReaderAt, messageOffset int64, messageSize int64, dataOffset int64, sp *gridpoint.SimplePacking, gdt gdt.Template, opts ...SimplePackingMessageReaderOptions) (MessageReader, error) { + spr := gridpoint.NewSimplePackingReader(r, dataOffset, messageOffset+messageSize, sp) + + mr := &simplePackingMessageReader{ + spr: spr, + sp: sp, + gdt: gdt, + cache: cache.NewNoCache(spr), + } + + for _, opt := range opts { + opt(mr) + } + + return mr, nil } -func NewSimplePackingMessageReaderFromMessageIndex(r io.ReaderAt, mi *MessageIndex) (MessageReader, error) { +func NewSimplePackingMessageReaderFromMessageIndex(r io.ReaderAt, mi *MessageIndex, opts ...SimplePackingMessageReaderOptions) (MessageReader, error) { sp, ok := mi.Packing.(*gridpoint.SimplePacking) if !ok { return nil, fmt.Errorf("unsupported packing: %T", mi.Packing) } - return NewSimplePackingMessageReader(r, mi.Offset, mi.Size, mi.DataOffset, sp, mi.GridDefinition) + return NewSimplePackingMessageReader(r, mi.Offset, mi.Size, mi.DataOffset, sp, mi.GridDefinition, opts...) } func (r *simplePackingMessageReader) ReadLL(lat float32, lon float32) (float32, float32, float32, error) { grid := r.gdt.GetGridIndex(lat, lon) lat, lng := r.gdt.GetGridPoint(grid) - v, err := r.spr.ReadGridAt(grid) + v, err := r.cache.ReadGridAt(grid, lat, lng) if err != nil { return 0, 0, 0, fmt.Errorf("read grid at point %d (lat: %f, lon: %f): %w", grid, lat, lng, err) }