-
Notifications
You must be signed in to change notification settings - Fork 5
Model Card
Itomigna2 edited this page Apr 2, 2024
·
11 revisions
This is model card of this implementation.
I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.
Sample model card of PaLM 2
Model Summary | |
---|---|
Summary | Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards. |
Input | RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel] |
Output | Probability distribution vector; [action_space] |
Model architecture | |
---|---|
Agent Network | |
It is not official term, it means the networks to be unrolled and optimized. It includes below networks. (It is also called 'learning network', 'online network') | |
Representation Network | It is the observation encoder based on CNN. (Main role: image -> hidden state) |
Dynamics Network | It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future) |
Policy Network | It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Value Network | It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Reward Network | It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Target Network | |
It has same elements as agent network. It is mixture of the updated agent network's parameters priorly. It is updated by exponential moving average. It is used in actor's environment interacting and learner's inference except unrolling agent network. |
Loss function | |
---|---|
PG loss | Auxiliary policy-gradient loss. (eq.10 in muesli paper)(first_term in the code) |
Model loss | The policy component of the model loss. (eq.13 in muesli paper)(extended to the first time step, start from k=0)(L_m in the code) |
Value model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_v in the code) |
Reward model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_r in the code) |
Replay Buffer | |
---|---|
Sampling method | Uniform (randomly pick 1 sequence (5 transitions) per randomly selected episode)(it is empirically works but it is arbitrary and not same as paper. Maybe it will be changed after.) |
Replay proportion in a batch | Off-policy data 75% + on-policy data 25% |
Capacity | Not yet implemented |
Treating frame stacking | Add the start image (stacking_frame-1) times before the start of interacting. |
Treating unroll over episode length | Add zeros elements (unroll_step+1) times after the last of interacting. |
Storing methods are quite twisted and not verified. Maybe it needs to be checked and improved. |
Evaluation | |
---|---|
LunarLander-v2 | Just checking cumulative returns while gathering data from agent-environment interacting |
Need to be improved (Maybe it has to be averaged with more than 3 random seed with controlled randomness) |
Techniques | |
---|---|
Categorical reparametrization | Yes. Used for value model and reward model. (distribution <-> scalar) |
Advantage normalization | Yes. |
Target Network | Yes. Moving average update |
Mixed prior policy | Yes. Mixed with 0.3% uniform distribution. (Mixed with 3% behavior policy is not used and verified due to my lack of knowledge). The role of this is regulariser as described in the Ada paper p6. |
Stacking observations | Yes. |
Min-max normalization | Yes. [0,1]. It is used to normalize embedding before p,v,r inference. |
β-LOO action-dependent baselines | None. |
Retrace | Not yet. (Target value has to be estimated by Retrace estimator but not yet implemented) |
Vectorized environment | Not yet. |
Distributed computing(actor-learner decomposition framework) | Not yet. |
Pop-Art | Not yet. |
Main Hyperparameters | |
---|---|
Following Table 5 in the Muesli paper | |
Batch size | 128 sequences |
Sequence length | 5 frames |
Model unroll length K | 4 |
Replay proportion in a batch | 75% |
Replay buffer capacity | (not implemented) ~ frames |
Initial learning rate | 3e-4 |
Final learning rate | 0 |
AdamW weight decay | 0 |
Discount | 0.997 |
Target network update rate (α_target) | 0.01 |
Value loss weight | 0.25 |
Reward loss weight | 1.0 |
RETRACE estimator samples | (not yet) |
KL(π_CMPO, π) estimator samples | None (exact KL used) |
Variance moving average decay (β_var) | 0.99 |
Variance offset (ϵ_var) | 1e-12 |
Added by this implementation | |
not all, focus on the more important | |
Learner iteration | 20 |
resize_height | 72 |
resize_width | 96 |
hidden state resolution (of LSTM) | 512 |
mlp hidden layer width | 128 |
support size (of categorical reparametrization) | 30 |
expriment_length(for LunarLander) | 20000 |
epsilon for CEloss | 1e-12 (it prevents NaN error due to log about zero)(not H-param optimized yet) |
stacking frame | 4 |
Optimization | |
---|---|
Mini-batch | Yes. |
Optimizer | AdamW (weight_decay=0) |
Gradient clipping | [-1,1] |
learning rate schedule | decay to zero |
Implementation Frameworks | |
---|---|
Hardware | Intel Xeon, NVIDIA H100 GPU |
Software | PyTorch, Python, MS nni, ... |
Computing resource usage | |
---|---|
GPU memory | This version use 1.6GB vram per 1 experiment |
CPU core | 1 core per 1 experiment |
Main memory | This version use approximately 10GB~100GB ram per 1 experiment (Maybe the reason is the absence of the replay buffer capacity, or some memory leakage in the code) |
Model Usage & Limitations | |
---|---|
TBD | |
TBD |