Skip to content

Commit

Permalink
Feature: TTL cache for API (#228)
Browse files Browse the repository at this point in the history
* Feature: TTL cache for API

* Add license header
  • Loading branch information
aopoltorzhicky authored Jun 30, 2024
1 parent cb2c4e3 commit 137c0d8
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 194 deletions.
154 changes: 0 additions & 154 deletions cmd/api/cache/cache.go

This file was deleted.

8 changes: 8 additions & 0 deletions cmd/api/cache/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// SPDX-FileCopyrightText: 2024 PK Lab AG <contact@pklab.io>
// SPDX-License-Identifier: MIT

package cache

type Config struct {
MaxEntitiesCount int
}
10 changes: 10 additions & 0 deletions cmd/api/cache/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-FileCopyrightText: 2024 PK Lab AG <contact@pklab.io>
// SPDX-License-Identifier: MIT

package cache

type ICache interface {
Get(key string) ([]byte, bool)
Set(key string, data []byte)
Clear()
}
71 changes: 71 additions & 0 deletions cmd/api/cache/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// SPDX-FileCopyrightText: 2024 PK Lab AG <contact@pklab.io>
// SPDX-License-Identifier: MIT

package cache

import (
"net/http"

"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/pkg/errors"
)

type CacheMiddleware struct {
cache ICache
skipper middleware.Skipper
}

func Middleware(cache ICache, skipper middleware.Skipper) echo.MiddlewareFunc {
mdlwr := CacheMiddleware{
cache: cache,
skipper: skipper,
}
return mdlwr.Handler
}

func (m *CacheMiddleware) Handler(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if m.skipper != nil {
if m.skipper(c) {
return next(c)
}
}
path := c.Request().URL.String()

if data, ok := m.cache.Get(path); ok {
entry := new(CacheEntry)
if err := entry.Decode(data); err != nil {
return err
}
return entry.Replay(c.Response())
}

recorder := NewResponseRecorder(c.Response().Writer)
c.Response().Writer = recorder

if err := next(c); err != nil {
return err
}
return m.cacheResult(path, recorder)
}
}

func (m *CacheMiddleware) cacheResult(key string, r *ResponseRecorder) error {
result := r.Result()
if !m.isStatusCacheable(result) {
return nil
}

data, err := result.Encode()
if err != nil {
return errors.Wrap(err, "unable to read recorded response")
}

m.cache.Set(key, data)
return nil
}

func (m *CacheMiddleware) isStatusCacheable(e *CacheEntry) bool {
return e.StatusCode == http.StatusOK || e.StatusCode == http.StatusNoContent
}
88 changes: 88 additions & 0 deletions cmd/api/cache/observable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// SPDX-FileCopyrightText: 2024 PK Lab AG <contact@pklab.io>
// SPDX-License-Identifier: MIT

package cache

import (
"context"
"sync"

"github.com/celenium-io/celestia-indexer/cmd/api/bus"
"github.com/dipdup-io/workerpool"
)

type ObservableCache struct {
maxEntitiesCount int
observer *bus.Observer

m map[string][]byte
queue []string
mx *sync.RWMutex
g workerpool.Group
}

func NewObservableCache(cfg Config, observer *bus.Observer) *ObservableCache {
return &ObservableCache{
maxEntitiesCount: cfg.MaxEntitiesCount,
observer: observer,
m: make(map[string][]byte),
queue: make([]string, cfg.MaxEntitiesCount),
mx: new(sync.RWMutex),
g: workerpool.NewGroup(),
}
}

func (c *ObservableCache) Start(ctx context.Context) {
c.g.GoCtx(ctx, c.listen)
}

func (c *ObservableCache) listen(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-c.observer.Head():
c.Clear()
}
}
}

func (c *ObservableCache) Close() error {
c.g.Wait()
return nil
}

func (c *ObservableCache) Get(key string) ([]byte, bool) {
c.mx.RLock()
data, ok := c.m[key]
c.mx.RUnlock()
return data, ok
}

func (c *ObservableCache) Set(key string, data []byte) {
c.mx.Lock()
queueIdx := len(c.m)

if _, ok := c.m[key]; ok {
c.m[key] = data
} else {
c.m[key] = data
if queueIdx == c.maxEntitiesCount {
keyForRemove := c.queue[queueIdx-1]
c.queue = append([]string{key}, c.queue[:queueIdx-1]...)
delete(c.m, keyForRemove)
} else {
c.queue[c.maxEntitiesCount-queueIdx-1] = key
}
}
c.mx.Unlock()
}

func (c *ObservableCache) Clear() {
c.mx.Lock()
for key := range c.m {
delete(c.m, key)
}
c.queue = make([]string, c.maxEntitiesCount)
c.mx.Unlock()
}
16 changes: 8 additions & 8 deletions cmd/api/cache/cache_test.go → cmd/api/cache/observable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"github.com/stretchr/testify/require"
)

func TestCache_SetGet(t *testing.T) {
func TestObservableCache_SetGet(t *testing.T) {
t.Run("set and get key from cache", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 2}, nil)
c := NewObservableCache(Config{MaxEntitiesCount: 2}, nil)
c.Set("test", []byte{0, 1, 2, 3})

got, ok := c.Get("test")
Expand All @@ -25,7 +25,7 @@ func TestCache_SetGet(t *testing.T) {
})

t.Run("overflow set queue", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 2}, nil)
c := NewObservableCache(Config{MaxEntitiesCount: 2}, nil)
for i := 0; i < 100; i++ {
c.Set(fmt.Sprintf("%d", i), []byte{byte(i)})
}
Expand All @@ -46,13 +46,13 @@ func TestCache_SetGet(t *testing.T) {
})

t.Run("overflow set queue multithread", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 2}, nil)
c := NewObservableCache(Config{MaxEntitiesCount: 2}, nil)

var wg sync.WaitGroup

for i := 0; i < 10; i++ {
wg.Add(1)
go func(c *Cache, wg *sync.WaitGroup) {
go func(c *ObservableCache, wg *sync.WaitGroup) {
defer wg.Done()
for i := 0; i < 100; i++ {
c.Set(fmt.Sprintf("%d", i), []byte{byte(i)})
Expand All @@ -67,9 +67,9 @@ func TestCache_SetGet(t *testing.T) {
})
}

func TestCache_Clear(t *testing.T) {
t.Run("set and get key from cache", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 100}, nil)
func TestObservableCache_Clear(t *testing.T) {
t.Run("clear cache", func(t *testing.T) {
c := NewObservableCache(Config{MaxEntitiesCount: 100}, nil)
for i := 0; i < 100; i++ {
c.Set(fmt.Sprintf("%d", i), []byte{byte(i)})
}
Expand Down
Loading

0 comments on commit 137c0d8

Please sign in to comment.