Skip to content

Commit

Permalink
Refactor notice struct and improve http server handling
Browse files Browse the repository at this point in the history
  • Loading branch information
perrornet committed Jun 15, 2024
1 parent f1b0072 commit 359f39f
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 63 deletions.
79 changes: 51 additions & 28 deletions cmd/configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
yaml_ncoder "github.com/zwgblue/yaml-encoder"
"net/http"
"omni-balance/internal/daemons"
"omni-balance/internal/db"
"omni-balance/internal/models"
"omni-balance/utils/configs"
"omni-balance/utils/constant"
"os"
Expand All @@ -26,46 +28,63 @@ var (
setPlaceholderFinished = make(chan struct{}, 1)
)

func startHttpServer(ctx context.Context, port string) (func(ctx context.Context) error, error) {
server := &http.Server{
Addr: port,
Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if !strings.EqualFold(request.Method, http.MethodPost) {
writer.WriteHeader(http.StatusMethodNotAllowed)
return
}
var args = make(map[string]interface{})
if err := json.NewDecoder(request.Body).Decode(&args); err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
for k, v := range args {
placeholder.Store(k, v)
}
func startHttpServer(_ context.Context, port string) error {
http.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if !strings.EqualFold(request.Method, http.MethodPost) {
writer.WriteHeader(http.StatusMethodNotAllowed)
return
}
var args = make(map[string]interface{})
if err := json.NewDecoder(request.Body).Decode(&args); err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
for k, v := range args {
placeholder.Store(k, v)
}

setPlaceholderFinished <- struct{}{}
}))

setPlaceholderFinished <- struct{}{}
}),
http.Handle("/remove_order", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if !strings.EqualFold(request.Method, http.MethodPost) {
writer.WriteHeader(http.StatusMethodNotAllowed)
return
}
var order = struct {
Id int `json:"id" form:"id"`
}{}
if err := json.NewDecoder(request.Body).Decode(&order); err != nil {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte(err.Error()))
return
}
err := db.DB().Model(&models.Order{}).Where("id = ?", order.Id).Limit(1).Delete(&models.Order{}).Error
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.WriteHeader(http.StatusOK)
}))
server := &http.Server{
Addr: port,
Handler: http.DefaultServeMux,
}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logrus.Panic(err)
}
}()
return server.Shutdown, nil
logrus.Infof("http server started on %s", port)
return nil
}

