diff --git a/cmd/rpcdaemon/cli/config.go b/cmd/rpcdaemon/cli/config.go index fa41f35df29..37a1e0a5bb6 100644 --- a/cmd/rpcdaemon/cli/config.go +++ b/cmd/rpcdaemon/cli/config.go @@ -96,6 +96,9 @@ func RootCommand() (*cobra.Command, *httpcfg.HttpCfg) { rootCmd.PersistentFlags().Uint64Var(&cfg.MaxTraces, "trace.maxtraces", 200, "Sets a limit on traces that can be returned in trace_filter") rootCmd.PersistentFlags().BoolVar(&cfg.WebsocketEnabled, "ws", false, "Enable Websockets - Same port as HTTP") rootCmd.PersistentFlags().BoolVar(&cfg.WebsocketCompression, "ws.compression", false, "Enable Websocket compression (RFC 7692)") + rootCmd.PersistentFlags().StringVar(&cfg.WebSocketListenAddress, "ws.addr", nodecfg.DefaultHTTPHost, "Websocket server listening interface") + rootCmd.PersistentFlags().IntVar(&cfg.WebSocketPort, "ws.port", nodecfg.DefaultHTTPPort, "Websocket server listening port") + rootCmd.PersistentFlags().StringSliceVar(&cfg.WebsocketCORSDomain, "ws.corsdomain", []string{}, "Comma separated list of domains from which to accept cross origin requests (browser enforced)") rootCmd.PersistentFlags().StringVar(&cfg.RpcAllowListFilePath, utils.RpcAccessListFlag.Name, "", "Specify granular (method-by-method) API allowlist") rootCmd.PersistentFlags().UintVar(&cfg.RpcBatchConcurrency, utils.RpcBatchConcurrencyFlag.Name, 2, utils.RpcBatchConcurrencyFlag.Usage) rootCmd.PersistentFlags().BoolVar(&cfg.RpcStreamingDisable, utils.RpcStreamingDisableFlag.Name, false, utils.RpcStreamingDisableFlag.Usage) @@ -533,14 +536,10 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp } httpHandler := node.NewHTTPHandlerStack(srv, cfg.HttpCORSDomain, cfg.HttpVirtualHost, cfg.HttpCompression) - var wsHandler http.Handler - if cfg.WebsocketEnabled { - wsHandler = srv.WebsocketHandler([]string{"*"}, nil, cfg.WebsocketCompression) - } graphQLHandler := graphql.CreateHandler(defaultAPIList) - apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, wsHandler, graphQLHandler, nil) + apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, nil, graphQLHandler, nil) if err != nil { return err } @@ -609,6 +608,69 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp log.Info("GRPC endpoint closed", "url", grpcEndpoint) } }() + + if cfg.WebsocketEnabled { + wsSrv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable) + + allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath) + if err != nil { + return err + } + + var wsApiFlags []string + for _, flag := range cfg.WebSocketApi { + if flag != "engine" { + wsApiFlags = append(wsApiFlags, flag) + } + } + + if err := node.RegisterApisFromWhitelist(defaultAPIList, wsApiFlags, wsSrv, false); err != nil { + return fmt.Errorf("could not start register WS apis: %w", err) + } + wsSrv.SetAllowList(allowListForRPC) + + wsSrv.SetBatchLimit(cfg.BatchLimit) + + var defaultAPIList []rpc.API + + for _, api := range rpcAPI { + if api.Namespace != "engine" { + defaultAPIList = append(defaultAPIList, api) + } + } + + var apiFlags []string + for _, flag := range cfg.API { + if flag != "engine" { + apiFlags = append(apiFlags, flag) + } + } + + if err := node.RegisterApisFromWhitelist(defaultAPIList, apiFlags, wsSrv, false); err != nil { + return fmt.Errorf("could not start register RPC apis: %w", err) + } + + wsEndpoint := fmt.Sprintf("%s:%d", cfg.WebSocketListenAddress, cfg.WebSocketPort) + + wsHttpHandler := wsSrv.WebsocketHandler(cfg.WebsocketCORSDomain, nil, cfg.WebsocketCompression) + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsHttpHandler.ServeHTTP(w, r) + }) + + wsListener, wsHttpAddr, err := node.StartHTTPEndpoint(wsEndpoint, cfg.HTTPTimeouts, wsHandler) + if err != nil { + return fmt.Errorf("could not start ws RPC api: %w", err) + } + + defer func() { + wsSrv.Stop() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = wsListener.Shutdown(shutdownCtx) + log.Info("WS endpoint closed", "url", wsHttpAddr) + }() + } + <-ctx.Done() log.Info("Exiting...") return nil @@ -685,7 +747,7 @@ func obtainJWTSecret(cfg httpcfg.HttpCfg) ([]byte, error) { return jwtSecret, nil } -func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, graphQLHandler http.Handler, jwtSecret []byte) (http.Handler, error) { +func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler, wsHandler, graphQLHandler http.Handler, jwtSecret []byte) (http.Handler, error) { var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if cfg.GraphQLEnabled && graphql.ProcessGraphQLcheckIfNeeded(graphQLHandler, w, r) { return diff --git a/cmd/rpcdaemon/cli/httpcfg/http_cfg.go b/cmd/rpcdaemon/cli/httpcfg/http_cfg.go index f42bb94f38f..601c9218b78 100644 --- a/cmd/rpcdaemon/cli/httpcfg/http_cfg.go +++ b/cmd/rpcdaemon/cli/httpcfg/http_cfg.go @@ -32,15 +32,20 @@ type HttpCfg struct { MaxTraces uint64 WebsocketEnabled bool WebsocketCompression bool - RpcAllowListFilePath string - RpcBatchConcurrency uint - RpcStreamingDisable bool - DBReadConcurrency int - TraceCompatibility bool // Bug for bug compatibility for trace_ routines with OpenEthereum - TxPoolApiAddr string - StateCache kvcache.CoherentConfig - Snap ethconfig.Snapshot - Sync ethconfig.Sync + WebSocketListenAddress string + WebSocketPort int + WebsocketCORSDomain []string + WebSocketApi []string + + RpcAllowListFilePath string + RpcBatchConcurrency uint + RpcStreamingDisable bool + DBReadConcurrency int + TraceCompatibility bool // Bug for bug compatibility for trace_ routines with OpenEthereum + TxPoolApiAddr string + StateCache kvcache.CoherentConfig + Snap ethconfig.Snapshot + Sync ethconfig.Sync // GRPC server GRPCServerEnabled bool diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 42aea6dde4b..d10edd10854 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -333,7 +333,7 @@ var ( HTTPCORSDomainFlag = cli.StringFlag{ Name: "http.corsdomain", Usage: "Comma separated list of domains from which to accept cross origin requests (browser enforced)", - Value: "", + Value: "*", } HTTPVirtualHostsFlag = cli.StringFlag{ Name: "http.vhosts", diff --git a/turbo/cli/default_flags.go b/turbo/cli/default_flags.go index 02a83871fc3..6f8fabc694b 100644 --- a/turbo/cli/default_flags.go +++ b/turbo/cli/default_flags.go @@ -62,6 +62,9 @@ var DefaultFlags = []cli.Flag{ &utils.AuthRpcVirtualHostsFlag, &utils.HTTPApiFlag, &utils.WSEnabledFlag, + &utils.WSListenAddrFlag, + &utils.WSPortFlag, + &utils.WSApiFlag, &utils.WsCompressionFlag, &utils.HTTPTraceFlag, &utils.StateCacheFlag, diff --git a/turbo/cli/flags.go b/turbo/cli/flags.go index 8bb035d4306..e347c91df82 100644 --- a/turbo/cli/flags.go +++ b/turbo/cli/flags.go @@ -356,6 +356,11 @@ func setEmbeddedRpcDaemon(ctx *cli.Context, cfg *nodecfg.Config) { apis := ctx.String(utils.HTTPApiFlag.Name) log.Info("starting HTTP APIs", "APIs", apis) + wsEnabled := ctx.IsSet(utils.WSEnabledFlag.Name) + wsApis := strings.Split(ctx.String(utils.WSApiFlag.Name), ",") + if wsEnabled { + log.Info("starting WS APIs", "APIs", wsApis) + } c := &httpcfg.HttpCfg{ Enabled: ctx.Bool(utils.HTTPEnabledFlag.Name), Dirs: cfg.Dirs, @@ -387,16 +392,20 @@ func setEmbeddedRpcDaemon(ctx *cli.Context, cfg *nodecfg.Config) { }, EvmCallTimeout: ctx.Duration(EvmCallTimeoutFlag.Name), - WebsocketEnabled: ctx.IsSet(utils.WSEnabledFlag.Name), - RpcBatchConcurrency: ctx.Uint(utils.RpcBatchConcurrencyFlag.Name), - RpcStreamingDisable: ctx.Bool(utils.RpcStreamingDisableFlag.Name), - DBReadConcurrency: ctx.Int(utils.DBReadConcurrencyFlag.Name), - RpcAllowListFilePath: ctx.String(utils.RpcAccessListFlag.Name), - Gascap: ctx.Uint64(utils.RpcGasCapFlag.Name), - MaxTraces: ctx.Uint64(utils.TraceMaxtracesFlag.Name), - TraceCompatibility: ctx.Bool(utils.RpcTraceCompatFlag.Name), - BatchLimit: ctx.Int(utils.RpcBatchLimit.Name), - ReturnDataLimit: ctx.Int(utils.RpcReturnDataLimit.Name), + WebsocketEnabled: wsEnabled, + WebSocketListenAddress: ctx.String(utils.WSListenAddrFlag.Name), + WebSocketPort: ctx.Int(utils.WSPortFlag.Name), + WebsocketCORSDomain: strings.Split(ctx.String(utils.WSAllowedOriginsFlag.Name), ","), + WebSocketApi: wsApis, + RpcBatchConcurrency: ctx.Uint(utils.RpcBatchConcurrencyFlag.Name), + RpcStreamingDisable: ctx.Bool(utils.RpcStreamingDisableFlag.Name), + DBReadConcurrency: ctx.Int(utils.DBReadConcurrencyFlag.Name), + RpcAllowListFilePath: ctx.String(utils.RpcAccessListFlag.Name), + Gascap: ctx.Uint64(utils.RpcGasCapFlag.Name), + MaxTraces: ctx.Uint64(utils.TraceMaxtracesFlag.Name), + TraceCompatibility: ctx.Bool(utils.RpcTraceCompatFlag.Name), + BatchLimit: ctx.Int(utils.RpcBatchLimit.Name), + ReturnDataLimit: ctx.Int(utils.RpcReturnDataLimit.Name), TxPoolApiAddr: ctx.String(utils.TxpoolApiAddrFlag.Name),