You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`from future import annotations
from finrl.config import ERL_PARAMS
from finrl.config import INDICATORS
from finrl.config import RLlib_PARAMS
from finrl.config import SAC_PARAMS
from finrl.config import TRAIN_END_DATE
from finrl.config import TRAIN_START_DATE
from finrl.config_tickers import DOW_30_TICKER
from meta.data_processor import DataProcessor
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Follow the tutorial https://github.com/AI4Finance-Foundation/FinRL-Tutorials/blob/master/3-Practical/FinRL_MultiCrypto_Trading.ipynb .
Always in:
Trained_model=agent.train_model(
Model=model, cwd=cwd, total_timesteps=break_step
)
Error:
exception: no description
File "D:\per\AI\test1\traintest.py", line 66, in train
trained_model = agent.train_model(
^^^^^^^^^^^^^^^^^^
File "D:\per\AI\test1\Crypto.py", line 32, in
train(start_date=TRAIN_START_DATE,
AssertionError:
Please help!!!
Source code:
Cypto.py
from meta.env_crypto_trading.env_multiple_crypto import CryptoEnv
from traintest import train
TICKER_LIST = ['BTCUSDT','ETHUSDT','ADAUSDT','BNBUSDT','XRPUSDT',
'SOLUSDT','DOTUSDT', 'DOGEUSDT','AVAXUSDT','UNIUSDT']
env = CryptoEnv
TRAIN_START_DATE = '2021-09-01'
TRAIN_END_DATE = '2021-09-02'
TEST_START_DATE = '2021-09-21'
TEST_END_DATE = '2021-09-30'
INDICATORS = ['macd','rsi']#self-defined technical indicator list is NOT supported yet
ERL_PARAMS = {"learning_rate": 2**-15,"batch_size": 211,
"gamma": 0.99, "seed":312,"net_dimension": 29,
"target_step": 5000, "eval_gap": 30, "eval_times": 1}
train(start_date=TRAIN_START_DATE,
end_date=TRAIN_END_DATE,
ticker_list=TICKER_LIST,
data_source='binance',
time_interval='1h',
technical_indicator_list=INDICATORS,
drl_lib='elegantrl',
env=env,
model_name='ppo',
current_working_dir='./test_ppo',
erl_params=ERL_PARAMS,
break_step=5e4,
if_vix=False
)
`
traintest.py
`from future import annotations
from finrl.config import ERL_PARAMS
from finrl.config import INDICATORS
from finrl.config import RLlib_PARAMS
from finrl.config import SAC_PARAMS
from finrl.config import TRAIN_END_DATE
from finrl.config import TRAIN_START_DATE
from finrl.config_tickers import DOW_30_TICKER
from meta.data_processor import DataProcessor
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
def train(
start_date,
end_date,
ticker_list,
data_source,
time_interval,
technical_indicator_list,
drl_lib,
env,
model_name,
if_vix=True,
**kwargs,
):
# download data
dp = DataProcessor(data_source, start_date=start_date, end_date=end_date, time_interval=time_interval)
data = dp.download_data(ticker_list=ticker_list)
data = dp.clean_data()
df=dp.dataframe
df.head()
df.to_csv('binancedataYS.csv')
data = dp.add_technical_indicator(tech_indicator_list=technical_indicator_list)
df2=dp.dataframe
df2.head()
df2.to_csv('binancedataIND.csv')
if if_vix:
data = dp.add_vix(data)
price_array, tech_array, turbulence_array = dp.df_to_array(if_vix)
env_config = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": True,
}
env_instance = env(config=env_config)
# read parameters
cwd = kwargs.get("cwd", "./" + str(model_name))
if drl_lib == "elegantrl":
from finrl.agents.elegantrl.models import DRLAgent as DRLAgent_erl
break_step = kwargs.get("break_step", 1e6)
erl_params = kwargs.get("erl_params")
agent = DRLAgent_erl(
env=env,
price_array=price_array,
tech_array=tech_array,
turbulence_array=turbulence_array,
)
model = agent.get_model(model_name, model_kwargs=erl_params)
trained_model = agent.train_model(
model=model, cwd=cwd, total_timesteps=break_step
)
elif drl_lib == "rllib":
total_episodes = kwargs.get("total_episodes", 100)
rllib_params = kwargs.get("rllib_params")
from finrl.agents.rllib.models import DRLAgent as DRLAgent_rllib
agent_rllib = DRLAgent_rllib(
env=env,
price_array=price_array,
tech_array=tech_array,
turbulence_array=turbulence_array,
)
model, model_config = agent_rllib.get_model(model_name)
model_config["lr"] = rllib_params["lr"]
model_config["train_batch_size"] = rllib_params["train_batch_size"]
model_config["gamma"] = rllib_params["gamma"]
# ray.shutdown()
trained_model = agent_rllib.train_model(
model=model,
model_name=model_name,
model_config=model_config,
total_episodes=total_episodes,
)
trained_model.save(cwd)
elif drl_lib == "stable_baselines3":
total_timesteps = kwargs.get("total_timesteps", 1e6)
agent_params = kwargs.get("agent_params")
from finrl.agents.stablebaselines3.models import DRLAgent as DRLAgent_sb3
agent = DRLAgent_sb3(env=env_instance)
model = agent.get_model(model_name, model_kwargs=agent_params)
trained_model = agent.train_model(
model=model, tb_log_name=model_name, total_timesteps=total_timesteps
)
print("Training is finished!")
trained_model.save(cwd)
print("Trained model is saved in " + str(cwd))
else:
raise ValueError("DRL library input is NOT supported. Please check.")
`
Beta Was this translation helpful? Give feedback.
All reactions