Skip to content

Commit

Permalink
adapted tsai to work with mps
Browse files Browse the repository at this point in the history
  • Loading branch information
oguiza committed Feb 11, 2024
1 parent 2db1421 commit bb69bef
Show file tree
Hide file tree
Showing 20 changed files with 1,340 additions and 981 deletions.
260 changes: 130 additions & 130 deletions nbs/006_data.core.ipynb

Large diffs are not rendered by default.

422 changes: 280 additions & 142 deletions nbs/010_data.transforms.ipynb

Large diffs are not rendered by default.

55 changes: 28 additions & 27 deletions nbs/012_data.image.ipynb

Large diffs are not rendered by default.

189 changes: 92 additions & 97 deletions nbs/022_tslearner.ipynb

Large diffs are not rendered by default.

76 changes: 39 additions & 37 deletions nbs/026_callback.noisy_student.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|export \n",
"#|export\n",
"from tsai.imports import *\n",
"from tsai.utils import *\n",
"from tsai.data.preprocessing import *\n",
Expand All @@ -61,26 +61,26 @@
"#|export\n",
"\n",
"# This is an unofficial implementation of noisy student based on:\n",
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification. \n",
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification.\n",
"# In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).\n",
"# Official tensorflow implementation available in https://github.com/google-research/noisystudent\n",
"\n",
"\n",
"class NoisyStudent(Callback):\n",
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise: \n",
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:\n",
" - stochastic depth: .8\n",
" - RandAugment: N=2, M=27\n",
" - dropout: .5\n",
" \n",
"\n",
" Steps:\n",
" 1. Build the dl you will use as a teacher\n",
" 2. Create dl2 with the pseudolabels (either soft or hard preds)\n",
" 3. Pass any required batch_tfms to the callback\n",
" \n",
"\n",
" \"\"\"\n",
" \n",
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True, \n",
" pseudolabel_sample_weight:float=1., verbose=False): \n",
"\n",
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True,\n",
" pseudolabel_sample_weight:float=1., verbose=False):\n",
" r'''\n",
" Args:\n",
" dl2: dataloader with the pseudolabels\n",
Expand All @@ -90,18 +90,18 @@
" do_setup: perform a transform setup on the labeled dataset.\n",
" pseudolabel_sample_weight: weight of each pseudolabel sample relative to the labeled one of the loss.\n",
" '''\n",
" \n",
"\n",
" self.dl2, self.bs, self.l2pl_ratio, self.batch_tfms, self.do_setup, self.verbose = dl2, bs, l2pl_ratio, batch_tfms, do_setup, verbose\n",
" self.pl_sw = pseudolabel_sample_weight\n",
" \n",
"\n",
" def before_fit(self):\n",
" if self.batch_tfms is None: self.batch_tfms = self.dls.train.after_batch\n",
" self.old_bt = self.dls.train.after_batch # Remove and store dl.train.batch_tfms\n",
" self.old_bs = self.dls.train.bs\n",
" self.dls.train.after_batch = noop \n",
" self.dls.train.after_batch = noop\n",
"\n",
" if self.do_setup and self.batch_tfms:\n",
" for bt in self.batch_tfms: \n",
" for bt in self.batch_tfms:\n",
" bt.setup(self.dls.train)\n",
"\n",
" if self.bs is None: self.bs = self.dls.train.bs\n",
Expand All @@ -111,12 +111,12 @@
" pv(f'labels / pseudolabels per training batch : {self.dls.train.bs} / {self.dl2.bs}', self.verbose)\n",
" rel_weight = (self.dls.train.bs/self.dl2.bs) * (len(self.dl2.dataset)/len(self.dls.train.dataset))\n",
" pv(f'relative labeled/ pseudolabel sample weight in dataset: {rel_weight:.1f}', self.verbose)\n",
" \n",
"\n",
" self.dl2iter = iter(self.dl2)\n",
" \n",
"\n",
" self.old_loss_func = self.learn.loss_func\n",
" self.learn.loss_func = self.loss\n",
" \n",
"\n",
" def before_batch(self):\n",
" if self.training:\n",
" X, y = self.x, self.y\n",
Expand All @@ -125,26 +125,26 @@
" self.dl2iter = iter(self.dl2)\n",
" X2, y2 = next(self.dl2iter)\n",
" if y.ndim == 1 and y2.ndim == 2: y = torch.eye(self.learn.dls.c, device=y.device)[y]\n",
" \n",
"\n",
" X_comb, y_comb = concat(X, X2), concat(y, y2)\n",
" \n",
" if self.batch_tfms is not None: \n",
"\n",
" if self.batch_tfms is not None:\n",
" X_comb = compose_tfms(X_comb, self.batch_tfms, split_idx=0)\n",
" y_comb = compose_tfms(y_comb, self.batch_tfms, split_idx=0)\n",
" self.learn.xb = (X_comb,)\n",
" self.learn.yb = (y_comb,)\n",
" pv(f'\\nX: {X.shape} X2: {X2.shape} X_comb: {X_comb.shape}', self.verbose)\n",
" pv(f'y: {y.shape} y2: {y2.shape} y_comb: {y_comb.shape}', self.verbose)\n",
" \n",
" def loss(self, output, target): \n",
"\n",
" def loss(self, output, target):\n",
" if target.ndim == 2: _, target = target.max(dim=1)\n",
" if self.training and self.pl_sw != 1: \n",
" if self.training and self.pl_sw != 1:\n",
" loss = (1 - self.pl_sw) * self.old_loss_func(output[:self.dls.train.bs], target[:self.dls.train.bs])\n",
" loss += self.pl_sw * self.old_loss_func(output[self.dls.train.bs:], target[self.dls.train.bs:])\n",
" return loss \n",
" else: \n",
" return loss\n",
" else:\n",
" return self.old_loss_func(output, target)\n",
" \n",
"\n",
" def after_fit(self):\n",
" self.dls.train.after_batch = self.old_bt\n",
" self.learn.loss_func = self.old_loss_func\n",
Expand All @@ -170,7 +170,8 @@
"outputs": [],
"source": [
"dsid = 'NATOPS'\n",
"X, y, splits = get_UCR_data(dsid, return_split=False)"
"X, y, splits = get_UCR_data(dsid, return_split=False)\n",
"X = X.astype(np.float32)"
]
},
{
Expand Down Expand Up @@ -229,10 +230,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.884984</td>\n",
" <td>1.809759</td>\n",
" <td>0.166667</td>\n",
" <td>00:06</td>\n",
" <td>1.782144</td>\n",
" <td>1.758471</td>\n",
" <td>0.250000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -249,7 +250,7 @@
"output_type": "stream",
"text": [
"\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 58])\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 41])\n",
"y: torch.Size([171]) y2: torch.Size([85]) y_comb: torch.Size([256])\n"
]
}
Expand Down Expand Up @@ -323,10 +324,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.894964</td>\n",
" <td>1.814770</td>\n",
" <td>0.177778</td>\n",
" <td>00:03</td>\n",
" <td>1.898401</td>\n",
" <td>1.841182</td>\n",
" <td>0.155556</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -343,7 +344,7 @@
"output_type": "stream",
"text": [
"\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 45])\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 51])\n",
"y: torch.Size([171, 6]) y2: torch.Size([85, 6]) y_comb: torch.Size([256, 6])\n"
]
}
Expand All @@ -353,6 +354,7 @@
"soft_preds = False\n",
"\n",
"pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n",
"pseudolabels = pseudolabels.astype(np.float32)\n",
"dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n",
"dl2 = TSDataLoader(dsets2, num_workers=0)\n",
"noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n",
Expand Down Expand Up @@ -380,9 +382,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2023-01-21 14:30:23\n",
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2024-02-10 21:53:24\n",
"Correct notebook to script conversion! 😃\n",
"Saturday 21/01/23 14:30:25 CET\n"
"Saturday 10/02/24 21:53:27 CET\n"
]
},
{
Expand Down
Loading

0 comments on commit bb69bef

Please sign in to comment.