Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
  • Loading branch information
JimmyZhang12 committed Sep 21, 2024
1 parent e4dd479 commit 702d23a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import tempfile
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from dataclasses import asdict
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Mapping, Optional, Sized, Union
from dataclasses import asdict

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -97,6 +97,7 @@

try:
from megatron.core import dist_checkpointing, parallel_state
from megatron.core.dist_checkpointing.core import maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject
from megatron.core.dist_checkpointing.optimizer import (
Expand All @@ -105,10 +106,10 @@
optim_state_to_sharding_state,
)
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.core import maybe_load_config
from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer

from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -1353,7 +1354,7 @@ def dummy():
metadata = asdict(maybe_load_config(tmp_model_weights_dir))
sharded_state_dict = instance.sharded_state_dict(metadata=metadata)
checkpoint['state_dict'] = sharded_state_dict

checkpoint_io = DistributedCheckpointIO.from_config(conf)
checkpoint = checkpoint_io.load_checkpoint(
tmp_model_weights_dir,
Expand Down

0 comments on commit 702d23a

Please sign in to comment.