Skip to content

Commit

Permalink
Hot fixes: validator addresses and CSRF (#278)
Browse files Browse the repository at this point in the history
* Hot fixes: validator addresses and CSRF

* Fix: licences
  • Loading branch information
aopoltorzhicky authored Sep 9, 2024
1 parent ffd60cb commit 12ee097
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 52 deletions.
89 changes: 37 additions & 52 deletions cmd/api/handler/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package handler

import (
"context"
"net/http"
"time"

Expand Down Expand Up @@ -184,13 +185,10 @@ func (handler *AddressHandler) Transactions(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

fltrs := storage.TxFilter{
Limit: req.Limit,
Expand All @@ -210,7 +208,7 @@ func (handler *AddressHandler) Transactions(c echo.Context) error {
fltrs.MessageTypes.SetByMsgType(storageTypes.MsgType(req.MsgType[i]))
}

txs, err := handler.txs.ByAddress(c.Request().Context(), addressId[0], fltrs)
txs, err := handler.txs.ByAddress(c.Request().Context(), addressId, fltrs)
if err != nil {
return handleError(c, err, handler.address)
}
Expand Down Expand Up @@ -279,16 +277,13 @@ func (handler *AddressHandler) Messages(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

filters := req.ToFilters()
msgs, err := handler.messages.ByAddress(c.Request().Context(), addressId[0], filters)
msgs, err := handler.messages.ByAddress(c.Request().Context(), addressId, filters)
if err != nil {
return handleError(c, err, handler.address)
}
Expand Down Expand Up @@ -351,17 +346,14 @@ func (handler *AddressHandler) Blobs(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

logs, err := handler.blobLogs.BySigner(
c.Request().Context(),
addressId[0],
addressId,
storage.BlobLogFilters{
Limit: req.Limit,
Offset: req.Offset,
Expand Down Expand Up @@ -440,17 +432,14 @@ func (handler *AddressHandler) Delegations(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

delegations, err := handler.delegations.ByAddress(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
req.ShowZero,
Expand Down Expand Up @@ -505,17 +494,14 @@ func (handler *AddressHandler) Undelegations(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

undelegations, err := handler.undelegations.ByAddress(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
)
Expand Down Expand Up @@ -557,17 +543,14 @@ func (handler *AddressHandler) Redelegations(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

redelegations, err := handler.redelegations.ByAddress(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
)
Expand Down Expand Up @@ -623,17 +606,13 @@ func (handler *AddressHandler) Vestings(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

vestings, err := handler.vestings.ByAddress(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
req.ShowEnded,
Expand Down Expand Up @@ -676,17 +655,14 @@ func (handler *AddressHandler) Grants(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

grants, err := handler.grants.ByGranter(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
)
Expand Down Expand Up @@ -727,17 +703,13 @@ func (handler *AddressHandler) Grantee(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

grants, err := handler.grants.ByGrantee(
c.Request().Context(),
addressId[0],
addressId,
req.Limit,
req.Offset,
)
Expand Down Expand Up @@ -787,17 +759,14 @@ func (handler *AddressHandler) Stats(c echo.Context) error {
return badRequestError(c, err)
}

addressId, err := handler.address.IdByHash(c.Request().Context(), hash)
addressId, err := handler.getIdByHash(c.Request().Context(), hash, req.Hash)
if err != nil {
return handleError(c, err, handler.address)
}
if len(addressId) != 1 {
return badRequestError(c, errors.Errorf("can't find address: %s", req.Hash))
}

series, err := handler.address.Series(
c.Request().Context(),
addressId[0],
addressId,
storage.Timeframe(req.Timeframe),
req.SeriesName,
storage.NewSeriesRequest(req.From, req.To),
Expand All @@ -812,3 +781,19 @@ func (handler *AddressHandler) Stats(c echo.Context) error {
}
return returnArray(c, response)
}

func (handler *AddressHandler) getIdByHash(ctx context.Context, hash []byte, address string) (uint64, error) {
addressId, err := handler.address.IdByHash(ctx, hash)
if err != nil {
return 0, err
}

switch len(addressId) {
case 0:
return 0, errors.Errorf("can't find address: %s", address)
case 1:
return addressId[0], nil
default:
return handler.address.IdByAddress(ctx, address, addressId...)
}
}
3 changes: 3 additions & 0 deletions cmd/api/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ func metricsSkipper(c echo.Context) bool {
}

func postSkipper(c echo.Context) bool {
if c.Request().Method != http.MethodPost {
return true
}
if strings.HasPrefix(c.Path(), "/v1/blob") {
return true
}
Expand Down
1 change: 1 addition & 0 deletions internal/storage/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type IAddress interface {
ListWithBalance(ctx context.Context, filters AddressListFilter) ([]Address, error)
Series(ctx context.Context, addressId uint64, timeframe Timeframe, column string, req SeriesRequest) (items []HistogramItem, err error)
IdByHash(ctx context.Context, hash ...[]byte) ([]uint64, error)
IdByAddress(ctx context.Context, address string, ids ...uint64) (uint64, error)
}

// Address -
Expand Down
44 changes: 44 additions & 0 deletions internal/storage/mock/address.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions internal/storage/postgres/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,16 @@ func (a *Address) IdByHash(ctx context.Context, hash ...[]byte) (id []uint64, er
Scan(ctx, &id)
return
}

// IdByAddress -
func (a *Address) IdByAddress(ctx context.Context, address string, ids ...uint64) (id uint64, err error) {
query := a.DB().NewSelect().
Model((*storage.Address)(nil)).
Column("id").
Where("address = ?", address)
if len(ids) > 0 {
query = query.Where("id IN (?)", bun.In(ids))
}
err = query.Scan(ctx, &id)
return
}
9 changes: 9 additions & 0 deletions internal/storage/postgres/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,12 @@ func (s *StorageTestSuite) TestAddressIdByHash() {
s.Require().Len(id, 1)
s.Require().EqualValues(1, id[0])
}

func (s *StorageTestSuite) TestAddressIdByAddress() {
ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer ctxCancel()

id, err := s.storage.Address.IdByAddress(ctx, "celestia1jc92qdnty48pafummfr8ava2tjtuhfdw774w60", 2, 3, 4)
s.Require().NoError(err)
s.Require().EqualValues(2, id)
}

0 comments on commit 12ee097

Please sign in to comment.