Attribute is reset per batch in dp
mode
#3301
-
❓ Questions and HelpBefore asking:
What is your question?I don't know whether this is a bug... Codeimport os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from argparse import Namespace
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
self._dummy_property = None
@property
def dummy_propery(self):
if self._dummy_property is None:
self._dummy_property = '*' * 30
print('print only once per gpu')
return self._dummy_property
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
print(self._dummy_property)
# Access every batch
self.dummy_propery
print(self._dummy_property)
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
train_loader = DataLoader(
MNIST(
os.getcwd(),
download=True,
transform=transforms.ToTensor()
),
batch_size=128
)
trainer = pl.Trainer(gpus=2,
distributed_backend='dp',
max_epochs=2)
model = LitModel()
trainer.fit(model, train_loader) Output
What's your environment?
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
@yukw777 mind have a look? 🐰 |
Beta Was this translation helpful? Give feedback.
-
@Borda sure thing! A bit busy right now, but I’ll try to get to it later this week or weekend! |
Beta Was this translation helpful? Give feedback.
-
oh, this is actually a known problem and comes from DataParallel in PyTorch itself. |
Beta Was this translation helpful? Give feedback.
oh, this is actually a known problem and comes from DataParallel in PyTorch itself.
See #565 and #1649 for reference.
@ananyahjha93 has been working on a workaround but it seems to be super non trivial #1895