Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor style changes and docstring adaptations #39

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from datetime import timedelta
from functools import partial
from typing import Any, Optional
from typing import Optional

import torch
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -93,7 +93,9 @@ def __init__(
separate parameter.
perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used
to stabilise the model.
max_history_size (int, optional): Maximum number of history steps.
max_history_size (int, optional): Maximum number of history steps. You can load
checkpoints with a smaller `max_history_size`, but you cannot load checkpoints
with a larger `max_history_size`.
use_lora (bool, optional): Use LoRA adaptation.
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
steps.
Expand Down Expand Up @@ -316,54 +318,54 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

# check if history size is compatible and adjust weights if necessary
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
d = self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
raise AssertionError(f"Cannot load checkpoint with max_history_size \
{d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}")
# Check if the history size is compatible and adjust weights if necessary.
current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2]
if self.max_history_size > current_history_size:
self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < current_history_size:
raise AssertionError(
f"Cannot load checkpoint with `max_history_size` {current_history_size} "
f"into model with `max_history_size` {self.max_history_size}."
)

self.load_state_dict(d, strict=strict)

def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
"""Adapt a checkpoint with smaller max_history_size to a model with a larger
max_history_size than the current model.
def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None:
"""Adapt a checkpoint with smaller `max_history_size` to a model with a larger
`max_history_size` than the current model.

If a checkpoint was trained with a larger max_history_size than the current model,
If a checkpoint was trained with a larger `max_history_size` than the current model,
this function will assert fail to prevent loading the checkpoint. This is to
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
performance.

This implementation copies weights from the checkpoint to the model and fills 0
for the new history width dimension.
This implementation copies weights from the checkpoint to the model and fills zeros
for the new history width dimension. It mutates `checkpoint`.
"""
# Find all weights with prefix "encoder.surf_token_embeds.weights."
for name, weight in list(checkpoint.items()):
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith(
"encoder.atmos_token_embeds.weights."
):
# We only need to adapt the patch embedding in the encoder.
enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.")
enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.")
if enc_surf_embedding or enc_atmos_embedding:
# This shouldn't get called with current logic but leaving here for future proofing
# and in cases where its called outside current context
assert (
weight.shape[2] <= self.max_history_size
), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

# Initialize the new weight tensor
# and in cases where its called outside current context.
if not (weight.shape[2] <= self.max_history_size):
raise AssertionError(
f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} "
f"into model with `max_history_size` {self.max_history_size}."
)

# Initialize the new weight tensor.
new_weight = torch.zeros(
(weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
device=weight.device,
dtype=weight.dtype,
)

# Copy the existing weights to the new tensor by duplicating the histories provided
# into any new history dimensions
for j in range(weight.shape[2]):
# only fill existing weights, others are zeros
new_weight[:, :, j, :, :] = weight[:, :, j, :, :]
# into any new history dimensions. The rest remains at zero.
new_weight[:, :, : weight.shape[2]] = weight

checkpoint[name] = new_weight
return checkpoint

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
Expand Down
35 changes: 16 additions & 19 deletions tests/test_checkpoint_adaptation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import numpy as np
import pytest
import torch

Expand All @@ -19,45 +20,41 @@ def checkpoint():
}


# check both history sizes which are divisible by 2 (original shape) and not
# Check both history sizes which are divisible by 2 (original shape) and not.
@pytest.mark.parametrize("model", [4, 5], indirect=True)
def test_adapt_checkpoint_max_history(model, checkpoint):
# checkpoint starts with history dim, shape[2], as size 2
# Checkpoint starts with history dim., `shape[2]`, equal to 2.
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
for name, weight in checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])


# check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize("model", [1], indirect=True)
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
"""Check that an assertion error is thrown when trying to load a larger checkpoint to a
smaller history size."""
with pytest.raises(AssertionError):
model.adapt_checkpoint_max_history_size(checkpoint)


# test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize("model", [4], indirect=True)
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)
"""Test adapting the checkpoint twice to ensure that the second time should not change the
weights."""
model.adapt_checkpoint_max_history_size(checkpoint)
model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
for name, weight in checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])
Loading