diff --git a/cmd/livepeer/livepeer.go b/cmd/livepeer/livepeer.go index 1e8a97e9dc..49506cb946 100755 --- a/cmd/livepeer/livepeer.go +++ b/cmd/livepeer/livepeer.go @@ -156,6 +156,7 @@ func parseLivepeerConfig() starter.LivepeerConfig { cfg.HevcDecoding = flag.Bool("hevcDecoding", *cfg.HevcDecoding, "Enable or disable HEVC decoding") // AI: + cfg.AIServiceRegistry = flag.Bool("aiServiceRegistry", *cfg.AIServiceRegistry, "Set to true to use an AI ServiceRegistry contract address") cfg.AIWorker = flag.Bool("aiWorker", *cfg.AIWorker, "Set to true to run an AI worker") cfg.AIModels = flag.String("aiModels", *cfg.AIModels, "Set models (pipeline:model_id) for AI worker to load upon initialization") cfg.AIModelsDir = flag.String("aiModelsDir", *cfg.AIModelsDir, "Set directory where AI model weights are stored") diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 69a33896e8..0d32e8f1f0 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -92,6 +92,7 @@ type LivepeerConfig struct { HttpIngest *bool Orchestrator *bool Transcoder *bool + AIServiceRegistry *bool AIWorker *bool Gateway *bool Broadcaster *bool @@ -199,6 +200,7 @@ func DefaultLivepeerConfig() LivepeerConfig { defaultTestTranscoder := true // AI: + defaultAIServiceRegistry := false defaultAIWorker := false defaultAIModels := "" defaultAIModelsDir := "" @@ -298,10 +300,11 @@ func DefaultLivepeerConfig() LivepeerConfig { TestTranscoder: &defaultTestTranscoder, // AI: - AIWorker: &defaultAIWorker, - AIModels: &defaultAIModels, - AIModelsDir: &defaultAIModelsDir, - AIRunnerImage: &defaultAIRunnerImage, + AIServiceRegistry: &defaultAIServiceRegistry, + AIWorker: &defaultAIWorker, + AIModels: &defaultAIModels, + AIModelsDir: &defaultAIModelsDir, + AIRunnerImage: &defaultAIRunnerImage, // Onchain: EthAcctAddr: &defaultEthAcctAddr, @@ -706,6 +709,11 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { CheckTxTimeout: time.Duration(int64(*cfg.TxTimeout) * int64(*cfg.MaxTxReplacements+1)), } + if *cfg.AIServiceRegistry { + // For the time-being Livepeer AI Subnet uses its own ServiceRegistry, so we define it here + ethCfg.ServiceRegistryAddr = ethcommon.HexToAddress("0x04C0b249740175999E5BF5c9ac1dA92431EF34C5") + } + client, err := eth.NewClient(ethCfg) if err != nil { glog.Errorf("Failed to create Livepeer Ethereum client: %v", err) diff --git a/eth/client.go b/eth/client.go index 4452062898..1ed4cc1e46 100644 --- a/eth/client.go +++ b/eth/client.go @@ -167,6 +167,9 @@ type LivepeerEthClientConfig struct { Signer types.Signer ControllerAddr ethcommon.Address CheckTxTimeout time.Duration + + // For the time-being Livepeer AI Subnet uses its own ServiceRegistry, so we define it here + ServiceRegistryAddr ethcommon.Address } func NewClient(cfg LivepeerEthClientConfig) (LivepeerEthClient, error) { @@ -174,11 +177,12 @@ func NewClient(cfg LivepeerEthClientConfig) (LivepeerEthClient, error) { backend := NewBackend(cfg.EthClient, cfg.Signer, cfg.GasPriceMonitor, cfg.TransactionManager) return &client{ - accountManager: cfg.AccountManager, - backend: backend, - tm: cfg.TransactionManager, - controllerAddr: cfg.ControllerAddr, - checkTxTimeout: cfg.CheckTxTimeout, + accountManager: cfg.AccountManager, + backend: backend, + tm: cfg.TransactionManager, + controllerAddr: cfg.ControllerAddr, + checkTxTimeout: cfg.CheckTxTimeout, + serviceRegistryAddr: cfg.ServiceRegistryAddr, }, nil } @@ -211,28 +215,15 @@ func (c *client) setContracts(opts *bind.TransactOpts) error { glog.V(common.SHORT).Infof("LivepeerToken: %v", c.tokenAddr.Hex()) - chainID, err := c.backend.ChainID(context.Background()) - if err != nil { - glog.Errorf("Failed to get chain ID from remote ethereum node: %v", err) - return err - } - - // TODO: This is a temporary setup for a separate AIServiceRegistry. Revise this when AI subnet merges with the mainnet. - var serviceRegistryAddr ethcommon.Address - arbitrumOneChainID := big.NewInt(42161) - if chainID.Cmp(arbitrumOneChainID) == 0 { - serviceRegistryAddr = ethcommon.HexToAddress("0x04C0b249740175999E5BF5c9ac1dA92431EF34C5") - } else { - serviceRegistryAddr, err = c.GetContract(crypto.Keccak256Hash([]byte("ServiceRegistry"))) + if c.serviceRegistryAddr == (ethcommon.Address{}) { + c.serviceRegistryAddr, err = c.GetContract(crypto.Keccak256Hash([]byte("ServiceRegistry"))) if err != nil { glog.Errorf("Error getting ServiceRegistry address: %v", err) return err } } - c.serviceRegistryAddr = serviceRegistryAddr - - serviceRegistry, err := contracts.NewServiceRegistry(serviceRegistryAddr, c.backend) + serviceRegistry, err := contracts.NewServiceRegistry(c.serviceRegistryAddr, c.backend) if err != nil { glog.Errorf("Error creating ServiceRegistry binding: %v", err) return err