func waitForPlaceholder(ctx context.Context, configPath, port string) (newConfigPath string, err error) {
func waitForPlaceholder(_ context.Context, configPath string) (newConfigPath string, err error) {
data, err := os.ReadFile(configPath)
if err != nil {
return "", err
}
shutdown, err := startHttpServer(ctx, port)
if err != nil {
return "", err
}
defer func() {
_ = shutdown(ctx)
}()

<-setPlaceholderFinished
placeholder.Range(func(key, value interface{}) bool {
data = bytes.ReplaceAll(data, []byte(key.(string)), []byte(cast.ToString(value)))
Expand All @@ -79,13 +98,17 @@ func waitForPlaceholder(ctx context.Context, configPath, port string) (newConfig
}

func initConfig(ctx context.Context, enablePlaceholder bool, configPath, serverPort string) (err error) {
err = startHttpServer(ctx, serverPort)
if err != nil {
return err
}
if enablePlaceholder {
ports := strings.Split(serverPort, ":")
if len(ports) < 2 {
ports = append([]string{}, "", "8080")
}
logrus.Infof("waiting for placeholder, you can use `curl -X POST -d '{\"<you_placeholder>\":\"0x1234567890\"}' http://127.0.0.1:%s` to set placeholder", ports[1])
configPath, err = waitForPlaceholder(context.Background(), configPath, serverPort)
configPath, err = waitForPlaceholder(context.Background(), configPath)
if err != nil {
return err
}
Expand Down
49 changes: 47 additions & 2 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package main

import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"io"
"net/http"
"net/url"
"omni-balance/internal/daemons"
_ "omni-balance/internal/daemons/cross_chain"
_ "omni-balance/internal/daemons/monitor"
Expand Down Expand Up @@ -80,10 +85,9 @@ func Action(cli *cli.Context) error {
logrus.SetFormatter(&logrus.JSONFormatter{})
}

if err := notice.Init(notice.Type(config.Notice.Type), config.Notice.Config); err != nil {
if err := notice.Init(notice.Type(config.Notice.Type), config.Notice.Config, config.Notice.Interval); err != nil {
logrus.Warnf("init notice error: %v", err)
}
notice.SetMsgInterval(config.Notice.Interval)

if err := db.InitDb(*config); err != nil {
return errors.Wrap(err, "init db")
Expand Down Expand Up @@ -114,6 +118,47 @@ func main() {
app.Name = "omni-balance"
app.Action = Action
app.Commands = []*cli.Command{
{
Name: "del_order",
Usage: "delete order by id",
Flags: []cli.Flag{
&cli.IntFlag{
Name: "id",
Usage: "order id",
},
&cli.StringFlag{
Name: "server",
Usage: "server host, example: http://127.0.0.1:8080",
Value: "http://127.0.0.1:8080",
},
},
Action: func(c *cli.Context) error {
u, err := url.Parse(c.String("server"))
if err != nil {
return errors.Wrap(err, "parse server url")
}
u.RawPath = "/remove_order"
u.Path = u.RawPath
var body = bytes.NewBuffer(nil)
err = json.NewEncoder(body).Encode(map[string]interface{}{
"id": c.Int("id"),
})
if err != nil {
return errors.Wrap(err, "encode body")
}
resp, err := http.Post(u.String(), "application/json", body)
if err != nil {
return errors.Wrap(err, "post")
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return errors.Errorf("http status code: %d, body is: %s", resp.StatusCode, data)
}
logrus.Infof("delete order #%d success", c.Int64("id"))
return nil
},
},
{
Name: "version",
Usage: "show version",
Expand Down
6 changes: 3 additions & 3 deletions internal/daemons/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ func runForever(ctx context.Context, conf configs.Config, task Task) {
func Run(ctx context.Context, conf configs.Config) error {
for index := range tasks {
if tasks[index].RunOnStart {
logrus.Infof("task %s run on start, wait for the task finished", tasks[index].Name)
logrus.Debugf("task %s run on start, wait for the task finished", tasks[index].Name)
if err := tasks[index].TaskFunc(ctx, conf); err != nil {
logrus.Errorf("task %s failed, err: %v", tasks[index].Name, err)
continue
}
logrus.Infof("task %s run on start finished", tasks[index].Name)
logrus.Debugf("task %s run on start finished", tasks[index].Name)
continue
}
}
for index := range tasks {
logrus.Infof("task %s run in background", tasks[index].Name)
logrus.Debugf("task %s run in background", tasks[index].Name)
go runForever(ctx, conf, tasks[index])
}
return nil
Expand Down
31 changes: 26 additions & 5 deletions internal/daemons/rebalance/rebalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,29 @@ func Run(ctx context.Context, conf configs.Config) error {
go func(order *models.Order) {
defer w.Done()
log := order.GetLogs()
utils.SetLogToCtx(ctx, log)
if err := reBalance(ctx, order, conf); err != nil {
subCtx, cancel := context.WithCancel(utils.SetLogToCtx(ctx, log))
defer cancel()

go func() {
defer cancel()
var t = time.NewTicker(time.Second * 5)
defer t.Stop()

for {
select {
case <-subCtx.Done():
return
case <-t.C:
var count int64
_ = db.DB().Model(&models.Order{}).Where("id = ?", order.ID).Count(&count)
if count == 0 {
log.Infof("order #%d not found, exit this order rebalance", order.ID)
return
}
}
}
}()
if err := reBalance(subCtx, order, conf); err != nil {
log.Errorf("reBalance order #%d error: %s", order.ID, err)
return
}
Expand Down Expand Up @@ -149,7 +170,8 @@ func reBalance(ctx context.Context, order *models.Order, conf configs.Config) er
return errors.Wrap(err, "save provider error")
}

log.Infof("start reBalance %s on %s use %s provider", order.TokenOutName, order.TargetChainName, providerObj.Name())
log.Infof("start reBalance #%d %s on %s use %s provider", order.ID, order.TokenOutName,
order.TargetChainName, providerObj.Name())
result, err := providerObj.Swap(ctx, args)
if err != nil {
return errors.Wrapf(err, "reBalance %s on %s error", order.TokenOutName, providerObj.Name())
Expand Down Expand Up @@ -321,12 +343,11 @@ func getReBalanceProvider(ctx context.Context, order models.Order, conf configs.
}

func providerSupportsOrder(ctx context.Context, p provider.Provider, order models.Order, conf configs.Config, log *logrus.Entry) (provider.TokenInCosts, bool) {
wallet := conf.GetWallet(order.Wallet)
tokenInCosts, err := p.GetCost(ctx, provider.SwapParams{
SourceToken: order.TokenInName,
Sender: conf.GetWallet(order.Wallet),
TargetToken: order.TokenOutName,
Receiver: wallet.GetAddress().Hex(),
Receiver: order.Wallet,
TargetChain: order.TargetChainName,
Amount: order.Amount,
})
Expand Down
2 changes: 1 addition & 1 deletion internal/daemons/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func CreateSwapParams(order models.Order, orderProcess models.OrderProcess, log
return provider.SwapParams{
SourceChain: order.CurrentChainName,
Sender: wallet,
Receiver: wallet.GetAddress().Hex(),
Receiver: order.Wallet,
TargetChain: order.TargetChainName,
SourceToken: order.TokenOutName,
TargetToken: order.TokenOutName,
Expand Down
16 changes: 8 additions & 8 deletions internal/models/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ import (
"github.com/shopspring/decimal"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"omni-balance/utils"
"omni-balance/utils/configs"
"time"
)

type OrderStatus string

const (
OrderStatusWait OrderStatus = "wait"
OrderStatusProcessing OrderStatus = "processing"
OrderStatusSuccess OrderStatus = "success"
OrderStatusFail OrderStatus = "fail"
OrderStatusWait OrderStatus = "wait"
//OrderStatusProcessing OrderStatus = "processing"
OrderStatusSuccess OrderStatus = "success"
//OrderStatusFail OrderStatus = "fail"
OrderStatusWaitTransferFromOperator OrderStatus = "wait_transfer_from_operator"
OrderStatusWaitCrossChain OrderStatus = "wait_cross_chain"
OrderStatusUnknown OrderStatus = "unknown"
Expand Down Expand Up @@ -95,8 +96,7 @@ func GetLastOrderProcess(ctx context.Context, db *gorm.DB, orderId uint) OrderPr
}

func (o *Order) GetLogs() *logrus.Entry {
data, _ := json.Marshal(o)
var fields logrus.Fields
_ = json.Unmarshal(data, &fields)
return logrus.WithFields(fields)
return logrus.WithFields(logrus.Fields{
"order": utils.ToMap(o),
})
}
4 changes: 2 additions & 2 deletions utils/chains/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func WaitForTx(ctx context.Context, client simulated.Client, txHash common.Hash)
return err
}
if errors.Is(err, ethereum.NotFound) {
log.Debugf("tx not found, txHash: %s, try again later", txHash.Hex())
log.Infof("tx not found, txHash: %s, try again later", txHash.Hex())
continue
}
if err != nil {
Expand All @@ -133,7 +133,7 @@ func WaitForTx(ctx context.Context, client simulated.Client, txHash common.Hash)
if tx.Status == 0 {
return errors.New("tx failed")
}
log.Debugf("tx success, txHash: %s", txHash.Hex())
log.Infof("tx success, txHash: %s", txHash.Hex())
return nil
}
}
Expand Down
13 changes: 7 additions & 6 deletions utils/configs/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ type Config struct {

TaskInterval map[string]time.Duration `json:"task_interval" yaml:"task_interval"`

// 通知渠道, 当成功rebalance时, 发送通知
Notice struct {
Type string `json:"type" yaml:"type" comment:"Notice type, support: slack"`
Config map[string]interface{} `json:"config" yaml:"config" comment:"It depends on the notification type, slack needs ['webhook','channel']"`
Interval time.Duration `json:"interval" yaml:"interval" comment:"Same message send interval, minimum interval must be greater than or equal to 1 hour, default 1h"`
} `json:"notice" yaml:"notice" comment:"Notice config. When rebalance success, send notice"`
Notice Notice `json:"notice" yaml:"notice" comment:"Notice config. When rebalance success, send notice"`
}

type Notice struct {
Type string `json:"type" yaml:"type" comment:"Notice type, support: slack"`
Config map[string]interface{} `json:"config" yaml:"config" comment:"It depends on the notification type, slack needs ['webhook','channel']"`
Interval time.Duration `json:"interval" yaml:"interval" comment:"Same message send interval, minimum interval must be greater than or equal to 1 hour, default 1h"`
}

type Chain struct {
Expand Down
4 changes: 3 additions & 1 deletion utils/notice/notice.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Notice interface {

func SetMsgInterval(interval time.Duration) {
if interval.Seconds() < time.Hour.Seconds() {
logrus.Warnf("msg interval %s is too short, set to 1 hour", interval)
msgInterval = time.Hour
return
}
Expand All @@ -47,7 +48,7 @@ func WithFields(ctx context.Context, fields Fields) context.Context {
return context.WithValue(ctx, constant.NoticeFieldsKeyInCtx, fields)
}

func Init(noticeType Type, conf map[string]interface{}) error {
func Init(noticeType Type, conf map[string]interface{}, interval time.Duration) error {
if notice != nil {
return nil
}
Expand All @@ -64,6 +65,7 @@ func Init(noticeType Type, conf map[string]interface{}) error {
return errors.Errorf("notice type %s not support", noticeType)
}
}
SetMsgInterval(interval)
return nil
}

Expand Down
1 change: 1 addition & 0 deletions utils/provider/bridge/darwinia/darwinia.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ func (b *Bridge) Swap(ctx context.Context, args provider.SwapParams) (result pro
recordFn(provider.SwapHistory{Actions: sourceChainSendingAction, Status: string(provider.TxStatusPending),
CurrentChain: args.SourceChain})
ctx = provider.WithNotify(ctx, provider.WithNotifyParams{
Receiver: common.HexToAddress(args.Receiver),
TokenIn: args.SourceToken,
TokenOut: args.TargetToken,
TokenInChain: args.SourceChain,
Expand Down
1 change: 1 addition & 0 deletions utils/provider/bridge/helix/helix.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func (b *Bridge) Swap(ctx context.Context, args provider.SwapParams) (result pro
tx.Gas = 406775

ctx = provider.WithNotify(ctx, provider.WithNotifyParams{
Receiver: common.HexToAddress(args.Receiver),
TokenIn: args.SourceToken,
TokenOut: args.TargetToken,
TokenInChain: args.SourceChain,
Expand Down
Loading

0 comments on commit 359f39f

Please sign in to comment.