From 4831cc7583032a75f67af38f1d2cae89ac44fe31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:50:12 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lai_tldr/data.py | 17 ++++++++--------- lai_tldr/module.py | 11 +++++------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/lai_tldr/data.py b/lai_tldr/data.py index 81299b3..d2dd4f6 100644 --- a/lai_tldr/data.py +++ b/lai_tldr/data.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - """Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy.""" import os @@ -55,11 +54,11 @@ def __init__( self.target_max_token_len = target_max_token_len def __len__(self): - """returns length of data.""" + """Returns length of data.""" return len(self.data) def __getitem__(self, index: int): - """returns dictionary of input tensors to feed into T5/MT5 model.""" + """Returns dictionary of input tensors to feed into T5/MT5 model.""" data_row = self.data.iloc[index] source_text = data_row["source_text"] @@ -85,9 +84,9 @@ def __getitem__(self, index: int): ) labels = target_text_encoding["input_ids"] - labels[ - labels == 0 - ] = -100 # to make sure we have correct labels for T5 text generation + labels[labels == 0] = ( + -100 + ) # to make sure we have correct labels for T5 text generation return dict( source_text_input_ids=source_text_encoding["input_ids"].flatten(), @@ -170,7 +169,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - """training dataloader.""" + """Training dataloader.""" return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -179,7 +178,7 @@ def train_dataloader(self): ) def test_dataloader(self): - """test dataloader.""" + """Test dataloader.""" return DataLoader( self.test_dataset, batch_size=self.batch_size, @@ -188,7 +187,7 @@ def test_dataloader(self): ) def val_dataloader(self): - """validation dataloader.""" + """Validation dataloader.""" return DataLoader( self.val_dataset, batch_size=self.batch_size, diff --git a/lai_tldr/module.py b/lai_tldr/module.py index a9b9d80..4e130f9 100644 --- a/lai_tldr/module.py +++ b/lai_tldr/module.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - """Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy.""" from lightning.pytorch import LightningModule @@ -42,7 +41,7 @@ def __init__( self.save_only_last_epoch = False def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None): - """forward step.""" + """Forward step.""" output = self.model( input_ids, attention_mask=attention_mask, @@ -53,7 +52,7 @@ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None return output.loss, output.logits def training_step(self, batch, batch_idx): - """training step.""" + """Training step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -70,7 +69,7 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - """validation step.""" + """Validation step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -86,7 +85,7 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss, prog_bar=True) def test_step(self, batch, batch_idx): - """test step.""" + """Test step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -102,7 +101,7 @@ def test_step(self, batch, batch_idx): self.log("test_loss", loss, prog_bar=True) def configure_optimizers(self): - """configure optimizers.""" + """Configure optimizers.""" return AdamW(self.parameters(), lr=0.0001)