LightningCLI access DataModule methods in model constructor #13403
-
Hello, I am switching my code to use LightningCLI for easier config/reproducibility. I want to access dataset examples when creating my model, but I don't want to have to make my data loading code explicitly available to both Currently my project uses code similar to this: data = MyLightningDataModule(...)
model = MyLightningModule(data.input_dim, ...)
trainer.fit(model, datamodule=data) where This is how I use LightningCLI: cli = LightningCLI(MyLightningModule, MyLightningDataModule) But it seems that when using LightningCLI, the LightningModule is instantiated before the LightningDataModule, and I can't figure out how to configure LightningCLI to make the LightningDataModule available to the LightningModule constructor. I've tried to make this happen by supplying datamodule as a parameter to my class MyLightningModule(pl.LightningModule):
def __init__(self, data: BigVulDatasetLineVDDataModule):
self.data = data
# instantiate model...
def train_dataloader(self):
return self.data.train_dataloader()
def val_dataloader(self):
return self.data.train_dataloader()
def test_dataloader(self):
return self.data.train_dataloader()
# other LightningModule related methods... And my config looks like this: model:
# model args...
data:
class_path: project.MyLightningDataModule
init_args:
# datamodule args... ...but it would be nice if I could keep these two classes independent. Is there any better way to achieve the functionality I want? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
For this you use link_arguments which requires to subclass class MyLightningModule(LightningModule):
def __init__(self, input_dim: int):
...
class MyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments('data.input_dim', 'model.input_dim', apply_on='instantiate') It might also be possible like your second example linking the entire data object. Though this needs the latest version of jsonargparse: class MyLightningModule(pl.LightningModule):
def __init__(self, data: BigVulDatasetLineVDDataModule):
self.data = data
# instantiate model...
class MyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments('data', 'model.data', apply_on='instantiate') |
Beta Was this translation helpful? Give feedback.
-
Thanks you for pointing |
Beta Was this translation helpful? Give feedback.
For this you use link_arguments which requires to subclass
LightningCLI
. Something like:It might also be possible like your second example linking the entire data object. Though this needs the latest version of jsonargparse: