Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3 #565

Open
peregilk opened this issue Apr 25, 2024 · 26 comments
Open

Llama 3 #565

peregilk opened this issue Apr 25, 2024 · 26 comments

Comments

@peregilk
Copy link

Do you have any plans for adding for supporting Llama-3? Any idea how complex this would be, apart from new configs?

@dlwh
Copy link
Member

dlwh commented Apr 25, 2024

@Helw150 said it worked out of the box. Just configs I think

@peregilk
Copy link
Author

Thats fantastic, @dlwh! Great if you could share your configs @Helw150.

I must admit I have not dug into the details here yet, but I understood the biggest architectural changes were using a larger tokenizer, and adding GQA to the smaller models. I havent seen GQA used in any of the Levanter models, but found a post saying it was supported. Can this also just be enabled through the configs?

I also read a post about them doing some masking on longer sequences so that the attention did not "spill over" to new documents.

@peregilk
Copy link
Author

The model seems to start training with:

data:  
  tokenizer: "meta-llama/Meta-Llama-3-8B"
model:
  type: llama
initialize_from_hf: "meta-llama/Meta-Llama-3-8B"
use_hf_model_config: true

However, I keep getting the message: "The tokenizers appear to be different. You may want to check this.".

Not really sure what is causing this.

@peregilk
Copy link
Author

@dlwh. Unfortuantely, I can not seem to get it to work right out of the box. The model is training, but when trying to train on a domain specific corpus, the loss is starting way too high, and never fully recovers.

I am pretty sure the issue is the vocab size here. I can not seem to override the vocab size in the model config.

This line seem to return the default Llama tokenizer:

converter = config.model.default_hf_checkpoint_converter

While it is overwritten later, I think this is the main issue.

I have tried both reading the configs from HF, and creating them from scratch.

Please advice.

@dlwh
Copy link
Member

dlwh commented Apr 26, 2024

ok i'll try to take a look this weekend. Do you have a full config you can use a reproducer by any chance?

@peregilk
Copy link
Author

peregilk commented Apr 27, 2024

Awesome. Here are the config I have been using. Just replaced the urls.

data:
  train_urls:
    - "gs://mydatabucket/train-shard-{0001..0147}-of-0147.json.gz"
  validation_urls:
    - "gs://mydatabucket/NCC_plus_scandi/validation-shard-0001-of-0001.json.gz"
  cache_dir: "gs://mycachebucket/tokenized/llama3hfconfigfalse/"
  tokenizer: "meta-llama/Meta-Llama-3-8B"
model:
  type: llama
  seq_len: 2048
  hidden_dim: 4096
  intermediate_dim: 14336
  num_layers: 32
  num_heads: 32
  num_kv_heads: 8
  initializer_range: 0.02
  use_flash_attention: true
initialize_from_hf: "meta-llama/Meta-Llama-3-8B"
use_hf_model_config: false
trainer:
  wandb:
    entity: "myentity"
    project: "myproject"
    tags: ["llama3"]
    name: north-llamatre-hfconfigfalse
  mp: p=f32,c=bfloat16
  train_batch_size: 256 
  num_train_steps: 10000
  steps_per_eval: 250
  tensor_parallel_axes: ["mlp", "heads"]
  fsdp_axis: "embed"
  batch_axis: "batch"
  checkpointer:
    base_path: "gs://mycheckpointbucket/north-llama3-hfconfigfalse/checkpoints"
    keep:
      - every: 1000
optimizer:
  learning_rate: 1.2e-5
  weight_decay: 0.1
  min_lr_ratio: 0.1
  warmup: 1000
hf_save_steps: 5000
hf_save_path: "gs://myhfbucket/north-llama3-hfconfigfalse/hf"

I have also tried setting
use_hf_model_config: true

This gave the same result.

What I am seeing can be illustrated here:
image

The red line is the loss of a Mistral model. The grey line is from Llama3. Apart from that, the settings are identical, and they are both trying to use the HF tokenizer. The pattern is very similar to what we are seeing with just hot-swappng to a new tokenizer.

@Helw150
Copy link
Collaborator

Helw150 commented May 1, 2024

Do you have a reproduction of a case where the Levanter implementation gives you a different prediction than the HuggingFace implementation? As an example, here's a round trip test I used to verify the Whisper implementation

def test_hf_roundtrip():

The only architectural change in LLama 3 is the Grouped Query attention - which is supported here:

QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)

I've exported a few Llama 3 finetunes from Levanter to HuggingFace successfully and the models seem to work as expected for inference, so it's unclear to me whether the above case suggests a bug or is a function of the much larger vocab size of LLama 3 v.s. Mistral. I'm not sure what the data mix is above, but if it's multilingual it's also likely Mistral starts from lower loss because it's more explicitly designed for Multilinguality.

If you send over a case where HuggingFace and Levanter output different logits for the Llama 3 weights, I'd be happy to take on the debugging from there!

@peregilk
Copy link
Author

peregilk commented May 4, 2024

