Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent 1e3868d commit e3431d2
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions examples/trompt_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,24 @@ def train(
epoch: int,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
num_classes: int,
metric: torchmetrics.Metric,
rank: int,
) -> float:
model.train()
loss_accum = torch.tensor(0.0, device=rank, dtype=torch.float32)
for tf in tqdm(loader, desc=f"Epoch {epoch:02d}", disable=rank != 0):
for tf in tqdm(
loader,
desc=f"Epoch {epoch:02d} (train)",
disable=rank != 0,
):
tf = tf.to(rank)
# [batch_size, num_layers, num_classes]
out = model(tf)

with torch.no_grad():
metric.update(out.mean(dim=1).argmax(dim=-1), tf.y)
num_layers = out.size(1)

_, num_layers, num_classes = out.size()
# [batch_size * num_layers, num_classes]
pred = out.view(-1, num_classes)
y = tf.y.repeat_interleave(num_layers)
Expand Down Expand Up @@ -105,15 +110,16 @@ def test(
return metric_value


def run(rank, world_size, args) -> None:
def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank,
)
logging.basicConfig(
format=f"[rank={rank}] [%(asctime)s] %(levelname)s: %(message)s",
format=(f"[rank={rank}/{world_size}] "
f"[%(asctime)s] %(levelname)s: %(message)s"),
level=logging.INFO,
)
logger = logging.getLogger(__name__)
Expand All @@ -122,7 +128,7 @@ def run(rank, world_size, args) -> None:
assert dataset.task_type.is_classification

# Ensure train, val and test splits are the same across all ranks by
# setting the seed before shuffling.
# setting the seed on each rank.
torch.manual_seed(args.seed)
dataset = dataset.shuffle()
train_dataset, val_dataset, test_dataset = (
Expand Down Expand Up @@ -186,7 +192,6 @@ def run(rank, world_size, args) -> None:
epoch,
train_loader,
optimizer,
dataset.num_classes,
train_metric,
rank,
)
Expand Down

0 comments on commit e3431d2

Please sign in to comment.