Skip to content

Commit

Permalink
Handle when etag changes on remote archive (#130)
Browse files Browse the repository at this point in the history
Implements ETag and retry logic for S3 and HTTP buckets, so archives can be correctly updated in-place.
  • Loading branch information
msbarry authored Feb 8, 2024
1 parent c402c54 commit 1f898fd
Show file tree
Hide file tree
Showing 7 changed files with 627 additions and 45 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21
require (
github.com/RoaringBitmap/roaring v1.5.0
github.com/alecthomas/kong v0.8.0
github.com/aws/aws-sdk-go v1.45.12
github.com/caddyserver/caddy/v2 v2.7.5
github.com/dustin/go-humanize v1.0.1
github.com/paulmach/orb v0.10.0
Expand Down Expand Up @@ -38,7 +39,6 @@ require (
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b // indirect
github.com/aws/aws-sdk-go v1.45.12 // indirect
github.com/aws/aws-sdk-go-v2 v1.20.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.11 // indirect
github.com/aws/aws-sdk-go-v2/config v1.18.32 // indirect
Expand Down
30 changes: 30 additions & 0 deletions go.sum

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func main() {
}
w.WriteHeader(statusCode)
w.Write(body)
logger.Printf("served %s in %s", r.URL.Path, time.Since(start))
logger.Printf("served %d %s in %s", statusCode, r.URL.Path, time.Since(start))
})

logger.Printf("Serving %s %s on port %d and interface %s with Access-Control-Allow-Origin: %s\n", cli.Serve.Bucket, cli.Serve.Path, cli.Serve.Port, cli.Serve.Interface, cli.Serve.Cors)
Expand Down
118 changes: 106 additions & 12 deletions pmtiles/bucket.go
Original file line number Diff line number Diff line change
@@ -1,65 +1,159 @@
package pmtiles

import (
"bytes"
"context"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"gocloud.dev/blob"
"io"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"gocloud.dev/blob"
)

// Bucket is an abstration over a gocloud or plain HTTP bucket.
type Bucket interface {
Close() error
NewRangeReader(ctx context.Context, key string, offset int64, length int64) (io.ReadCloser, error)
NewRangeReaderEtag(ctx context.Context, key string, offset int64, length int64, etag string) (io.ReadCloser, string, error)
}

// RefreshRequiredError is an error that indicates the etag has chanced on the remote file
type RefreshRequiredError struct {
StatusCode int
}

func (m *RefreshRequiredError) Error() string {
return fmt.Sprintf("HTTP error indicates file has changed: %d", m.StatusCode)
}

type mockBucket struct {
items map[string][]byte
}

func (m mockBucket) Close() error {
return nil
}

func (m mockBucket) NewRangeReader(ctx context.Context, key string, offset int64, length int64) (io.ReadCloser, error) {
body, _, err := m.NewRangeReaderEtag(ctx, key, offset, length, "")
return body, err

}
func (m mockBucket) NewRangeReaderEtag(_ context.Context, key string, offset int64, length int64, etag string) (io.ReadCloser, string, error) {
bs, ok := m.items[key]
if !ok {
return nil, "", fmt.Errorf("Not found %s", key)
}

hash := md5.Sum(bs)
resultEtag := hex.EncodeToString(hash[:])
if len(etag) > 0 && resultEtag != etag {
return nil, "", &RefreshRequiredError{}
}
if offset+length > int64(len(bs)) {
return nil, "", &RefreshRequiredError{416}
}

return io.NopCloser(bytes.NewReader(bs[offset:(offset + length)])), resultEtag, nil
}

// HTTPClient is an interface that lets you swap out the default client with a mock one in tests
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

type HTTPBucket struct {
baseURL string
client HTTPClient
}

func (b HTTPBucket) NewRangeReader(_ context.Context, key string, offset, length int64) (io.ReadCloser, error) {
func (b HTTPBucket) NewRangeReader(ctx context.Context, key string, offset, length int64) (io.ReadCloser, error) {
body, _, err := b.NewRangeReaderEtag(ctx, key, offset, length, "")
return body, err
}

func (b HTTPBucket) NewRangeReaderEtag(ctx context.Context, key string, offset, length int64, etag string) (io.ReadCloser, string, error) {
reqURL := b.baseURL + "/" + key

req, err := http.NewRequest("GET", reqURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return nil, err
return nil, "", err
}

req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+length-1))
if len(etag) > 0 {
req.Header.Set("If-Match", etag)
}

resp, err := http.DefaultClient.Do(req)
resp, err := b.client.Do(req)
if err != nil {
return nil, err
return nil, "", err
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
resp.Body.Close()
return nil, fmt.Errorf("HTTP error: %d", resp.StatusCode)
if isRefreshRequredCode(resp.StatusCode) {
err = &RefreshRequiredError{resp.StatusCode}
} else {
err = fmt.Errorf("HTTP error: %d", resp.StatusCode)
}
return nil, "", err
}

return resp.Body, nil
return resp.Body, resp.Header.Get("ETag"), nil
}

func (b HTTPBucket) Close() error {
return nil
}

func isRefreshRequredCode(code int) bool {
return code == http.StatusPreconditionFailed || code == http.StatusRequestedRangeNotSatisfiable
}