I am trying to debug this and test on downstream tasks by exporting to HF. However, I noticed that for llama3, no tokenizer.model file is created when saving to HF. Have you experienced this @Helw150?

Edit: I see the reason for this is that the HF repos does not contain any tokenizer.model-file.

@peregilk peregilk closed this as completed May 5, 2024
@peregilk
Copy link
Author

Reopening this. I have trained a bit more, and I am really not satisfied with the result, even if the train/eval loss looks fine.

Do you have a working llama3 config-file @Helw150. I want to double check if I have made any mistakes here.

@peregilk peregilk reopened this May 21, 2024
@Helw150
Copy link
Collaborator

Helw150 commented May 21, 2024

Hi!

My use case is a bit non-standard (training multi-modal encoders) so I'm not sure my configs will help so much. If you want to check them anyways, you can find them on the will/distill branch tagged with via_*! In these cases, I'm leaving Llama frozen but still need to get gradients from it. I've done runs with both Llama 2 and Llama 3 and haven't seen any surprising looking issues when switching to Llama 3!

Could you give a bit more details about the issue you are facing? Does it seem like the model isn't training properly? Or is it that the results aren't satisfactory?

If it's the latter, additional context (e.g. specific symptoms, expected behavior) would help for me to understand whether there's an underlying bug that could cause this or if it's a matter of hyperparameters/underlying capabilities!

@dlwh
Copy link
Member

dlwh commented May 21, 2024

What revision/commit were you using to train? My usage of the TPU splash attention had/has a bug that messed everything up. I'm like 60% sure I know how to fix (and you can probably fix your checkpoints post-hoc) but I need another day or so. If you want to try something, can you pre-multiply all of the q_proj by sqrt(headdim). I haven't verified that yet but I strongly suspect

@Helw150
Copy link
Collaborator

Helw150 commented May 21, 2024

Ah yes, worth noting that I haven't pulled in the Splash Attention changes yet

@dlwh
Copy link
Member

dlwh commented May 21, 2024

splash attention is currently disabled so main is fine 🤞 right now

@peregilk
Copy link
Author

I was using splash attention, so that might have caused the error.

However, I was suspecting this to be a tokenizer-size issue. I remember also getting some warning about non-matching tokenizers here.

But I can retry this without splash, and see if that is related.

@dlwh
Copy link
Member

dlwh commented May 22, 2024

I believe splash is now fixed in latest main, but it's now off by default.

Can you try

--model.attn_backend splash

and

--model.attn_backend jax_flash

and let me know if things seem ok?

@peregilk
Copy link
Author

Awesome! I have not been training for long, but in general my good runs have been starting with an eval-loss of around 2.5, while the broken runs have started on 6. In the latest main, this seems to start with a 2.5 loss both with and without flash attention. Looks very good.

For reference (in case other are having the same issue), the correct commands are uppercase:
--model.attn_backend SPLASH
--model.attn_backend JAX_FLASH

Splash automatically upscales to 32, since 16 is not working. I understand this is expected.

@dlwh
Copy link
Member

dlwh commented May 22, 2024

Awesome thanks for your patience.

Yeah, for whatever reason they don't support bf16 for attention with that kernel yet

the uppercase thing can be fixed by upgrading draccus to >=0.8.0

@dlwh dlwh closed this as completed May 22, 2024
@Aphoh
Copy link
Contributor

Aphoh commented Jun 25, 2024

@peregilk Llama3 shouldn't work out of the box nicely, as it uses a different theta for the RoPE scaling and configuring that isn't yet supported in levanter. This issue should probably be re-opened. Even when I use the correct rope theta I don't get reasonable results in levanter (i.e. eval_lm gives me a loss of ~7 on neutral pretraining datasets like SlimPajama). @dlwh any ideas?

@dlwh dlwh reopened this Jun 25, 2024
@dlwh
Copy link
Member

dlwh commented Jun 25, 2024

that's not great. Probably need to spend some time in a debugger.

@dlwh
Copy link
Member

dlwh commented Jun 25, 2024

i probably won't get to this for at least a few days myself, but happy to provide some support

@mayankjobanputra
Copy link

mayankjobanputra commented Sep 3, 2024

@dlwh any progress on this one? I was thinking of switching to Levanter from composer.

@dlwh
Copy link
Member

dlwh commented Sep 23, 2024

i don't really understand the issue. we have a unit test (which I recognize is not necessarily proof it's correct) and support rope scaling now. Does someone have a code that fails

@dlwh
Copy link
Member

dlwh commented Sep 23, 2024

ok i see. this is becoming a priority for me so i will try to tackle it by wednesday

@dlwh
Copy link
Member

dlwh commented Sep 24, 2024

I haven't fully tested it but can you try main. I added the new llama 3 rope stuff

@dlwh
Copy link
Member

dlwh commented Oct 2, 2024

@mayankjobanputra did you have a chance to try it?

@mayankjobanputra
Copy link

@dlwh I haven't tried it yet. Still preprocessing the data and meanwhile writing some infra code around the framework. If everything goes smoothly I should be able to answer your question in 15ish days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants