diff --git a/dagsync/ipnisync/cid_schema_hint.go b/dagsync/ipnisync/cid_schema_hint.go new file mode 100644 index 0000000..839f2e1 --- /dev/null +++ b/dagsync/ipnisync/cid_schema_hint.go @@ -0,0 +1,24 @@ +package ipnisync + +const ( + // CidSchemaHeader is the HTTP header used as an optional hint about the + // type of data requested by a CID. + CidSchemaHeader = "Ipni-Cid-Schema-Type" + // CidSchemaAd is a value for the CidSchemaHeader specifying advertiesement + // data is being requested. + CidSchemaAd = "advertisement" + // CidSchemaEntries is a value for the CidSchemaHeader specifying + // advertisement entries (multihash chunks) data is being requested. + CidSchemaEntries = "entries" +) + +// cidSchemaTypeKey is the type used for the key of CidSchemaHeader when set as +// a context value. +type cidSchemaTypeCtxKey string + +// CidSchemaCtxKey is used as the key when creating a context with a value or extracting the cid schema from a context. Examples: +// +// ctx := context.WithValue(ctx, CidSchemaCtxKey, CidSchemaAd) +// +// cidSchemaType, ok := ctx.Value(CidSchemaCtxKey).(string) +const CidSchemaCtxKey cidSchemaTypeCtxKey = CidSchemaHeader diff --git a/dagsync/ipnisync/publisher.go b/dagsync/ipnisync/publisher.go index 358f01e..0f0297e 100644 --- a/dagsync/ipnisync/publisher.go +++ b/dagsync/ipnisync/publisher.go @@ -1,6 +1,7 @@ package ipnisync import ( + "context" "errors" "fmt" "net/http" @@ -41,6 +42,10 @@ var _ http.Handler = (*Publisher)(nil) // NewPublisher creates a new ipni-sync publisher. Optionally, a libp2p stream // host can be provided to serve HTTP over libp2p. +// +// If the publisher receives a request that contains a valid CidSchemaHeader +// header, then the ipld.Context passed to the lsys Load function contains a +// context that has that header's value stored under the CidSchemaCtxKey key. func NewPublisher(lsys ipld.LinkSystem, privKey ic.PrivKey, options ...Option) (*Publisher, error) { opts, err := getOpts(options) if err != nil { @@ -218,7 +223,14 @@ func (p *Publisher) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid request: not a cid", http.StatusBadRequest) return } - item, err := p.lsys.Load(ipld.LinkContext{}, cidlink.Link{Cid: c}, basicnode.Prototype.Any) + + ipldCtx := ipld.LinkContext{} + reqType := r.Header.Get(CidSchemaHeader) + if reqType != "" { + ipldCtx.Ctx = context.WithValue(context.Background(), CidSchemaCtxKey, reqType) + } + + item, err := p.lsys.Load(ipldCtx, cidlink.Link{Cid: c}, basicnode.Prototype.Any) if err != nil { if errors.Is(err, ipld.ErrNotExists{}) { http.Error(w, "cid not found", http.StatusNotFound) diff --git a/dagsync/ipnisync/sync.go b/dagsync/ipnisync/sync.go index e4e18e4..fff68d2 100644 --- a/dagsync/ipnisync/sync.go +++ b/dagsync/ipnisync/sync.go @@ -226,6 +226,16 @@ func (s *Syncer) Sync(ctx context.Context, nextCid cid.Cid, sel ipld.Node) error return fmt.Errorf("failed to compile selector: %w", err) } + // Check for valid cid schema type if set. + cidSchemaType, ok := ctx.Value(CidSchemaCtxKey).(string) + if ok { + switch cidSchemaType { + case CidSchemaAd, CidSchemaEntries: + default: + return fmt.Errorf("invalid cid schema type value: %s", cidSchemaType) + } + } + cids, err := s.walkFetch(ctx, nextCid, xsel) if err != nil { return fmt.Errorf("failed to traverse requested dag: %w", err) @@ -307,6 +317,12 @@ retry: return err } + // Value already checked in Sync. + reqType, ok := ctx.Value(CidSchemaCtxKey).(string) + if ok { + req.Header.Set(CidSchemaHeader, reqType) + } + resp, err := s.client.Do(req) if err != nil { if len(s.urls) != 0 { diff --git a/dagsync/ipnisync/sync_test.go b/dagsync/ipnisync/sync_test.go index c748b27..055e02f 100644 --- a/dagsync/ipnisync/sync_test.go +++ b/dagsync/ipnisync/sync_test.go @@ -230,3 +230,62 @@ func TestIPNIsync_NotFoundReturnsContentNotFoundErr(t *testing.T) { require.NotNil(t, err) require.Contains(t, err.Error(), "content not found") } + +func TestRequestTypeHint(t *testing.T) { + pubPrK, _, err := crypto.GenerateKeyPairWithReader(crypto.RSA, 2048, rand.Reader) + require.NoError(t, err) + pubID, err := peer.IDFromPrivateKey(pubPrK) + require.NoError(t, err) + + var lastReqTypeHint string + + // Instantiate a dagsync publisher. + publs := cidlink.DefaultLinkSystem() + + publs.StorageReadOpener = func(lnkCtx linking.LinkContext, lnk datamodel.Link) (io.Reader, error) { + if lnkCtx.Ctx != nil { + hint, ok := lnkCtx.Ctx.Value(ipnisync.CidSchemaCtxKey).(string) + require.True(t, ok) + require.NotEmpty(t, hint) + lastReqTypeHint = hint + t.Log("Request type hint:", hint) + } else { + lastReqTypeHint = "" + } + + require.NotEmpty(t, lastReqTypeHint, "missing expected context value") + return nil, ipld.ErrNotExists{} + } + + pub, err := ipnisync.NewPublisher(publs, pubPrK, ipnisync.WithHTTPListenAddrs("0.0.0.0:0")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pub.Close()) }) + + ls := cidlink.DefaultLinkSystem() + store := &memstore.Store{} + ls.SetWriteStorage(store) + ls.SetReadStorage(store) + + sync := ipnisync.NewSync(ls, nil) + pubInfo := peer.AddrInfo{ + ID: pubID, + Addrs: pub.Addrs(), + } + syncer, err := sync.NewSyncer(pubInfo) + require.NoError(t, err) + + testCid, err := cid.Decode(sampleNFTStorageCid) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), ipnisync.CidSchemaCtxKey, ipnisync.CidSchemaAd) + _ = syncer.Sync(ctx, testCid, selectorparse.CommonSelector_MatchPoint) + require.Equal(t, ipnisync.CidSchemaAd, lastReqTypeHint) + + ctx = context.WithValue(context.Background(), ipnisync.CidSchemaCtxKey, ipnisync.CidSchemaEntries) + _ = syncer.Sync(ctx, testCid, selectorparse.CommonSelector_MatchPoint) + require.Equal(t, ipnisync.CidSchemaEntries, lastReqTypeHint) + + ctx = context.WithValue(context.Background(), ipnisync.CidSchemaCtxKey, "bad") + err = syncer.Sync(ctx, testCid, selectorparse.CommonSelector_MatchPoint) + require.ErrorContains(t, err, "invalid cid schema type value") +} diff --git a/dagsync/subscriber.go b/dagsync/subscriber.go index 4170ad9..1c64335 100644 --- a/dagsync/subscriber.go +++ b/dagsync/subscriber.go @@ -488,6 +488,7 @@ func (s *Subscriber) SyncAdChain(ctx context.Context, peerInfo peer.AddrInfo, op sel := ExploreRecursiveWithStopNode(depthLimit, s.adsSelectorSeq, stopLnk) + ctx = context.WithValue(ctx, ipnisync.CidSchemaCtxKey, ipnisync.CidSchemaAd) syncCount, err := hnd.handle(ctx, nextCid, sel, syncer, opts.blockHook, segdl, stopAtCid) if err != nil { return cid.Undef, fmt.Errorf("sync handler failed: %w", err) @@ -571,6 +572,7 @@ func (s *Subscriber) syncEntries(ctx context.Context, peerInfo peer.AddrInfo, en log.Debugw("Start entries sync", "peer", peerInfo.ID, "cid", entCid) + ctx = context.WithValue(ctx, ipnisync.CidSchemaCtxKey, ipnisync.CidSchemaEntries) _, err = hnd.handle(ctx, entCid, sel, syncer, bh, segdl, cid.Undef) if err != nil { return fmt.Errorf("sync handler failed: %w", err) @@ -872,6 +874,7 @@ func (h *handler) asyncSyncAdChain(ctx context.Context) { return } + ctx = context.WithValue(ctx, ipnisync.CidSchemaCtxKey, ipnisync.CidSchemaAd) sel := ExploreRecursiveWithStopNode(adsDepthLimit, h.subscriber.adsSelectorSeq, latestSyncLink) syncCount, err := h.handle(ctx, nextCid, sel, syncer, h.subscriber.generalBlockHook, h.subscriber.segDepthLimit, stopAtCid) if err != nil { diff --git a/dagsync/test/util.go b/dagsync/test/util.go index aa4bd8c..4e62f48 100644 --- a/dagsync/test/util.go +++ b/dagsync/test/util.go @@ -170,8 +170,12 @@ func encode(lsys ipld.LinkSystem, n ipld.Node) (ipld.Node, ipld.Link) { func MkLinkSystem(ds datastore.Batching) ipld.LinkSystem { lsys := cidlink.DefaultLinkSystem() - lsys.StorageReadOpener = func(_ ipld.LinkContext, lnk ipld.Link) (io.Reader, error) { - val, err := ds.Get(context.Background(), datastore.NewKey(lnk.String())) + lsys.StorageReadOpener = func(ipldCtx ipld.LinkContext, lnk ipld.Link) (io.Reader, error) { + ctx := ipldCtx.Ctx + if ctx == nil { + ctx = context.Background() + } + val, err := ds.Get(ctx, datastore.NewKey(lnk.String())) if err != nil { return nil, err }