From 12ee097367db410b25cbe2844e4b9acab1188df7 Mon Sep 17 00:00:00 2001 From: Artem Poltorzhitskiy Date: Mon, 9 Sep 2024 15:19:43 +0200 Subject: [PATCH] Hot fixes: validator addresses and CSRF (#278) * Hot fixes: validator addresses and CSRF * Fix: licences --- cmd/api/handler/address.go | 89 ++++++++++------------- cmd/api/init.go | 3 + internal/storage/address.go | 1 + internal/storage/mock/address.go | 44 +++++++++++ internal/storage/postgres/address.go | 13 ++++ internal/storage/postgres/address_test.go | 9 +++ 6 files changed, 107 insertions(+), 52 deletions(-) diff --git a/cmd/api/handler/address.go b/cmd/api/handler/address.go index 3fc2ae28..e46761de 100644 --- a/cmd/api/handler/address.go +++ b/cmd/api/handler/address.go @@ -4,6 +4,7 @@ package handler import ( + "context" "net/http" "time" @@ -184,13 +185,10 @@ func (handler *AddressHandler) Transactions(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } fltrs := storage.TxFilter{ Limit: req.Limit, @@ -210,7 +208,7 @@ func (handler *AddressHandler) Transactions(c echo.Context) error { fltrs.MessageTypes.SetByMsgType(storageTypes.MsgType(req.MsgType[i])) } - txs, err := handler.txs.ByAddress(c.Request().Context(), addressId[0], fltrs) + txs, err := handler.txs.ByAddress(c.Request().Context(), addressId, fltrs) if err != nil { return handleError(c, err, handler.address) } @@ -279,16 +277,13 @@ func (handler *AddressHandler) Messages(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } filters := req.ToFilters() - msgs, err := handler.messages.ByAddress(c.Request().Context(), addressId[0], filters) + msgs, err := handler.messages.ByAddress(c.Request().Context(), addressId, filters) if err != nil { return handleError(c, err, handler.address) } @@ -351,17 +346,14 @@ func (handler *AddressHandler) Blobs(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } logs, err := handler.blobLogs.BySigner( c.Request().Context(), - addressId[0], + addressId, storage.BlobLogFilters{ Limit: req.Limit, Offset: req.Offset, @@ -440,17 +432,14 @@ func (handler *AddressHandler) Delegations(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } delegations, err := handler.delegations.ByAddress( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, req.ShowZero, @@ -505,17 +494,14 @@ func (handler *AddressHandler) Undelegations(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } undelegations, err := handler.undelegations.ByAddress( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, ) @@ -557,17 +543,14 @@ func (handler *AddressHandler) Redelegations(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } redelegations, err := handler.redelegations.ByAddress( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, ) @@ -623,17 +606,13 @@ func (handler *AddressHandler) Vestings(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } - vestings, err := handler.vestings.ByAddress( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, req.ShowEnded, @@ -676,17 +655,14 @@ func (handler *AddressHandler) Grants(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } grants, err := handler.grants.ByGranter( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, ) @@ -727,17 +703,13 @@ func (handler *AddressHandler) Grantee(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } - grants, err := handler.grants.ByGrantee( c.Request().Context(), - addressId[0], + addressId, req.Limit, req.Offset, ) @@ -787,17 +759,14 @@ func (handler *AddressHandler) Stats(c echo.Context) error { return badRequestError(c, err) } - addressId, err := handler.address.IdByHash(c.Request().Context(), hash) + addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash) if err != nil { return handleError(c, err, handler.address) } - if len(addressId) != 1 { - return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash)) - } series, err := handler.address.Series( c.Request().Context(), - addressId[0], + addressId, storage.Timeframe(req.Timeframe), req.SeriesName, storage.NewSeriesRequest(req.From, req.To), @@ -812,3 +781,19 @@ func (handler *AddressHandler) Stats(c echo.Context) error { } return returnArray(c, response) } + +func (handler *AddressHandler) getIdByHash(ctx context.Context, hash []byte, address string) (uint64, error) { + addressId, err := handler.address.IdByHash(ctx, hash) + if err != nil { + return 0, err + } + + switch len(addressId) { + case 0: + return 0, errors.Errorf("can't find address: %s", address) + case 1: + return addressId[0], nil + default: + return handler.address.IdByAddress(ctx, address, addressId...) + } +} diff --git a/cmd/api/init.go b/cmd/api/init.go index 9e2bad24..2ff4e096 100644 --- a/cmd/api/init.go +++ b/cmd/api/init.go @@ -112,6 +112,9 @@ func metricsSkipper(c echo.Context) bool { } func postSkipper(c echo.Context) bool { + if c.Request().Method != http.MethodPost { + return true + } if strings.HasPrefix(c.Path(), "/v1/blob") { return true } diff --git a/internal/storage/address.go b/internal/storage/address.go index 60f4d136..7b103148 100644 --- a/internal/storage/address.go +++ b/internal/storage/address.go @@ -27,6 +27,7 @@ type IAddress interface { ListWithBalance(ctx context.Context, filters AddressListFilter) ([]Address, error) Series(ctx context.Context, addressId uint64, timeframe Timeframe, column string, req SeriesRequest) (items []HistogramItem, err error) IdByHash(ctx context.Context, hash ...[]byte) ([]uint64, error) + IdByAddress(ctx context.Context, address string, ids ...uint64) (uint64, error) } // Address - diff --git a/internal/storage/mock/address.go b/internal/storage/mock/address.go index b2fed3f2..5c566c82 100644 --- a/internal/storage/mock/address.go +++ b/internal/storage/mock/address.go @@ -161,6 +161,50 @@ func (c *MockIAddressGetByIDCall) DoAndReturn(f func(context.Context, uint64) (* return c } +// IdByAddress mocks base method. +func (m *MockIAddress) IdByAddress(ctx context.Context, address string, ids ...uint64) (uint64, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, address} + for _, a := range ids { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "IdByAddress", varargs...) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IdByAddress indicates an expected call of IdByAddress. +func (mr *MockIAddressMockRecorder) IdByAddress(ctx, address any, ids ...any) *MockIAddressIdByAddressCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, address}, ids...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IdByAddress", reflect.TypeOf((*MockIAddress)(nil).IdByAddress), varargs...) + return &MockIAddressIdByAddressCall{Call: call} +} + +// MockIAddressIdByAddressCall wrap *gomock.Call +type MockIAddressIdByAddressCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockIAddressIdByAddressCall) Return(arg0 uint64, arg1 error) *MockIAddressIdByAddressCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockIAddressIdByAddressCall) Do(f func(context.Context, string, ...uint64) (uint64, error)) *MockIAddressIdByAddressCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockIAddressIdByAddressCall) DoAndReturn(f func(context.Context, string, ...uint64) (uint64, error)) *MockIAddressIdByAddressCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // IdByHash mocks base method. func (m *MockIAddress) IdByHash(ctx context.Context, hash ...[]byte) ([]uint64, error) { m.ctrl.T.Helper() diff --git a/internal/storage/postgres/address.go b/internal/storage/postgres/address.go index f3a701a0..adfe7b20 100644 --- a/internal/storage/postgres/address.go +++ b/internal/storage/postgres/address.go @@ -121,3 +121,16 @@ func (a *Address) IdByHash(ctx context.Context, hash ...[]byte) (id []uint64, er Scan(ctx, &id) return } + +// IdByAddress - +func (a *Address) IdByAddress(ctx context.Context, address string, ids ...uint64) (id uint64, err error) { + query := a.DB().NewSelect(). + Model((*storage.Address)(nil)). + Column("id"). + Where("address = ?", address) + if len(ids) > 0 { + query = query.Where("id IN (?)", bun.In(ids)) + } + err = query.Scan(ctx, &id) + return +} diff --git a/internal/storage/postgres/address_test.go b/internal/storage/postgres/address_test.go index e3351865..a9930999 100644 --- a/internal/storage/postgres/address_test.go +++ b/internal/storage/postgres/address_test.go @@ -235,3 +235,12 @@ func (s *StorageTestSuite) TestAddressIdByHash() { s.Require().Len(id, 1) s.Require().EqualValues(1, id[0]) } + +func (s *StorageTestSuite) TestAddressIdByAddress() { + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + id, err := s.storage.Address.IdByAddress(ctx, "celestia1jc92qdnty48pafummfr8ava2tjtuhfdw774w60", 2, 3, 4) + s.Require().NoError(err) + s.Require().EqualValues(2, id) +}