type BucketAdapter struct {
Bucket *blob.Bucket
}

func (ba BucketAdapter) NewRangeReader(ctx context.Context, key string, offset, length int64) (io.ReadCloser, error) {
reader, err := ba.Bucket.NewRangeReader(ctx, key, offset, length, nil)
body, _, err := ba.NewRangeReaderEtag(ctx, key, offset, length, "")
return body, err
}

func (ba BucketAdapter) NewRangeReaderEtag(ctx context.Context, key string, offset, length int64, etag string) (io.ReadCloser, string, error) {
reader, err := ba.Bucket.NewRangeReader(ctx, key, offset, length, &blob.ReaderOptions{
BeforeRead: func(asFunc func(interface{}) bool) error {
var req *s3.GetObjectInput
if len(etag) > 0 && asFunc(&req) {
req.IfMatch = &etag
}
return nil
},
})
if err != nil {
return nil, err
var resp awserr.RequestFailure
errors.As(err, &resp)
if resp != nil && isRefreshRequredCode(resp.StatusCode()) {
return nil, "", &RefreshRequiredError{resp.StatusCode()}
}
return nil, "", err
}
resultETag := ""
var resp s3.GetObjectOutput
if reader.As(&resp) {
resultETag = *resp.ETag
}
return reader, nil
return reader, resultETag, nil
}

func (ba BucketAdapter) Close() error {
Expand Down Expand Up @@ -101,7 +195,7 @@ func NormalizeBucketKey(bucket string, prefix string, key string) (string, strin

func OpenBucket(ctx context.Context, bucketURL string, bucketPrefix string) (Bucket, error) {
if strings.HasPrefix(bucketURL, "http") {
bucket := HTTPBucket{bucketURL}
bucket := HTTPBucket{bucketURL, http.DefaultClient}
return bucket, nil
}
bucket, err := blob.OpenBucket(ctx, bucketURL)
Expand Down
81 changes: 80 additions & 1 deletion pmtiles/bucket_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package pmtiles

import (
"github.com/stretchr/testify/assert"
"context"
"io"
"net/http"
"os"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNormalizeLocalFile(t *testing.T) {
Expand Down Expand Up @@ -35,3 +39,78 @@ func TestNormalizePathPrefixServer(t *testing.T) {
assert.True(t, strings.HasSuffix(bucket, "/foo"))
assert.True(t, strings.HasPrefix(bucket, "file://"))
}

type ClientMock struct {
request *http.Request
response *http.Response
}

func (c *ClientMock) Do(req *http.Request) (*http.Response, error) {
c.request = req
return c.response, nil
}

func TestHttpBucketRequestNormal(t *testing.T) {
mock := ClientMock{}
header := http.Header{}
header.Add("ETag", "etag")
bucket := HTTPBucket{"http://tiles.example.com/tiles", &mock}
mock.response = &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("abc")),
Header: header,
}
data, etag, err := bucket.NewRangeReaderEtag(context.Background(), "a/b/c", 100, 3, "")
assert.Equal(t, "", mock.request.Header.Get("If-Match"))
assert.Equal(t, "bytes=100-102", mock.request.Header.Get("Range"))
assert.Equal(t, "http://tiles.example.com/tiles/a/b/c", mock.request.URL.String())
assert.Nil(t, err)
b, err := io.ReadAll(data)
assert.Nil(t, err)
assert.Equal(t, "abc", string(b))
assert.Equal(t, "etag", etag)
assert.Nil(t, err)
}

func TestHttpBucketRequestRequestEtag(t *testing.T) {
mock := ClientMock{}
header := http.Header{}
header.Add("ETag", "etag2")
bucket := HTTPBucket{"http://tiles.example.com/tiles", &mock}
mock.response = &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("abc")),
Header: header,
}
data, etag, err := bucket.NewRangeReaderEtag(context.Background(), "a/b/c", 0, 3, "etag1")
assert.Equal(t, "etag1", mock.request.Header.Get("If-Match"))
assert.Nil(t, err)
b, err := io.ReadAll(data)
assert.Nil(t, err)
assert.Equal(t, "abc", string(b))
assert.Equal(t, "etag2", etag)
assert.Nil(t, err)
}

func TestHttpBucketRequestRequestEtagFailed(t *testing.T) {
mock := ClientMock{}
header := http.Header{}
header.Add("ETag", "etag2")
bucket := HTTPBucket{"http://tiles.example.com/tiles", &mock}
mock.response = &http.Response{
StatusCode: 412,
Body: io.NopCloser(strings.NewReader("abc")),
Header: header,
}
_, _, err := bucket.NewRangeReaderEtag(context.Background(), "a/b/c", 0, 3, "etag1")
assert.Equal(t, "etag1", mock.request.Header.Get("If-Match"))
assert.True(t, isRefreshRequredError(err))

mock.response.StatusCode = 416
_, _, err = bucket.NewRangeReaderEtag(context.Background(), "a/b/c", 0, 3, "etag1")
assert.True(t, isRefreshRequredError(err))

mock.response.StatusCode = 404
_, _, err = bucket.NewRangeReaderEtag(context.Background(), "a/b/c", 0, 3, "etag1")
assert.False(t, isRefreshRequredError(err))
}
Loading

0 comments on commit 1f898fd

Please sign in to comment.