diff --git a/dlordinal/datasets/fgnet.py b/dlordinal/datasets/fgnet.py index e72611b..b1d5b68 100644 --- a/dlordinal/datasets/fgnet.py +++ b/dlordinal/datasets/fgnet.py @@ -1,10 +1,11 @@ import re import shutil from pathlib import Path -from typing import Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd +from PIL import Image from skimage.io import imread, imsave from skimage.transform import resize from skimage.util import img_as_ubyte @@ -18,44 +19,82 @@ class FGNet(VisionDataset): """ Base class for FGNet dataset. + Attributes + ---------- + root : Path + Root directory of the dataset. + target_size : tuple + Size of the images after resizing. + categories : list + List of categories to be used. + test_size : float + Size of the test set. + validation_size : float + Size of the validation set. + transform : callable, optional + A function/transform that takes in a PIL image and returns a transformed version. + target_transform : callable, optional + A function/transform that takes in the target and transforms it. + data : pd.DataFrame + Dataframe containing the dataset. + Parameters ---------- root : str or Path - Root directory of dataset - download : bool, optional - If True, downloads the dataset from the internet and puts it in root directory. - If dataset is already downloaded, it is not downloaded again. - process_data : bool, optional - If True, processes the dataset and puts it in root directory. - If dataset is already processed, it is not processed again. + Root directory of the dataset. + download : bool, optional, default = True + If True, downloads the dataset from the internet and puts it in the root directory. + If the dataset is already downloaded, it is not downloaded again. target_size : tuple, optional - Size of the images after resizing. + Size of the images after resizing. Default is (128, 128). categories : list, optional - List of categories to be used. + List of categories to be used. Default is [3, 11, 16, 24, 40]. test_size : float, optional - Size of the test set. + Size of the test set. Default is 0.2. validation_size : float, optional - Size of the validation set. + Size of the validation set. Default is 0.15. + train : bool, optional + If True, returns the training dataset, otherwise returns the test dataset. Default is True. + transform : callable, optional + A function/transform that takes in a PIL image and returns a transformed version. + target_transform : callable, optional + A function/transform that takes in the target and transforms it. """ + # Attributes + root: Path + target_size: tuple + categories: list + test_size: float + validation_size: float + transform: Optional[Callable] + target_transform: Optional[Callable] + data: pd.DataFrame + def __init__( self, root: Union[str, Path], - download: bool = False, - process_data: bool = True, + download: bool = True, target_size: tuple = (128, 128), categories: list = [3, 11, 16, 24, 40], test_size: float = 0.2, validation_size: float = 0.15, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(FGNet, self).__init__(root) + super(FGNet, self).__init__( + str(root), transform=transform, target_transform=target_transform + ) - self.root = Path(self.root) + self.root = Path(root) self.root.parent.mkdir(parents=True, exist_ok=True) self.target_size = target_size self.categories = categories self.test_size = test_size self.validation_size = validation_size + self.transform = transform + self.target_transform = target_transform original_path = self.root / "FGNET/images" processed_path = self.root / "FGNET/data_processed" @@ -76,20 +115,97 @@ def __init__( " download it" ) - if process_data: - self.process(original_path, processed_path) - self.split( - original_csv_path, - train_csv_path, - test_csv_path, - original_images_path, - train_images_path, - test_images_path, - ) + self.process(original_path, processed_path) + self.split( + original_csv_path, + train_csv_path, + test_csv_path, + original_images_path, + train_images_path, + test_images_path, + ) + + # Load train and test dataframes + if train: + self.data = pd.read_csv(train_csv_path) + else: + self.data = pd.read_csv(test_csv_path) def __str__(self) -> str: return "FGNet" + def __len__(self) -> int: + """ + Obtain the number of samples in the dataset. + + Returns + ------- + int + Number of samples in the dataset. + """ + + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Get a sample from the dataset. + + Parameters + ---------- + index : int + Index of the sample to get. + + Returns + ------- + tuple + (image, target) where target is the class index of the target class. + """ + img_path = ( + self.root / "FGNET" / "data_processed" / self.data.iloc[index]["path"] + ) + + # Cargar la imagen como PIL.Image.Image + image = Image.open(img_path) + image = image.convert("RGB") + + # Aplicar transformación si está definida + if self.transform: + image = self.transform(image) + + target = int(self.data.iloc[index]["category"]) + if self.target_transform: + target = self.target_transform(target) + + return image, target + + @property + def targets(self) -> List[int]: + """ + Return the targets of the dataset. + + Returns + ------- + list + List of targets. + """ + + if self.target_transform: + return self.target_transform(self.data["category"]) + else: + return self.data["category"].tolist() + + @property + def classes(self) -> List[int]: + """ + Return the unique classes in the dataset. + + Returns + ------- + list + List of unique classes. + """ + return np.unique(self.data["category"]).tolist() + def download(self) -> None: """ Download the FGNet dataset and extract it. @@ -100,7 +216,7 @@ def download(self) -> None: download_and_extract_archive( "http://yanweifu.github.io/FG_NET_data/FGNET.zip", - self.root, + str(self.root), filename="fgnet.zip", md5="1206978cac3626321b84c22b24cc8d19", ) @@ -205,7 +321,9 @@ def get_age_from_filename(self, filename): Filename of the image. """ m = re.match("[0-9]+A([0-9]+).*", filename) - return int(m.groups()[0]) + if m: + return int(m.groups()[0]) + return None def find_category(self, real_age): """ @@ -289,7 +407,11 @@ def split_dataframe( x = np.array(df["path"]) y = np.array(df["category"]) x_train, x_test, y_train, y_test = train_test_split( - x, y, test_size=self.test_size, random_state=1 + x, + y, + test_size=self.test_size, + random_state=1, + stratify=y, ) for path, label in zip(x_train, y_train): @@ -303,7 +425,11 @@ def split_dataframe( shutil.copy(original_images_path / path, test_path / path) x_train, x_val, y_train, y_val = train_test_split( - x_train, y_train, test_size=self.validation_size, random_state=1 + x_train, + y_train, + test_size=self.validation_size, + random_state=1, + stratify=y_train, ) train = np.hstack((x_train[:, np.newaxis], y_train[:, np.newaxis])) diff --git a/dlordinal/datasets/tests/test_fgnet.py b/dlordinal/datasets/tests/test_fgnet.py index cd073cb..c1d639b 100644 --- a/dlordinal/datasets/tests/test_fgnet.py +++ b/dlordinal/datasets/tests/test_fgnet.py @@ -1,84 +1,133 @@ -import shutil -from pathlib import Path - +import numpy as np import pytest +import torch +from PIL import Image +from torchvision.transforms import ToTensor from dlordinal.datasets import FGNet -TMP_DIR = "./tmp_test_dir_fgnet" + +@pytest.fixture +def fgnet_train(tmp_path): + fgnet = FGNet( + root=tmp_path, + download=True, + train=True, + ) + return fgnet @pytest.fixture -def fgnet_instance(): - root = TMP_DIR - fgnet = FGNet(root, download=True, process_data=True) +def fgnet_test(tmp_path): + fgnet = FGNet( + root=tmp_path, + download=True, + train=False, + ) return fgnet -def test_download(fgnet_instance): - fgnet_instance.download() - assert fgnet_instance._check_integrity_download() +def test_download(fgnet_train): + fgnet_train.download() + assert fgnet_train._check_integrity_download() -def test_process(fgnet_instance): - fgnet_instance.process( - fgnet_instance.root / "FGNET/images", - fgnet_instance.root / "FGNET/data_processed", +def test_process(fgnet_train): + fgnet_train.process( + fgnet_train.root / "FGNET/images", + fgnet_train.root / "FGNET/data_processed", ) - assert fgnet_instance._check_integrity_process() + assert fgnet_train._check_integrity_process() -def test_split(fgnet_instance): - fgnet_instance.split( - fgnet_instance.root / "FGNET/data_processed/fgnet.csv", - fgnet_instance.root / "FGNET/data_processed/train.csv", - fgnet_instance.root / "FGNET/data_processed/test.csv", - fgnet_instance.root / "FGNET/data_processed", - fgnet_instance.root / "FGNET/train", - fgnet_instance.root / "FGNET/test", +def test_split(fgnet_train): + fgnet_train.split( + fgnet_train.root / "FGNET/data_processed/fgnet.csv", + fgnet_train.root / "FGNET/data_processed/train.csv", + fgnet_train.root / "FGNET/data_processed/test.csv", + fgnet_train.root / "FGNET/data_processed", + fgnet_train.root / "FGNET/train", + fgnet_train.root / "FGNET/test", ) - assert fgnet_instance._check_integrity_split() + assert fgnet_train._check_integrity_split() -def test_find_category(fgnet_instance): - assert fgnet_instance.find_category(1) == 0 - assert fgnet_instance.find_category(9) == 1 - assert fgnet_instance.find_category(14) == 2 - assert fgnet_instance.find_category(21) == 3 - assert fgnet_instance.find_category(33) == 4 +def test_find_category(fgnet_train): + assert fgnet_train.find_category(1) == 0 + assert fgnet_train.find_category(9) == 1 + assert fgnet_train.find_category(14) == 2 + assert fgnet_train.find_category(21) == 3 + assert fgnet_train.find_category(33) == 4 -def test_get_age_from_filename(fgnet_instance): +def test_get_age_from_filename(fgnet_train): filename = "001A12X_X.jpg" - assert fgnet_instance.get_age_from_filename(filename) == 12 + assert fgnet_train.get_age_from_filename(filename) == 12 -def test_load_data(fgnet_instance): - data = fgnet_instance.load_data(fgnet_instance.root / "FGNET/images") +def test_load_data(fgnet_train): + data = fgnet_train.load_data(fgnet_train.root / "FGNET/images") assert len(data) > 0 -def test_process_images_from_df(fgnet_instance): - data = fgnet_instance.load_data(fgnet_instance.root / "FGNET/images") - processed_images = list( - (fgnet_instance.root / "FGNET/data_processed").rglob("*.JPG") - ) +def test_process_images_from_df(fgnet_train): + data = fgnet_train.load_data(fgnet_train.root / "FGNET/images") + processed_images = list((fgnet_train.root / "FGNET/data_processed").rglob("*.JPG")) assert len(processed_images) == len(data) -def test_split_dataframe(fgnet_instance): - csv_path = fgnet_instance.root / "FGNET/data_processed/fgnet.csv" - train_images_path = fgnet_instance.root / "FGNET/train" - original_images_path = fgnet_instance.root / "FGNET/images" - test_images_path = fgnet_instance.root / "FGNET/test" - train_df, test_df = fgnet_instance.split_dataframe( +def test_split_dataframe(fgnet_train): + csv_path = fgnet_train.root / "FGNET/data_processed/fgnet.csv" + train_images_path = fgnet_train.root / "FGNET/train" + original_images_path = fgnet_train.root / "FGNET/images" + test_images_path = fgnet_train.root / "FGNET/test" + train_df, test_df = fgnet_train.split_dataframe( csv_path, train_images_path, original_images_path, test_images_path ) assert len(train_df) > 0 assert len(test_df) > 0 -def test_clean_up(): - path = Path(TMP_DIR) - if path.exists(): - shutil.rmtree(path) +def test_getitem(fgnet_train, fgnet_test): + for fgnet in [fgnet_train, fgnet_test]: + for i in range(len(fgnet)): + assert isinstance(fgnet[i][0], Image.Image) + assert isinstance(fgnet[i][1], int) + assert fgnet[i][1] == fgnet.targets[i] + assert np.array(fgnet[i][0]).ndim == 3 + + fgnet.transform = ToTensor() + + for i in range(len(fgnet)): + assert isinstance(fgnet[i][0], torch.Tensor) + assert isinstance(fgnet[i][1], int) + assert fgnet[i][1] == fgnet.targets[i] + assert len(fgnet[i][0].shape) == 3 + + fgnet.target_transform = lambda target: np.array(target) + for i in range(len(fgnet)): + assert isinstance(fgnet[i][0], torch.Tensor) + assert isinstance(fgnet[i][1], np.ndarray) + assert np.array_equal(fgnet[i][1], fgnet.targets[i]) + + +def test_len(fgnet_train, fgnet_test): + for fgnet in [fgnet_train, fgnet_test]: + assert len(fgnet) == len(fgnet.targets) + assert len(fgnet) == len(fgnet.data) + + +def test_targets(fgnet_train): + assert len(fgnet_train.targets) > 0 + assert isinstance(fgnet_train.targets, list) + assert isinstance(fgnet_train.targets[0], int) + assert np.all(np.array(fgnet_train.targets) >= 0) + + +def test_classes(fgnet_train, fgnet_test): + assert len(fgnet_train.classes) == 6 + assert isinstance(fgnet_train.classes, list) + assert fgnet_train.classes == fgnet_test.classes + assert fgnet_train.classes == np.unique(fgnet_train.targets).tolist() + assert fgnet_test.classes == np.unique(fgnet_test.targets).tolist() + assert fgnet_train.classes == [0, 1, 2, 3, 4, 5] diff --git a/tutorials/dlordinal_with_skorch_tutorial.ipynb b/tutorials/dlordinal_with_skorch_tutorial.ipynb index 339672c..8a098db 100644 --- a/tutorials/dlordinal_with_skorch_tutorial.ipynb +++ b/tutorials/dlordinal_with_skorch_tutorial.ipynb @@ -20,7 +20,6 @@ "from torch import cuda, nn\n", "from torch.optim import Adam\n", "from torchvision import models\n", - "from torchvision.datasets import ImageFolder\n", "from torchvision.transforms import Compose, ToTensor\n", "from skorch import NeuralNetClassifier" ] @@ -37,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -58,12 +57,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we use the `FGNet` method to download and preprocess the images. Once that is done with the training data, we create a validation partition comprising 15% of the data using the `StratifiedShuffleSplit` method. Finally, with all the partitions, we load the images using a method called `DataLoader`." + "Now we use the `FGNet` method to download and preprocess the images." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -73,23 +72,33 @@ "Files already downloaded and verified\n", "Files already processed and verified\n", "Files already split and verified\n", - "Using cuda device\n" + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Using cpu device\n" ] } ], "source": [ - "fgnet = FGNet(root=\"./datasets/fgnet\", download=True, process_data=True)\n", - "\n", - "train_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/train\", transform=Compose([ToTensor()])\n", + "fgnet_train = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", - "test_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/test\", transform=Compose([ToTensor()])\n", + "\n", + "fgnet_test = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", "\n", - "num_classes = len(train_data.classes)\n", - "classes = train_data.classes\n", - "targets = train_data.targets\n", + "num_classes = len(fgnet_train.classes)\n", + "classes = fgnet_train.classes\n", + "targets = fgnet_train.targets\n", "\n", "# Get CUDA device\n", "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", @@ -110,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -146,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -240,61 +249,59 @@ ")" ] }, - "execution_count": 22, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# estimator.fit(X=train_data, y=stargets)\n", - "targets = np.array(targets)\n", - "estimator.fit(X=train_data, y=targets)" + "estimator.fit(X=fgnet_train, y=targets)" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Train probabilities = train_probs=array([[ 3.7356095 , 2.280471 , 0.27906656, -6.5274134 , -0.1423619 ,\n", - " -0.89688575],\n", - " [ 6.7312865 , 4.2798862 , -2.754112 , -7.015324 , -0.58337 ,\n", - " -1.614481 ],\n", - " [ 1.5178611 , 0.3035042 , -1.0939833 , -1.1187612 , 0.3635073 ,\n", - " -1.4568175 ],\n", + "Train probabilities = train_probs=array([[ 2.9335966e+00, 1.6396730e+00, -4.1097212e+00, 4.2018917e-01,\n", + " -1.4565204e+00, -4.3641205e+00],\n", + " [ 8.1791782e+00, 2.8562646e+00, -9.0971680e+00, 1.6580626e-02,\n", + " -3.8278933e+00, -7.0976620e+00],\n", + " [ 9.5499592e+00, 3.5849128e+00, -8.4043884e+00, -4.2262230e-02,\n", + " -5.7516675e+00, -7.0723133e+00],\n", " ...,\n", - " [-8.981531 , -2.1939955 , -1.2311378 , -1.599317 , 1.7429321 ,\n", - " 8.956122 ],\n", - " [-7.570979 , -2.4199474 , -0.9986418 , 2.073321 , 1.8904057 ,\n", - " 3.514359 ],\n", - " [-4.612633 , -2.3110492 , 1.4501587 , 1.0073776 , 0.30610457,\n", - " 1.1957583 ]], dtype=float32)\n", + " [ 1.7538257e+00, 1.8087029e-04, -6.2550454e+00, 2.6264958e+00,\n", + " 3.1415778e-01, -4.3745127e+00],\n", + " [-6.1797386e-01, -6.9311045e-02, -3.2306197e+00, 1.5903845e+00,\n", + " -5.8013773e-01, -1.2073216e+00],\n", + " [-1.2833995e+00, -2.5462928e-01, -3.4825776e+00, 2.6367762e+00,\n", + " 2.4873786e-01, -1.8333929e+00]], dtype=float32)\n", "\n", - "Test probabilities = test_probs=array([[-0.7816221 , -0.87308043, -1.196569 , -2.6637518 , 1.0128176 ,\n", - " 2.083984 ],\n", - " [-0.04164401, 1.2640952 , 1.5022627 , -2.0729616 , -0.1945019 ,\n", - " -1.6816527 ],\n", - " [ 7.281721 , 2.9113057 , -3.4834485 , -7.3575487 , 1.2093832 ,\n", - " -1.4325407 ],\n", + "Test probabilities = test_probs=array([[-0.30555367, -0.06686927, -2.9128444 , 2.4259052 , 0.38098484,\n", + " -2.2845333 ],\n", + " [ 6.7591906 , 4.8296413 , -4.5853643 , 0.25159836, -5.5993004 ,\n", + " -6.9446783 ],\n", + " [ 1.775908 , 2.4953337 , -3.323121 , 1.694654 , -1.8339258 ,\n", + " -4.3164835 ],\n", " ...,\n", - " [-9.944385 , -3.6944542 , -0.42169666, 0.5072165 , 2.8238878 ,\n", - " 7.273284 ],\n", - " [-5.4416714 , -3.3233507 , -2.3007298 , 0.98697877, 2.2850878 ,\n", - " 4.3965707 ],\n", - " [-4.6799173 , -1.7463862 , -0.5284957 , -1.7213606 , 0.89393014,\n", - " 4.456833 ]], dtype=float32)\n" + " [-1.7040929 , -1.338182 , -4.237404 , 3.5072465 , 1.2669804 ,\n", + " -2.8282924 ],\n", + " [ 1.7215905 , 1.0412043 , -3.7369957 , 1.7478234 , -0.50562465,\n", + " -3.7393394 ],\n", + " [-1.3184199 , 0.42436934, -1.2091427 , 1.731024 , 0.15624674,\n", + " -2.6779366 ]], dtype=float32)\n" ] } ], "source": [ - "train_probs = estimator.predict_proba(train_data)\n", + "train_probs = estimator.predict_proba(fgnet_train)\n", "print(f\"Train probabilities = {train_probs=}\\n\")\n", "\n", - "test_probs = estimator.predict_proba(test_data)\n", + "test_probs = estimator.predict_proba(fgnet_test)\n", "print(f\"Test probabilities = {test_probs=}\")" ] } @@ -315,7 +322,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.8.19" }, "orig_nbformat": 4, "vscode": { diff --git a/tutorials/ecoc_tutorial.ipynb b/tutorials/ecoc_tutorial.ipynb index 269fce1..fe04105 100644 --- a/tutorials/ecoc_tutorial.ipynb +++ b/tutorials/ecoc_tutorial.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,6 @@ "from torch import cuda\n", "from torch.optim import Adam\n", "from torch.utils.data import DataLoader, Subset\n", - "from torchvision.datasets import ImageFolder\n", "from torchvision.transforms import Compose, ToTensor\n", "from torchvision.models import resnet18\n", "from tqdm import tqdm" @@ -47,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -73,33 +72,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Using cpu device\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Detected image shape: [3, 128, 128]\n", + "class_weights=tensor([1.0191, 1.5345, 0.7946, 1.1314, 0.5517, 2.4273])\n" + ] + } + ], "source": [ - "fgnet = FGNet(root=\"./datasets/fgnet\", download=True, process_data=True)\n", - "\n", - "complete_train_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/train\", transform=Compose([ToTensor()])\n", + "fgnet_trainval = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", - "test_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/test\", transform=Compose([ToTensor()])\n", + "\n", + "test_data = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", "\n", - "num_classes = len(complete_train_data.classes)\n", - "classes = complete_train_data.classes\n", - "targets = complete_train_data.targets\n", + "num_classes = len(fgnet_trainval.classes)\n", + "classes = fgnet_trainval.classes\n", + "targets = fgnet_trainval.targets\n", "\n", "# Create a validation split\n", "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)\n", "sss_splits = list(\n", - " sss.split(X=np.zeros(len(complete_train_data)), y=complete_train_data.targets)\n", + " sss.split(X=np.zeros(len(fgnet_trainval)), y=fgnet_trainval.targets)\n", ")\n", "train_idx, val_idx = sss_splits[0]\n", "\n", "# Create subsets for training and validation\n", - "train_data = Subset(complete_train_data, train_idx)\n", - "val_data = Subset(complete_train_data, val_idx)\n", + "train_data = Subset(fgnet_trainval, train_idx)\n", + "val_data = Subset(fgnet_trainval, val_idx)\n", "\n", "# Get CUDA device\n", "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", @@ -152,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -182,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -201,13 +229,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Metrics computation\n", - "\n", - "\n", "def compute_metrics(y_true: np.ndarray, \n", " y_pred: np.ndarray, \n", " num_classes: int):\n", @@ -301,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -326,8 +352,6 @@ "\n", " # Compute prediction error and accuracy of the training process\n", " pred = model(X)\n", - " print(pred)\n", - " print(y)\n", " loss = loss_fn(pred, y)\n", "\n", " mean_loss += loss\n", @@ -381,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -438,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -496,9 +520,246 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train progress: 100%|█████████████████████████| 4/4 [00:11<00:00, 2.90s/it, accuracy=74, loss=39.1]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Train Confusion matrix :\n", + "[[ 0 71 2 1 0 0]\n", + " [ 0 175 16 13 0 1]\n", + " [ 0 44 25 32 3 7]\n", + " [ 0 19 21 81 9 13]\n", + " [ 0 11 8 35 7 39]\n", + " [ 0 0 0 6 3 38]]\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] EPOCH: 1/5\n", + "Train loss: 86.180344, Train accuracy: 0.1088\n", + "Val loss: 138.933105, Val accuracy: 0.1074\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train progress: 100%|█████████████████████████| 4/4 [00:10<00:00, 2.69s/it, accuracy=74, loss=31.9]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Train Confusion matrix :\n", + "[[ 0 70 2 2 0 0]\n", + " [ 0 181 16 8 0 0]\n", + " [ 0 44 25 37 1 4]\n", + " [ 0 14 17 99 3 10]\n", + " [ 0 5 9 59 11 16]\n", + " [ 0 0 0 1 3 43]]\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] EPOCH: 2/5\n", + "Train loss: 73.077477, Train accuracy: 0.1088\n", + "Val loss: 124.414589, Val accuracy: 0.1074\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train progress: 100%|███████████████████████████| 4/4 [00:10<00:00, 2.74s/it, accuracy=74, loss=22]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Train Confusion matrix :\n", + "[[ 0 73 0 1 0 0]\n", + " [ 0 192 10 3 0 0]\n", + " [ 0 42 36 32 1 0]\n", + " [ 0 12 20 96 13 2]\n", + " [ 0 3 7 44 27 19]\n", + " [ 0 0 0 0 2 45]]\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] EPOCH: 3/5\n", + "Train loss: 59.367184, Train accuracy: 0.1088\n", + "Val loss: 112.904716, Val accuracy: 0.1074\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train progress: 100%|█████████████████████████| 4/4 [00:11<00:00, 2.78s/it, accuracy=74, loss=28.2]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Train Confusion matrix :\n", + "[[ 0 72 0 2 0 0]\n", + " [ 0 195 7 3 0 0]\n", + " [ 0 37 43 29 2 0]\n", + " [ 0 4 14 113 9 3]\n", + " [ 0 1 7 41 33 18]\n", + " [ 0 0 0 0 4 43]]\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] EPOCH: 4/5\n", + "Train loss: 51.117718, Train accuracy: 0.1088\n", + "Val loss: 101.604919, Val accuracy: 0.1074\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train progress: 100%|█████████████████████████| 4/4 [00:11<00:00, 2.78s/it, accuracy=74, loss=22.5]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Train Confusion matrix :\n", + "[[ 0 74 0 0 0 0]\n", + " [ 0 196 7 2 0 0]\n", + " [ 0 23 49 39 0 0]\n", + " [ 0 3 9 131 0 0]\n", + " [ 0 1 1 49 31 18]\n", + " [ 0 0 0 0 0 47]]\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] EPOCH: 5/5\n", + "Train loss: 44.203930, Train accuracy: 0.1088\n", + "Val loss: 114.216644, Val accuracy: 0.1074\n", + "\n", + "[INFO] Network evaluation ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Test progress: 100%|█████████████████████████████████████████| 2/2 [00:01<00:00, 1.37it/s, loss=88]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Confusion matrix :\n", + "[[ 0 14 5 3 0 0]\n", + " [ 0 27 16 14 2 1]\n", + " [ 0 11 7 13 2 0]\n", + " [ 0 5 11 22 2 2]\n", + " [ 0 0 3 21 3 3]\n", + " [ 0 1 2 3 3 5]]\n", + "\n", + "MS: 0.0000\n", + "\n", + "QWK: 0.5270\n", + "\n", + "MAE: 0.9502\n", + "\n", + "CCR: 0.3184\n", + "\n", + "1-off: 0.7861\n", + "\n", + "[INFO] Total training time: 61.98s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "H = {\"train_loss\": [], \"train_acc\": [], \"val_loss\": [], \"val_acc\": []}\n", "\n", @@ -545,13 +806,6 @@ "end_time = time.time()\n", "print(\"\\n[INFO] Total training time: {:.2f}s\".format(end_time - start_time))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tutorials/fgnet_tutorial.ipynb b/tutorials/fgnet_tutorial.ipynb new file mode 100644 index 0000000..d524151 --- /dev/null +++ b/tutorials/fgnet_tutorial.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Importing libraries\n", + "\n", + "From the *Ordinal Deep Learning* package, we import the methods that will allow us to work with ordinal datasets.\n", + "\n", + "We also import methods from libraries such as *pytorch* and *torchvision* that will allow us to process and work with the datasets.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from dlordinal.datasets import FGNet\n", + "from torchvision.transforms import ToTensor, Compose\n", + "import numpy as np" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FGNet\n", + "\n", + "To utilize the [FGNet dataset](https://yanweifu.github.io/FG_NET_data/), two instances of the dataset will be created: one for the training data and one for the test data. Each instance will include the following fields:\n", + "\n", + "* __root__: an attribute that defines the path where the dataset will be downloaded and extracted.\n", + "* __download__: an attribute that indicates the desire to perform the dataset download.\n", + "* __train__: an attribute indicating that only the processed input dataset will be returned if its value is set to TRUE.\n", + "* __target_transform__: an attribute that defines the transformation to be applied to the targets.\n", + "* __transform__: an attribute that defines the transformation to be applied to the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n" + ] + } + ], + "source": [ + "fgnet_train = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", + ")\n", + "\n", + "fgnet_test = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the `FGNet` objects can be used as any other `VisionDataset` from `torchvision`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of samples in the FGNet train dataset: 801\n", + "Targets of the FGNet train dataset: [2 0 0 2 2 3 4 1 3 1 3 3 2 1 2 3 3 2 4 1 3 1 3 2 0 0 3 4 1 3 2 4 1 1 4 4 4\n", + " 1 3 2 2 2 1 0 1 3 3 3 1 1 5 1 2 2 4 0 1 1 1 3 5 1 2 5 0 1 1 1 2 1 3 1 2 1\n", + " 2 1 2 5 1 3 1 2 3 2 4 1 0 1 1 1 0 2 1 4 3 3 2 1 1 3 0 4 1 0 5 1 3 3 4 1 2\n", + " 4 5 1 0 4 5 1 0 1 3 4 1 2 5 3 1 3 4 4 4 2 1 3 3 3 2 3 2 1 0 0 1 3 1 3 1 2\n", + " 5 3 1 4 4 0 2 3 3 2 4 1 3 1 3 2 5 3 1 3 2 1 2 1 1 4 1 1 1 2 3 1 1 4 4 2 3\n", + " 5 1 4 5 3 5 1 3 0 0 1 3 1 3 1 5 4 1 3 4 1 5 1 0 0 1 3 2 1 4 3 1 1 2 4 1 1\n", + " 3 1 1 2 1 1 3 3 1 2 4 3 3 0 2 5 4 1 2 1 3 0 1 4 4 2 1 4 3 2 5 3 3 5 3 1 0\n", + " 4 0 2 3 2 4 1 1 0 4 1 2 2 4 2 0 1 1 0 4 1 1 2 3 4 1 2 4 3 3 1 4 3 5 3 3 2\n", + " 2 1 4 4 0 2 0 1 5 0 2 0 4 1 3 4 5 0 4 4 3 1 5 1 1 1 3 2 1 2 3 1 4 2 0 3 3\n", + " 3 4 0 1 1 1 3 2 0 1 1 2 5 1 1 3 2 1 2 0 1 3 3 3 1 1 3 1 1 2 2 1 3 4 3 0 3\n", + " 4 2 1 1 3 2 4 3 3 4 1 5 1 1 2 1 1 4 1 1 2 3 1 2 2 0 5 1 3 1 3 0 3 4 1 1 2\n", + " 1 1 1 4 4 0 3 2 1 1 3 1 3 3 2 2 3 3 5 2 4 4 2 4 1 2 1 2 5 1 4 0 2 4 3 0 2\n", + " 4 4 4 1 1 4 3 1 1 1 1 4 2 3 4 1 3 3 0 1 1 1 2 2 3 4 2 3 3 1 3 5 1 3 2 0 0\n", + " 0 0 1 5 3 3 1 4 3 5 3 3 1 0 1 4 2 1 2 1 1 2 1 1 1 0 3 4 2 0 3 4 4 1 5 2 5\n", + " 2 4 4 1 3 4 3 3 1 1 3 1 1 1 1 2 4 1 0 3 0 3 3 2 2 1 0 4 3 1 1 3 1 3 1 1 0\n", + " 2 0 3 2 4 3 1 4 2 3 5 4 1 5 5 4 5 5 3 4 2 5 5 4 3 0 0 1 2 3 0 0 1 2 4 3 3\n", + " 5 4 1 4 5 5 2 0 1 1 0 1 2 4 2 4 2 3 4 4 4 1 4 0 1 3 1 2 5 3 1 0 1 2 4 4 3\n", + " 5 0 0 0 1 1 3 0 5 1 1 1 3 1 2 1 1 3 3 3 2 4 4 1 2 4 5 1 0 1 0 0 1 5 2 1 3\n", + " 2 1 1 2 3 3 3 3 0 0 1 0 1 0 0 1 0 3 1 1 3 5 2 1 1 2 4 3 3 2 3 4 4 4 3 0 1\n", + " 4 2 1 5 1 1 2 0 1 0 4 1 2 0 4 1 4 5 1 0 2 1 3 5 0 2 1 4 0 3 2 1 2 1 4 3 2\n", + " 1 1 2 1 1 1 5 1 1 3 4 1 2 3 3 3 0 1 1 3 4 3 4 2 3 3 1 2 2 4 2 4 1 4 3 3 2\n", + " 1 3 0 0 3 3 4 1 1 0 5 1 2 3 1 1 3 1 1 5 2 4 5 3]\n", + "Classes of the FGNet train dataset: [2 0 3 4 1 5]\n", + "3rd sample of the FGNet train dataset: (tensor([[[0.7294, 0.7294, 0.7294, ..., 0.8039, 0.8039, 0.8039],\n", + " [0.7294, 0.7294, 0.7294, ..., 0.8039, 0.8039, 0.8039],\n", + " [0.7333, 0.7333, 0.7333, ..., 0.8039, 0.8039, 0.8039],\n", + " ...,\n", + " [0.6078, 0.6078, 0.6039, ..., 0.4549, 0.4471, 0.4353],\n", + " [0.6039, 0.6039, 0.6039, ..., 0.4431, 0.4353, 0.4196],\n", + " [0.6039, 0.6039, 0.6000, ..., 0.4314, 0.4235, 0.4196]],\n", + "\n", + " [[0.7490, 0.7490, 0.7490, ..., 0.8157, 0.8157, 0.8157],\n", + " [0.7490, 0.7490, 0.7490, ..., 0.8157, 0.8157, 0.8157],\n", + " [0.7490, 0.7490, 0.7490, ..., 0.8157, 0.8157, 0.8157],\n", + " ...,\n", + " [0.6627, 0.6627, 0.6588, ..., 0.5451, 0.5255, 0.5059],\n", + " [0.6588, 0.6588, 0.6588, ..., 0.5490, 0.5294, 0.5137],\n", + " [0.6588, 0.6588, 0.6549, ..., 0.5451, 0.5294, 0.5255]],\n", + "\n", + " [[0.7647, 0.7647, 0.7608, ..., 0.7725, 0.7725, 0.7725],\n", + " [0.7647, 0.7608, 0.7608, ..., 0.7725, 0.7725, 0.7725],\n", + " [0.7608, 0.7529, 0.7529, ..., 0.7725, 0.7725, 0.7725],\n", + " ...,\n", + " [0.6627, 0.6627, 0.6588, ..., 0.5686, 0.5608, 0.5451],\n", + " [0.6588, 0.6588, 0.6588, ..., 0.5765, 0.5686, 0.5529],\n", + " [0.6588, 0.6588, 0.6549, ..., 0.5765, 0.5725, 0.5686]]]), 2)\n", + "\n", + "\n", + "Number of samples in the FGNet test dataset: 201\n", + "Targets of the FGNet test dataset: [3 0 2 3 1 3 1 4 5 3 2 3 3 0 1 0 4 1 1 1 1 2 2 1 0 1 1 1 3 5 1 2 1 4 4 4 2\n", + " 4 3 2 5 0 4 4 4 2 5 2 5 5 1 1 1 3 1 2 5 1 5 0 2 2 1 2 3 1 1 1 4 0 0 3 1 3\n", + " 4 3 0 2 1 3 5 3 3 3 1 1 1 3 0 1 2 3 3 1 1 4 3 3 2 1 0 4 5 1 1 1 2 3 3 2 3\n", + " 4 1 4 4 2 5 1 1 0 2 0 2 1 3 3 1 1 1 5 3 1 1 3 4 4 1 4 4 1 2 4 4 3 1 0 1 2\n", + " 2 3 0 3 2 5 1 4 4 1 2 2 3 3 2 0 4 1 0 2 1 1 2 1 3 3 3 1 0 0 0 1 1 0 5 1 3\n", + " 4 2 0 3 4 1 3 1 4 3 4 1 2 3 4 2]\n", + "Classes of the FGNet test dataset: [3 0 2 1 4 5]\n", + "3rd sample of the FGNet test dataset: (tensor([[[0.7843, 0.7843, 0.7882, ..., 0.7882, 0.7882, 0.7843],\n", + " [0.7843, 0.7843, 0.7882, ..., 0.7882, 0.7882, 0.7843],\n", + " [0.7843, 0.7843, 0.7882, ..., 0.7882, 0.7843, 0.7843],\n", + " ...,\n", + " [0.3961, 0.3922, 0.3882, ..., 0.8196, 0.7451, 0.6863],\n", + " [0.3765, 0.3725, 0.3725, ..., 0.7451, 0.7020, 0.6784],\n", + " [0.3765, 0.3725, 0.3725, ..., 0.6941, 0.6941, 0.6941]],\n", + "\n", + " [[0.7922, 0.7922, 0.7922, ..., 0.7843, 0.7843, 0.7804],\n", + " [0.7922, 0.7922, 0.7922, ..., 0.7843, 0.7843, 0.7804],\n", + " [0.7922, 0.7922, 0.7922, ..., 0.7843, 0.7804, 0.7804],\n", + " ...,\n", + " [0.3451, 0.3412, 0.3373, ..., 0.6078, 0.5176, 0.4510],\n", + " [0.3333, 0.3294, 0.3294, ..., 0.5098, 0.4510, 0.4157],\n", + " [0.3333, 0.3294, 0.3294, ..., 0.4510, 0.4275, 0.4196]],\n", + "\n", + " [[0.7412, 0.7412, 0.7608, ..., 0.7647, 0.7647, 0.7608],\n", + " [0.7412, 0.7412, 0.7608, ..., 0.7647, 0.7647, 0.7608],\n", + " [0.7412, 0.7412, 0.7608, ..., 0.7647, 0.7608, 0.7608],\n", + " ...,\n", + " [0.2784, 0.2745, 0.2706, ..., 0.5608, 0.4745, 0.4118],\n", + " [0.2627, 0.2588, 0.2588, ..., 0.4667, 0.4118, 0.3804],\n", + " [0.2627, 0.2588, 0.2588, ..., 0.4078, 0.3922, 0.3882]]]), 3)\n" + ] + } + ], + "source": [ + "print(f'Number of samples in the FGNet train dataset: {len(fgnet_train)}')\n", + "print(f'Targets of the FGNet train dataset: {fgnet_train.targets}')\n", + "print(f'Classes of the FGNet train dataset: {fgnet_train.classes}')\n", + "print(f'3rd sample of the FGNet train dataset: {fgnet_train[3]}')\n", + "print(\"\\n\")\n", + "\n", + "print(f'Number of samples in the FGNet test dataset: {len(fgnet_test)}')\n", + "print(f'Targets of the FGNet test dataset: {fgnet_test.targets}')\n", + "print(f'Classes of the FGNet test dataset: {fgnet_test.classes}')\n", + "print(f'3rd sample of the FGNet test dataset: {fgnet_test[3]}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Torch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "385611db6ca4af2663855b1744f455946eef985f7b33eb977c97667790417df3" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/hybrid_dropout_tutorial.ipynb b/tutorials/hybrid_dropout_tutorial.ipynb index 44b45cd..17fbf16 100644 --- a/tutorials/hybrid_dropout_tutorial.ipynb +++ b/tutorials/hybrid_dropout_tutorial.ipynb @@ -29,7 +29,6 @@ "from torch.optim import Adam\n", "from torch.utils.data import DataLoader, Subset\n", "from torchvision import models\n", - "from torchvision.datasets import ImageFolder\n", "from torchvision.transforms import Compose, ToTensor\n", "from tqdm import tqdm\n", "from dlordinal.dropout import HybridDropout, HybridDropoutContainer" @@ -83,37 +82,53 @@ "Files already downloaded and verified\n", "Files already processed and verified\n", "Files already split and verified\n", - "Using cuda device\n", + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Using cpu device\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "Detected image shape: [3, 128, 128]\n", - "class_weights=array([1.5 , 0.55165289, 1.01136364, 0.84493671, 1.0511811 ,\n", - " 2.51886792])\n" + "class_weights=array([1.01908397, 1.53448276, 0.79464286, 1.13135593, 0.55165289,\n", + " 2.42727273])\n" ] } ], "source": [ - "fgnet = FGNet(root=\"./datasets/fgnet\", download=True, process_data=True)\n", - "\n", - "complete_train_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/train\", transform=Compose([ToTensor()])\n", + "fgnet_trainval = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", - "test_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/test\", transform=Compose([ToTensor()])\n", + "\n", + "test_data = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", "\n", - "num_classes = len(complete_train_data.classes)\n", - "classes = complete_train_data.classes\n", - "targets = complete_train_data.targets\n", + "num_classes = len(fgnet_trainval.classes)\n", + "classes = fgnet_trainval.classes\n", + "targets = fgnet_trainval.targets\n", "\n", "# Create a validation split\n", "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)\n", "sss_splits = list(\n", - " sss.split(X=np.zeros(len(complete_train_data)), y=complete_train_data.targets)\n", + " sss.split(X=np.zeros(len(fgnet_trainval)), y=fgnet_trainval.targets)\n", ")\n", "train_idx, val_idx = sss_splits[0]\n", "\n", "# Create subsets for training and validation\n", - "train_data = Subset(complete_train_data, train_idx)\n", - "val_data = Subset(complete_train_data, val_idx)\n", + "train_data = Subset(fgnet_trainval, train_idx)\n", + "val_data = Subset(fgnet_trainval, val_idx)\n", "\n", "# Get CUDA device\n", "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", @@ -528,7 +543,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:28<00:00, 7.01s/it, accuracy=203, loss=1.41]" + "Train progress: 100%|████████████████████████| 4/4 [00:12<00:00, 3.18s/it, accuracy=218, loss=1.58]" ] }, { @@ -537,12 +552,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 2 55 1 1 17 0]\n", - " [ 3 125 6 17 53 1]\n", - " [ 1 54 10 20 24 3]\n", - " [ 3 50 9 31 36 5]\n", - " [ 1 42 4 30 29 2]\n", - " [ 1 8 5 9 16 6]]\n", + "[[ 0 67 1 1 0 5]\n", + " [ 0 181 4 9 2 9]\n", + " [ 1 83 5 16 1 5]\n", + " [ 1 93 14 25 1 9]\n", + " [ 2 60 4 23 3 8]\n", + " [ 0 29 2 10 2 4]]\n", "\n" ] }, @@ -558,8 +573,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 1/5\n", - "Train loss: 1.630586, Train accuracy: 0.2985\n", - "Val loss: 3.244811, Val accuracy: 0.3802\n", + "Train loss: 1.659357, Train accuracy: 0.3206\n", + "Val loss: 3.339191, Val accuracy: 0.3636\n", "\n" ] }, @@ -567,7 +582,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:01<00:00, 3.87it/s, accuracy=413, loss=1.21]" + "Train progress: 100%|████████████████████████| 4/4 [00:11<00:00, 2.91s/it, accuracy=384, loss=1.07]" ] }, { @@ -576,12 +591,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 29 46 0 1 0 0]\n", - " [ 2 187 4 9 3 0]\n", - " [ 2 48 25 33 4 0]\n", - " [ 0 19 10 94 10 1]\n", - " [ 0 12 3 24 69 0]\n", - " [ 0 2 1 16 17 9]]\n", + "[[ 58 15 0 1 0 0]\n", + " [ 17 174 0 13 1 0]\n", + " [ 2 38 1 69 1 0]\n", + " [ 0 11 0 132 0 0]\n", + " [ 1 1 0 79 19 0]\n", + " [ 0 2 0 31 14 0]]\n", "\n" ] }, @@ -597,8 +612,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 2/5\n", - "Train loss: 1.223561, Train accuracy: 0.6074\n", - "Val loss: 2.884943, Val accuracy: 0.4711\n", + "Train loss: 1.194799, Train accuracy: 0.5647\n", + "Val loss: 6.592565, Val accuracy: 0.3967\n", "\n" ] }, @@ -606,7 +621,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 100%|███████████████████████| 4/4 [00:01<00:00, 3.34it/s, accuracy=527, loss=0.776]" + "Train progress: 100%|███████████████████████| 4/4 [00:11<00:00, 2.92s/it, accuracy=521, loss=0.824]" ] }, { @@ -615,12 +630,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 69 7 0 0 0 0]\n", - " [ 12 179 10 1 3 0]\n", - " [ 2 20 46 37 6 1]\n", - " [ 0 2 6 112 13 1]\n", - " [ 0 0 0 16 91 1]\n", - " [ 0 0 2 3 10 30]]\n", + "[[ 68 6 0 0 0 0]\n", + " [ 8 191 2 4 0 0]\n", + " [ 1 27 57 24 2 0]\n", + " [ 0 5 12 110 16 0]\n", + " [ 0 0 3 11 86 0]\n", + " [ 5 0 1 1 31 9]]\n", "\n" ] }, @@ -636,8 +651,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 3/5\n", - "Train loss: 0.839435, Train accuracy: 0.7750\n", - "Val loss: 4.085755, Val accuracy: 0.4545\n", + "Train loss: 0.866719, Train accuracy: 0.7662\n", + "Val loss: 4.893533, Val accuracy: 0.4050\n", "\n" ] }, @@ -645,7 +660,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 100%|███████████████████████| 4/4 [00:00<00:00, 4.01it/s, accuracy=624, loss=0.535]" + "Train progress: 100%|███████████████████████| 4/4 [00:12<00:00, 3.09s/it, accuracy=597, loss=0.662]" ] }, { @@ -654,12 +669,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 71 5 0 0 0 0]\n", - " [ 2 202 1 0 0 0]\n", - " [ 3 9 86 13 1 0]\n", - " [ 0 0 9 116 7 2]\n", - " [ 0 0 0 1 106 1]\n", - " [ 0 1 0 0 1 43]]\n", + "[[ 70 4 0 0 0 0]\n", + " [ 9 184 9 3 0 0]\n", + " [ 0 8 88 15 0 0]\n", + " [ 0 1 10 131 1 0]\n", + " [ 0 0 1 8 91 0]\n", + " [ 0 0 1 1 12 33]]\n", "\n" ] }, @@ -675,8 +690,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 4/5\n", - "Train loss: 0.574313, Train accuracy: 0.9176\n", - "Val loss: 8.397540, Val accuracy: 0.2562\n", + "Train loss: 0.649500, Train accuracy: 0.8779\n", + "Val loss: 5.146667, Val accuracy: 0.5372\n", "\n" ] }, @@ -684,7 +699,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:00<00:00, 4.07it/s, accuracy=659, loss=0.44]" + "Train progress: 100%|███████████████████████| 4/4 [00:11<00:00, 2.81s/it, accuracy=654, loss=0.459]" ] }, { @@ -693,12 +708,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 76 0 0 0 0 0]\n", - " [ 3 201 1 0 0 0]\n", - " [ 0 9 101 2 0 0]\n", - " [ 0 0 3 131 0 0]\n", - " [ 0 0 0 3 105 0]\n", - " [ 0 0 0 0 0 45]]\n", + "[[ 74 0 0 0 0 0]\n", + " [ 1 202 1 0 1 0]\n", + " [ 0 7 95 9 0 0]\n", + " [ 0 0 3 137 3 0]\n", + " [ 0 0 0 0 99 1]\n", + " [ 0 0 0 0 0 47]]\n", "\n" ] }, @@ -714,8 +729,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 5/5\n", - "Train loss: 0.458296, Train accuracy: 0.9691\n", - "Val loss: 5.839002, Val accuracy: 0.4628\n", + "Train loss: 0.486269, Train accuracy: 0.9618\n", + "Val loss: 5.453663, Val accuracy: 0.5289\n", "\n", "[INFO] Network evaluation ...\n" ] @@ -724,7 +739,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Test progress: 100%|███████████████████████████████████████| 2/2 [00:00<00:00, 2.06it/s, loss=17.6]" + "Test progress: 100%|███████████████████████████████████████| 2/2 [00:01<00:00, 1.32it/s, loss=12.6]" ] }, { @@ -733,24 +748,24 @@ "text": [ "\n", "Confusion matrix :\n", - "[[18 2 0 0 0 0]\n", - " [12 42 1 5 0 0]\n", - " [ 0 14 6 12 0 0]\n", - " [ 0 6 4 41 1 0]\n", - " [ 0 1 1 13 2 4]\n", - " [ 0 2 0 5 3 6]]\n", + "[[ 7 15 0 0 0 0]\n", + " [ 1 49 2 7 1 0]\n", + " [ 0 12 1 17 3 0]\n", + " [ 0 3 3 26 10 0]\n", + " [ 0 0 0 7 23 0]\n", + " [ 0 0 0 0 14 0]]\n", "\n", - "MS: 0.0952\n", + "MS: 0.0000\n", "\n", - "QWK: 0.7815\n", + "QWK: 0.8185\n", "\n", - "MAE: 0.5522\n", + "MAE: 0.5473\n", "\n", - "CCR: 0.5721\n", + "CCR: 0.5274\n", "\n", - "1-off: 0.9005\n", + "1-off: 0.9303\n", "\n", - "[INFO] Total training time: 35.99s\n" + "[INFO] Total training time: 66.07s\n" ] }, { diff --git a/tutorials/losses_tutorial.ipynb b/tutorials/losses_tutorial.ipynb index 0703369..791d6c1 100644 --- a/tutorials/losses_tutorial.ipynb +++ b/tutorials/losses_tutorial.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,6 @@ "from torch.optim import Adam\n", "from torch.utils.data import DataLoader, Subset\n", "from torchvision import models\n", - "from torchvision.datasets import ImageFolder\n", "from torchvision.transforms import Compose, ToTensor\n", "from tqdm import tqdm" ] @@ -52,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -87,44 +86,47 @@ "Files already downloaded and verified\n", "Files already processed and verified\n", "Files already split and verified\n", - "num_classes=6\n", - "Using cuda device\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Using cpu device\n", "Detected image shape: [3, 128, 128]\n", - "class_weights=array([1.60843373, 0.55394191, 1.02692308, 0.78070175, 1.12184874,\n", - " 2.34210526])\n" + "class_weights=array([1.01908397, 1.53448276, 0.79464286, 1.13135593, 0.55165289,\n", + " 2.42727273])\n" ] } ], "source": [ - "fgnet = FGNet(root=\"./datasets/fgnet\", download=True, process_data=True)\n", - "\n", - "complete_train_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/train\", transform=Compose([ToTensor()])\n", + "fgnet_trainval = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", - "test_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/test\", transform=Compose([ToTensor()])\n", + "\n", + "test_data = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", "\n", - "num_classes = len(complete_train_data.classes)\n", - "classes = complete_train_data.classes\n", - "targets = complete_train_data.targets\n", + "num_classes = len(fgnet_trainval.classes)\n", + "classes = fgnet_trainval.classes\n", + "targets = fgnet_trainval.targets\n", "\n", "# Create a validation split\n", "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)\n", "sss_splits = list(\n", - " sss.split(X=np.zeros(len(complete_train_data)), y=complete_train_data.targets)\n", + " sss.split(X=np.zeros(len(fgnet_trainval)), y=fgnet_trainval.targets)\n", ")\n", "train_idx, val_idx = sss_splits[0]\n", "\n", "# Create subsets for training and validation\n", - "train_data = Subset(complete_train_data, train_idx)\n", - "val_data = Subset(complete_train_data, val_idx)\n", + "train_data = Subset(fgnet_trainval, train_idx)\n", + "val_data = Subset(fgnet_trainval, val_idx)\n", "\n", "# Get CUDA device\n", "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", @@ -174,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -226,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -245,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -345,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -423,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -480,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -545,249 +547,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 50%|████████████ | 2/4 [00:01<00:01, 1.69it/s, accuracy=106, loss=1.62]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0.4965, 0.5200, 0.2156, 0.9261, -0.6116, 1.0949],\n", - " [-0.4715, -0.7595, 1.1330, 0.7932, 0.0749, 1.2884],\n", - " [ 0.8929, 0.5330, 0.0984, 0.3900, -0.7238, 0.4939],\n", - " ...,\n", - " [ 0.3272, -0.5772, -0.2451, -0.1964, -1.1121, 1.1002],\n", - " [ 0.3922, -1.2060, -0.2031, 0.7491, -0.4196, -0.4266],\n", - " [-0.1965, -0.1802, 1.1205, 0.7332, 0.4739, 0.6164]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 3, 1, 3, 1, 5, 4, 1, 3, 3, 4, 1, 1, 1, 5, 1, 1, 0, 1, 4, 5, 1, 1, 0,\n", - " 3, 2, 5, 3, 2, 1, 3, 2, 3, 1, 0, 2, 4, 0, 3, 4, 5, 3, 1, 5, 1, 4, 1, 1,\n", - " 2, 3, 5, 1, 3, 3, 0, 2, 2, 1, 4, 1, 1, 2, 2, 4, 2, 3, 1, 3, 3, 1, 2, 1,\n", - " 4, 3, 1, 3, 4, 1, 0, 5, 1, 2, 3, 0, 4, 2, 2, 2, 1, 1, 4, 4, 1, 2, 0, 4,\n", - " 1, 2, 1, 1, 1, 4, 4, 3, 2, 2, 0, 0, 4, 3, 5, 0, 2, 4, 3, 5, 1, 2, 0, 2,\n", - " 1, 4, 2, 4, 3, 3, 1, 4, 1, 1, 3, 5, 4, 4, 3, 3, 4, 3, 2, 4, 1, 1, 3, 2,\n", - " 3, 5, 1, 2, 3, 2, 1, 0, 5, 1, 4, 4, 4, 4, 1, 3, 3, 1, 3, 2, 0, 5, 5, 0,\n", - " 1, 1, 1, 2, 3, 4, 4, 2, 1, 3, 5, 1, 2, 1, 3, 2, 1, 1, 3, 3, 1, 1, 1, 1,\n", - " 5, 1, 4, 4, 4, 0, 3, 1], device='cuda:0')\n", - "tensor([[-0.6226, -0.2386, 0.8428, 0.0764, 0.8742, -1.1802],\n", - " [ 0.4918, -1.5402, 1.2385, 0.9472, 0.4741, -0.7197],\n", - " [ 0.5192, 0.8373, 0.4735, -0.2677, -0.3455, -0.7448],\n", - " ...,\n", - " [ 1.4621, 1.5240, 0.5250, 0.0796, -2.1565, -0.9186],\n", - " [-0.2466, 1.7202, 0.1187, -1.0205, -0.8667, -0.1564],\n", - " [ 0.1776, 1.8291, 0.1979, 0.5532, -0.1398, 0.4808]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 3, 2, 4, 4, 2, 4, 1, 2, 1, 1, 4, 0, 3, 2, 3, 4, 5, 3, 1, 3, 3, 2, 4,\n", - " 1, 4, 1, 0, 0, 1, 3, 4, 5, 3, 1, 1, 0, 1, 1, 2, 3, 1, 5, 1, 1, 0, 1, 3,\n", - " 0, 1, 1, 1, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1, 3, 1, 4, 1, 3, 2, 4, 4, 4, 0,\n", - " 0, 3, 3, 1, 0, 1, 0, 2, 5, 4, 1, 4, 1, 2, 4, 1, 2, 0, 1, 1, 3, 2, 0, 3,\n", - " 4, 1, 2, 2, 3, 0, 3, 2, 1, 4, 1, 2, 3, 3, 2, 3, 1, 3, 1, 4, 3, 2, 0, 4,\n", - " 1, 0, 5, 0, 2, 0, 2, 3, 1, 1, 3, 3, 0, 4, 1, 1, 2, 2, 0, 1, 5, 5, 3, 3,\n", - " 1, 4, 3, 1, 2, 1, 2, 1, 0, 2, 3, 5, 1, 2, 0, 0, 2, 1, 1, 0, 2, 1, 4, 3,\n", - " 1, 1, 5, 1, 0, 1, 4, 3, 3, 0, 2, 2, 0, 2, 3, 5, 2, 3, 2, 3, 2, 2, 1, 3,\n", - " 3, 2, 0, 1, 2, 1, 1, 1], device='cuda:0')\n" + "Train progress: 0%| | 0/4 [00:00)\n", - "tensor([1, 4, 3, 3, 3, 2, 3, 1, 2, 3, 3, 4, 1, 2, 2, 4, 4, 1, 2, 2, 2, 1, 1, 3,\n", - " 3, 5, 5, 4, 5, 2, 0, 3, 1, 1, 3, 4, 3, 2, 1, 1, 1, 1, 3, 1, 3, 5, 1, 4,\n", - " 1, 3, 3, 1, 1, 2, 1, 2, 3, 4, 2, 2, 0, 0, 1, 3, 4, 4, 1, 1, 3, 4, 3, 3,\n", - " 3, 0, 3, 1, 3, 0, 0, 1, 5, 1, 3, 1, 1, 2, 5, 4, 5, 3, 4, 0, 1, 2, 0, 3,\n", - " 0, 1, 1, 3, 1, 2, 2, 0, 5, 1, 1, 3, 0, 3, 4, 2, 2, 4, 1, 0, 3, 2, 1, 0,\n", - " 1, 4, 4, 1, 3, 5, 0, 3, 2, 1, 3, 3, 3, 1, 4, 1, 3, 4, 0, 1, 3, 1, 2, 1,\n", - " 1, 1, 3, 2, 4, 4, 1, 3, 3, 3, 5, 0, 2, 4, 1, 3, 3, 1, 4, 2, 1, 3, 2, 4,\n", - " 1, 4, 1, 0, 3, 5, 4, 2, 1, 1, 2, 3, 1, 3, 5, 1, 5, 0, 1, 1, 1, 1, 1, 1,\n", - " 3, 1, 4, 4, 2, 0, 1, 0], device='cuda:0')\n", - "tensor([[ 5.0785e-01, 2.9463e+00, 2.0519e-01, -6.4336e-01, -1.4951e+00,\n", - " -2.9072e+00],\n", - " [-3.9066e+00, -9.5237e-01, 1.2700e+00, 3.4056e+00, 4.8741e-01,\n", - " 1.3640e+00],\n", - " [-1.9492e+00, 5.5756e-04, 1.5570e-01, 1.7706e+00, 7.4992e-01,\n", - " -8.2138e-02],\n", - " [ 8.3534e-02, 2.8151e+00, 1.4220e+00, -1.0271e+00, -1.0803e+00,\n", - " -3.0577e+00],\n", - " [-3.8524e+00, -7.1910e-01, 1.5204e+00, 1.4530e+00, 6.6245e-01,\n", - " 1.3033e+00],\n", - " [ 4.9016e+00, 4.5504e+00, 7.1953e-01, -3.2331e+00, -1.6977e+00,\n", - " -4.0197e+00],\n", - " [ 9.2423e-01, 4.8244e+00, 8.9677e-01, -4.6250e-01, -2.6709e-01,\n", - " -2.5811e+00],\n", - " [ 3.7795e+00, 3.3462e+00, 1.2082e+00, -3.5658e+00, -1.1291e+00,\n", - " -2.6120e+00],\n", - " [-2.5993e+00, -3.4399e-01, 1.0973e+00, 1.8007e+00, -1.2054e-01,\n", - " -4.3068e-01],\n", - " [-5.6877e+00, -3.0176e+00, 2.5947e-01, 8.9188e-01, 2.0930e-01,\n", - " 4.6083e+00],\n", - " [ 3.0521e-02, 1.8225e+00, 6.0508e-01, -2.4535e+00, -5.1980e-01,\n", - " -1.5503e+00],\n", - " [ 1.1276e+00, 2.0274e+00, 8.5653e-01, -4.4079e-01, -6.4320e-02,\n", - " -3.1030e+00],\n", - " [-3.1500e+00, -1.4696e+00, 1.6117e-01, 1.4539e+00, 1.1949e+00,\n", - " 9.2815e-01],\n", - " [-1.0993e+00, 1.3193e+00, 6.9700e-01, 1.6869e+00, -1.5432e+00,\n", - " -2.1687e+00],\n", - " [-2.9303e+00, 1.0355e+00, 1.0843e+00, 2.5226e+00, -4.9454e-01,\n", - " -1.5932e+00],\n", - " [ 1.4987e+00, -1.7278e-02, -1.4861e-02, -4.6216e-01, 1.5976e+00,\n", - " -1.8801e-01],\n", - " [-2.8412e+00, -8.2997e-01, 1.6827e+00, 1.0800e+00, 1.6243e+00,\n", - " 5.7521e-03],\n", - " [-3.0704e+00, -1.7520e+00, 6.0630e-01, 1.3353e+00, 8.6340e-01,\n", - " 1.5803e+00],\n", - " [-9.8278e-01, -2.3779e-01, 4.0921e-01, 1.6559e+00, 3.9322e-01,\n", - " -1.6939e+00],\n", - " [ 3.2536e+00, 4.2503e+00, 1.6639e+00, -1.8315e+00, -2.1549e+00,\n", - " -3.0990e+00],\n", - " [-1.6405e+00, 6.8942e-02, 1.6318e-01, 8.2727e-01, 6.6896e-01,\n", - " 1.8618e+00],\n", - " [ 2.9201e+00, 1.5114e+00, 5.0524e-01, -1.6511e-01, 2.9880e-01,\n", - " -2.1735e+00],\n", - " [-2.5374e+00, -4.7302e-01, 1.8691e+00, 1.9098e+00, 1.1853e+00,\n", - " -4.3656e-01],\n", - " [ 3.3823e-01, 3.1440e+00, 8.2310e-01, -1.4717e+00, -9.2432e-01,\n", - " -1.7311e+00],\n", - " [ 5.0953e+00, 4.3090e+00, 1.1053e+00, -2.7455e+00, -1.5613e+00,\n", - " -3.7000e+00],\n", - " [-1.5205e+00, 7.7727e-01, 1.1180e+00, 9.0386e-01, -7.6575e-01,\n", - " -1.6860e+00],\n", - " [-2.3669e+00, 4.5491e-01, 9.5854e-01, 1.0452e+00, 1.6990e-01,\n", - " -2.6355e-01],\n", - " [-2.8932e+00, -1.7233e+00, 1.7979e+00, 2.4833e+00, 1.1864e+00,\n", - " -7.4040e-02],\n", - " [-2.2461e+00, -5.8027e-02, 2.6228e-01, 1.1173e+00, 1.2226e+00,\n", - " 1.1392e-01],\n", - " [-1.9828e+00, -9.9089e-01, 6.3517e-02, 1.4670e+00, 8.9338e-01,\n", - " 2.6797e-01],\n", - " [-3.7190e-01, 2.6808e+00, 1.0387e+00, 6.3618e-01, -1.4385e+00,\n", - " -2.9805e+00],\n", - " [ 7.1242e-02, -4.0289e-01, -1.7164e-01, 1.0950e+00, 9.3152e-01,\n", - " -1.5375e+00],\n", - " [ 3.1002e+00, 1.9405e+00, 7.5105e-01, -1.6233e+00, 4.4521e-02,\n", - " -3.1467e+00],\n", - " [-2.3722e-01, 1.3470e+00, 1.1610e+00, -5.2840e-02, 1.8195e-01,\n", - " -2.1983e+00],\n", - " [-2.1503e+00, 1.2085e+00, 1.0802e+00, 2.8033e+00, -3.0764e-01,\n", - " -2.7647e+00],\n", - " [-1.4158e+00, -1.0916e+00, 6.0906e-01, 1.0527e+00, 1.3323e+00,\n", - " 9.0856e-01],\n", - " [-4.7977e-01, 2.2366e+00, 4.3573e-01, 1.0780e+00, -2.7697e-01,\n", - " -2.3948e+00],\n", - " [-3.1155e+00, -1.6745e+00, 1.0671e+00, 1.1734e+00, 5.6097e-01,\n", - " 2.1138e+00],\n", - " [-4.2830e+00, -1.2069e+00, 4.9926e-01, 2.1320e+00, 1.7750e+00,\n", - " 8.0150e-01],\n", - " [-5.5305e-01, 2.4604e+00, 1.4384e+00, 2.5068e-01, -5.4250e-01,\n", - " -2.8098e+00],\n", - " [-1.8052e+00, 1.4433e+00, 2.1587e+00, 1.9838e+00, -1.4471e+00,\n", - " -2.4890e+00],\n", - " [-2.4005e+00, 6.8539e-01, 9.8687e-02, 1.3604e+00, 4.8813e-01,\n", - " -1.3280e+00],\n", - " [ 3.9354e+00, 4.6335e+00, 1.8245e-01, -1.7969e+00, -1.5609e+00,\n", - " -3.1146e+00],\n", - " [-7.6038e-01, 1.2493e+00, 1.0653e+00, 1.2007e+00, -1.2995e+00,\n", - " -2.2076e+00],\n", - " [ 5.4903e+00, 2.5086e+00, 6.5000e-01, -2.8259e+00, 1.7969e-01,\n", - " -3.5369e+00],\n", - " [ 1.5038e+00, 3.6694e+00, 3.9354e-01, -7.3680e-01, -1.3285e+00,\n", - " -3.6811e+00],\n", - " [ 1.2969e+00, 3.1081e+00, 3.6188e-01, -5.2957e-01, -1.3221e+00,\n", - " -2.8087e+00],\n", - " [-1.3090e-01, -1.0887e+00, 1.0467e+00, 2.2517e-01, 1.6431e+00,\n", - " -5.2488e-01],\n", - " [-7.5452e-01, 1.5061e+00, 1.5349e+00, -5.0616e-01, -1.4828e+00,\n", - " -5.3993e-01],\n", - " [ 9.5122e-01, 1.2501e+00, 1.0010e+00, -1.1966e+00, -5.1521e-01,\n", - " -9.6049e-01],\n", - " [-3.9732e+00, -3.4292e+00, -6.9478e-01, 1.1377e+00, 2.7303e+00,\n", - " 4.1629e+00],\n", - " [-2.1543e+00, 7.7302e-01, 7.5039e-01, 2.1543e+00, 5.5946e-01,\n", - " -1.4266e+00],\n", - " [-1.7515e+00, 2.9126e-01, 1.6751e+00, 1.7331e+00, -3.1740e-01,\n", - " -9.3652e-01],\n", - " [-1.4391e+00, 7.7127e-01, 2.4291e+00, 6.1647e-01, -8.6179e-02,\n", - " -1.9039e+00],\n", - " [ 9.2329e-01, 1.8318e+00, 1.4539e+00, -8.4983e-01, -8.7171e-01,\n", - " -1.5879e+00],\n", - " [-7.2465e-01, -5.7600e-02, 4.5831e-01, 1.4647e+00, 6.0855e-01,\n", - " -8.1992e-01],\n", - " [ 1.9350e+00, 4.0024e+00, -8.5899e-02, -1.4729e+00, -5.6349e-01,\n", - " -2.6009e+00],\n", - " [-2.1015e+00, 2.7938e-02, 1.3945e+00, 1.3888e+00, -2.0826e-01,\n", - " -1.0349e+00],\n", - " [-1.8656e+00, 5.8421e-01, 1.0487e+00, 1.3087e+00, -1.8196e-01,\n", - " 3.3737e-01],\n", - " [-4.2524e+00, -1.8048e+00, 2.0297e+00, 3.2886e+00, 1.2285e+00,\n", - " 2.2204e-01],\n", - " [ 7.7276e+00, 4.0654e+00, 8.4859e-01, -4.9864e+00, -6.2105e-03,\n", - " -2.6726e+00],\n", - " [-1.6841e+00, -2.2979e-01, 1.5988e+00, 4.5938e-01, 9.3984e-01,\n", - " -1.1130e+00],\n", - " [ 2.0038e-01, 3.0273e-01, 3.8662e-01, 5.2808e-01, -1.9859e-01,\n", - " -1.4430e+00],\n", - " [-3.3751e+00, 6.1276e-01, 9.5610e-01, 3.2154e+00, 3.1233e-01,\n", - " -1.0388e+00],\n", - " [-4.3014e+00, -2.3547e+00, 1.2787e+00, 2.0740e+00, 1.5496e+00,\n", - " 3.4805e+00],\n", - " [-3.6794e+00, -1.6216e-01, -5.9225e-01, 2.5496e-01, 1.0738e+00,\n", - " 1.7842e+00],\n", - " [-1.1106e+00, 1.3425e+00, 1.1768e+00, 1.9018e-01, -3.2973e-01,\n", - " -1.2715e+00],\n", - " [ 4.1599e+00, 2.2493e+00, 6.1875e-01, -2.9866e+00, -1.5451e-01,\n", - " -2.0031e+00],\n", - " [ 7.3384e-01, 3.6260e+00, 5.0990e-01, -7.2424e-01, -1.1784e+00,\n", - " -2.9340e+00],\n", - " [ 5.6991e-01, 1.5751e+00, 4.7283e-01, -8.0644e-01, -3.6754e-01,\n", - " -1.7110e+00],\n", - " [-1.0967e+00, 9.1527e-01, 1.9903e+00, 1.2590e+00, -6.0013e-02,\n", - " -1.6075e+00],\n", - " [ 3.9550e-02, 8.5155e-01, 1.5364e+00, 1.6481e-01, -6.2684e-01,\n", - " -1.9494e+00],\n", - " [-1.8348e+00, 2.0944e-01, 9.6946e-01, 2.7326e+00, 3.3227e-01,\n", - " -1.9949e+00],\n", - " [-2.2953e+00, -3.1102e-01, 1.1002e+00, 1.0050e+00, 7.5941e-01,\n", - " 1.4559e+00],\n", - " [-4.2828e+00, -2.2393e+00, -7.1243e-03, 6.6680e-01, 3.4315e-01,\n", - " 3.3873e+00],\n", - " [ 5.1251e+00, 3.4887e+00, -1.7700e-01, -4.8753e+00, -8.6678e-01,\n", - " -1.7642e+00],\n", - " [-2.9411e+00, -1.7802e+00, 2.5235e-01, 2.1104e+00, 6.5751e-01,\n", - " 8.6060e-01],\n", - " [ 4.0611e-01, 2.4398e+00, 5.5562e-01, 4.4931e-01, -1.8003e-01,\n", - " -3.2134e+00],\n", - " [-1.1275e+00, -4.8439e-01, 7.1867e-01, 1.5762e+00, 2.1451e+00,\n", - " -3.2391e-01],\n", - " [ 1.4330e+00, 3.9699e+00, 1.2871e+00, -2.3817e+00, -1.1352e+00,\n", - " -3.0764e+00]], device='cuda:0', grad_fn=)\n", - "tensor([1, 3, 4, 1, 2, 0, 1, 0, 4, 5, 2, 1, 3, 2, 3, 1, 3, 4, 1, 1, 5, 0, 3, 1,\n", - " 1, 1, 2, 3, 4, 4, 3, 4, 0, 1, 3, 4, 2, 2, 4, 2, 3, 3, 1, 3, 0, 1, 1, 4,\n", - " 1, 1, 5, 5, 3, 2, 1, 3, 1, 4, 4, 4, 0, 3, 3, 2, 5, 5, 4, 0, 1, 1, 2, 2,\n", - " 3, 4, 3, 0, 4, 2, 4, 1], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[44 17 0 9 0 1]\n", - " [34 75 43 35 14 4]\n", - " [10 24 32 32 7 5]\n", - " [23 14 43 39 18 8]\n", - " [13 7 16 32 25 8]\n", - " [ 4 3 6 8 10 17]]\n", + "[[30 13 9 18 2 2]\n", + " [33 58 51 37 17 9]\n", + " [ 3 35 30 23 14 6]\n", + " [ 3 20 43 40 25 12]\n", + " [ 1 23 22 23 19 12]\n", + " [ 1 7 10 9 11 9]]\n", "\n" ] }, @@ -803,8 +584,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 1/5\n", - "Train loss: 1.656975, Train accuracy: 0.3412\n", - "Val loss: 2.432539, Val accuracy: 0.2562\n", + "Train loss: 1.912837, Train accuracy: 0.2735\n", + "Val loss: 1.803974, Val accuracy: 0.2893\n", "\n" ] }, @@ -812,249 +593,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:01, 2.16it/s, accuracy=264, loss=1.18]" + "Train progress: 100%|████████████████████████| 4/4 [00:10<00:00, 2.62s/it, accuracy=282, loss=1.55]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 5.7647, 2.5873, 1.0617, -2.8842, -1.4300, -4.6185],\n", - " [-0.4238, 1.6832, 1.0209, 0.4468, -0.2965, -2.5169],\n", - " [-4.7260, -0.1016, 1.1316, 3.5639, 1.7196, -0.6065],\n", - " ...,\n", - " [ 2.7263, 1.2018, 1.3862, -1.1125, -0.8777, -2.0362],\n", - " [-0.2673, 3.3685, 1.7865, -2.5046, -0.9490, -1.6585],\n", - " [-1.0268, 0.3049, 1.1081, 1.0508, 0.7597, -2.3137]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([0, 1, 2, 2, 2, 3, 4, 4, 1, 1, 2, 4, 1, 1, 3, 4, 0, 0, 5, 0, 1, 3, 1, 3,\n", - " 1, 1, 0, 1, 4, 4, 4, 4, 4, 2, 4, 2, 2, 0, 3, 4, 4, 1, 4, 0, 4, 2, 4, 2,\n", - " 1, 0, 5, 5, 3, 1, 4, 4, 0, 4, 4, 1, 1, 1, 2, 3, 3, 3, 1, 3, 0, 1, 1, 0,\n", - " 1, 0, 4, 0, 5, 3, 1, 1, 2, 1, 3, 3, 4, 1, 4, 1, 4, 4, 5, 4, 0, 1, 2, 1,\n", - " 3, 1, 3, 0, 5, 1, 1, 1, 1, 2, 1, 1, 3, 5, 4, 3, 4, 1, 3, 1, 3, 1, 2, 3,\n", - " 3, 1, 0, 0, 2, 1, 2, 1, 2, 2, 1, 0, 2, 2, 3, 4, 1, 2, 1, 2, 1, 3, 0, 2,\n", - " 3, 1, 2, 1, 0, 1, 0, 1, 3, 1, 2, 3, 1, 1, 3, 1, 4, 5, 5, 3, 4, 3, 3, 5,\n", - " 3, 2, 1, 4, 3, 1, 2, 1, 4, 2, 1, 1, 4, 2, 0, 3, 3, 4, 3, 4, 3, 2, 3, 2,\n", - " 3, 2, 1, 1, 5, 0, 2, 2], device='cuda:0')\n", - "tensor([[-1.5025, 0.6489, 1.7107, 1.1047, -0.5289, -2.6486],\n", - " [-5.1265, -1.1864, -0.2679, -0.6918, 1.1148, 5.7352],\n", - " [-3.0374, -0.9213, 3.8099, 3.3727, 0.4661, -2.7374],\n", - " ...,\n", - " [-2.2769, -0.5094, 0.7765, 1.7265, 1.8101, -1.3096],\n", - " [-1.0978, 1.4358, 2.0385, 0.5130, -0.3928, -2.5509],\n", - " [-2.0134, 1.4741, 2.4085, 2.0922, -0.2273, -2.5904]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 5, 2, 0, 2, 1, 1, 1, 1, 4, 0, 1, 2, 1, 1, 4, 2, 2, 5, 1, 1, 4, 5, 1,\n", - " 1, 3, 4, 2, 2, 3, 3, 4, 1, 3, 0, 1, 5, 2, 1, 1, 3, 2, 1, 0, 2, 0, 4, 3,\n", - " 2, 2, 1, 2, 3, 3, 5, 3, 4, 1, 5, 3, 1, 4, 1, 1, 0, 1, 4, 3, 3, 4, 1, 3,\n", - " 2, 3, 3, 1, 2, 3, 1, 0, 5, 1, 4, 5, 4, 2, 0, 4, 3, 1, 3, 3, 5, 4, 2, 4,\n", - " 1, 3, 1, 1, 4, 5, 1, 4, 1, 4, 5, 1, 1, 2, 1, 0, 1, 1, 3, 1, 1, 2, 1, 1,\n", - " 3, 1, 3, 0, 0, 3, 2, 3, 1, 2, 1, 4, 1, 3, 1, 5, 3, 4, 2, 4, 3, 3, 1, 4,\n", - " 2, 1, 4, 2, 1, 5, 0, 5, 3, 1, 2, 3, 1, 4, 3, 4, 3, 2, 4, 2, 3, 4, 4, 5,\n", - " 4, 3, 3, 4, 3, 5, 3, 3, 0, 2, 1, 3, 3, 1, 1, 1, 2, 0, 4, 2, 1, 3, 0, 2,\n", - " 0, 3, 5, 0, 1, 3, 2, 3], device='cuda:0')\n", - "tensor([[ 1.2951, 2.7126, 1.9132, -1.9032, -0.8930, -2.1204],\n", - " [-1.1742, 0.1655, 1.5694, 1.4164, 0.5871, -2.1103],\n", - " [-4.0437, -1.0213, 0.9476, 2.1266, 2.9856, 0.8012],\n", - " ...,\n", - " [ 0.7465, 2.3658, 1.5894, -1.0207, -1.2942, -3.0039],\n", - " [-2.9101, 0.3122, 1.6026, 2.5441, 0.8244, -1.4983],\n", - " [ 2.4817, 4.2817, 2.3485, -1.9702, -2.1477, -4.0261]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 2, 4, 1, 4, 4, 0, 1, 4, 2, 1, 0, 3, 0, 3, 4, 2, 3, 1, 1, 0, 3, 1, 4,\n", - " 1, 1, 3, 1, 4, 4, 4, 1, 2, 2, 0, 1, 2, 5, 2, 1, 2, 1, 1, 3, 1, 0, 2, 3,\n", - " 3, 3, 4, 3, 0, 0, 4, 5, 1, 2, 2, 1, 5, 3, 3, 1, 5, 0, 5, 3, 4, 1, 4, 0,\n", - " 1, 3, 3, 1, 5, 2, 1, 1, 2, 3, 1, 2, 1, 5, 1, 5, 4, 5, 1, 2, 1, 1, 4, 4,\n", - " 1, 1, 0, 1, 1, 1, 3, 0, 3, 1, 1, 4, 4, 2, 2, 3, 3, 1, 2, 3, 2, 2, 1, 3,\n", - " 0, 2, 3, 3, 0, 1, 1, 4, 5, 1, 1, 1, 1, 3, 3, 5, 1, 5, 3, 4, 1, 1, 0, 1,\n", - " 3, 4, 3, 2, 3, 3, 4, 3, 3, 0, 3, 1, 5, 3, 0, 1, 0, 5, 3, 0, 1, 3, 1, 0,\n", - " 3, 2, 0, 0, 2, 1, 3, 5, 1, 0, 5, 2, 3, 2, 3, 3, 3, 3, 1, 3, 4, 2, 5, 2,\n", - " 1, 1, 0, 2, 4, 1, 3, 1], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:00<00:00, 4.94it/s, accuracy=464, loss=1.16]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-1.2981e+00, 3.6719e-02, 1.3975e+00, 1.9658e+00, 9.8866e-01,\n", - " -3.1032e+00],\n", - " [-5.1254e+00, -3.1349e-01, 2.8030e+00, 3.6724e+00, 2.5594e+00,\n", - " -1.3290e+00],\n", - " [ 1.0225e+00, 2.9629e+00, 2.3145e+00, -7.6046e-01, -1.2127e+00,\n", - " -5.3115e+00],\n", - " [-9.8114e-01, 1.7212e+00, 2.1476e+00, 4.2624e-01, 6.4795e-01,\n", - " -3.3145e+00],\n", - " [ 5.6081e-01, 2.1507e+00, 1.5812e+00, 8.1633e-02, -3.0267e-01,\n", - " -3.7644e+00],\n", - " [ 2.7219e+00, 3.2385e+00, 4.5577e-01, -1.5888e+00, -1.7410e+00,\n", - " -3.0634e+00],\n", - " [-2.1623e+00, 1.0956e+00, 9.7882e-01, 6.4806e-01, 4.8530e-01,\n", - " 1.0750e+00],\n", - " [-2.1557e+00, 6.6095e-01, -8.2063e-01, 7.4062e-01, 1.6283e+00,\n", - " 1.1100e-01],\n", - " [ 1.5759e+00, 3.4091e+00, 4.0915e-01, -6.5247e-01, -1.5097e+00,\n", - " -1.7020e+00],\n", - " [-2.3667e+00, 1.1188e+00, 1.3944e+00, 1.5247e+00, 4.9559e-01,\n", - " -1.8111e+00],\n", - " [-2.5314e+00, -2.5184e-01, 2.4474e+00, 3.0658e+00, 6.6135e-01,\n", - " -2.1589e+00],\n", - " [-2.4207e+00, 2.7240e-01, 2.5886e+00, 1.8213e+00, 6.6707e-01,\n", - " -2.0807e+00],\n", - " [-4.1269e+00, -4.5043e-01, 9.2161e-01, 1.4691e+00, 3.1527e+00,\n", - " -1.5056e-01],\n", - " [-2.7853e+00, -3.5188e-01, 1.6534e+00, 2.2346e+00, 1.4354e+00,\n", - " -2.0461e+00],\n", - " [ 1.8503e+00, 1.9897e+00, 7.5863e-01, -6.2073e-01, 1.0135e-02,\n", - " -3.6038e+00],\n", - " [-4.4610e+00, -1.2230e+00, 2.7291e+00, 3.2644e+00, 1.7930e+00,\n", - " -1.0184e+00],\n", - " [ 6.1886e-01, 3.1915e+00, 2.7025e+00, -7.4590e-01, -1.9022e+00,\n", - " -3.2335e+00],\n", - " [-3.8002e+00, -1.0500e-01, 1.8049e+00, 3.0342e+00, 1.4849e+00,\n", - " -1.2769e+00],\n", - " [-5.2888e+00, -2.9493e+00, -8.5479e-01, 4.0235e+00, 5.1271e+00,\n", - " 2.0490e+00],\n", - " [-5.4666e+00, -4.9528e-01, 3.3004e+00, 3.2127e+00, 1.3618e+00,\n", - " -5.7120e-01],\n", - " [ 7.1989e-01, 5.2092e+00, 2.6883e+00, -4.0900e-01, -3.5386e+00,\n", - " -3.0361e+00],\n", - " [-1.7176e+00, -3.6512e-01, -1.0818e+00, -7.8799e-01, 2.4587e+00,\n", - " 9.7480e-01],\n", - " [-7.2784e+00, -1.8236e+00, 2.2881e+00, 2.7634e+00, 3.8241e+00,\n", - " 1.5915e+00],\n", - " [-4.3716e+00, -1.6250e+00, 1.0525e+00, 4.1845e+00, 3.6376e+00,\n", - " 3.7620e-01],\n", - " [-4.9016e-01, 2.2416e+00, 2.4430e+00, -2.5016e-02, -8.3553e-01,\n", - " -3.1740e+00],\n", - " [ 1.8930e+00, 1.1850e+00, 2.3317e-01, -6.3692e-01, -9.8048e-02,\n", - " -2.1093e+00],\n", - " [ 4.3807e-01, 2.5475e+00, 3.2457e-01, -6.4670e-01, -1.1239e+00,\n", - " -2.5770e+00],\n", - " [-4.2702e+00, -1.6392e+00, 1.7047e+00, 2.6904e+00, 2.4803e+00,\n", - " 1.9911e-02],\n", - " [ 7.7943e+00, 4.0401e+00, 1.2884e+00, -4.1030e+00, -3.3884e+00,\n", - " -4.5441e+00],\n", - " [ 2.4113e+00, 3.3635e+00, 2.2616e+00, -1.6821e+00, -2.1401e+00,\n", - " -4.3564e+00],\n", - " [-4.0855e+00, -1.5475e-01, 8.3700e-01, 1.1874e+00, 1.5183e+00,\n", - " 1.0926e+00],\n", - " [-8.4415e+00, -2.9288e+00, 5.7940e-03, 3.1943e+00, 4.9659e+00,\n", - " 6.3546e+00],\n", - " [ 6.6418e+00, 3.7318e+00, 5.0536e-01, -3.8989e+00, -3.1989e+00,\n", - " -3.0521e+00],\n", - " [-2.2896e+00, 9.7577e-01, 2.5564e+00, 1.5771e+00, -2.6881e-01,\n", - " -2.9885e+00],\n", - " [ 2.7372e+00, 4.2798e+00, 2.1429e+00, -2.3077e+00, -2.1199e+00,\n", - " -3.2777e+00],\n", - " [ 3.0463e-01, 3.3558e+00, 1.4420e+00, -1.2118e+00, -1.0489e+00,\n", - " -2.8213e+00],\n", - " [-3.3059e+00, -4.4157e-02, 1.9744e+00, 1.7623e+00, 1.4575e+00,\n", - " -1.1185e+00],\n", - " [ 2.7504e+00, 3.1113e+00, 8.1389e-01, -1.1762e+00, -1.6611e+00,\n", - " -3.0482e+00],\n", - " [-2.9674e+00, 2.1317e+00, 3.5365e+00, 4.7639e-01, 1.4342e-01,\n", - " -2.6407e+00],\n", - " [ 2.7043e+00, 3.1055e+00, 9.6020e-01, -1.4170e+00, -1.3620e+00,\n", - " -4.2584e+00],\n", - " [ 1.5848e+00, 2.2054e+00, 1.9485e+00, -4.5409e-01, -1.4943e+00,\n", - " -4.5091e+00],\n", - " [ 1.5181e+00, 1.6401e+00, 5.3803e-01, 1.3239e-01, -7.6496e-01,\n", - " -2.5849e+00],\n", - " [-2.9188e-01, 2.3029e+00, 2.6058e+00, -2.2827e-01, -1.0306e+00,\n", - " -3.9448e+00],\n", - " [ 2.6057e+00, 1.0020e+00, 7.7339e-01, -9.8371e-01, -1.3269e-01,\n", - " -3.3339e+00],\n", - " [ 2.4581e+00, 2.9126e+00, 1.5435e+00, -1.0966e+00, -1.6562e+00,\n", - " -4.3502e+00],\n", - " [-2.9470e+00, 3.4618e-01, 4.8793e-01, 1.2012e+00, 1.6592e+00,\n", - " -8.4760e-02],\n", - " [-5.8648e+00, -9.0010e-01, -9.9714e-01, 6.8681e-01, 2.5240e+00,\n", - " 4.1970e+00],\n", - " [-9.8852e+00, -4.1562e+00, -2.2414e+00, 1.8383e+00, 4.8508e+00,\n", - " 7.3542e+00],\n", - " [-5.2891e+00, -1.5740e+00, 1.9180e+00, 2.3890e+00, 2.4482e+00,\n", - " 1.5224e+00],\n", - " [ 1.1158e+00, 3.3713e+00, 1.8357e+00, -4.6016e-01, -2.4027e+00,\n", - " -3.1247e+00],\n", - " [-1.0906e+00, 7.9975e-01, 2.0089e+00, 1.1730e+00, 2.9028e-01,\n", - " -2.8310e+00],\n", - " [-1.2160e+00, 2.1093e+00, 1.4182e+00, 2.8977e-01, -1.0014e+00,\n", - " -2.2533e+00],\n", - " [ 3.0169e+00, 2.2649e+00, 6.6073e-01, -2.1667e-01, -6.6347e-01,\n", - " -4.0107e+00],\n", - " [ 7.2828e+00, 3.4267e+00, -2.6812e-01, -2.6288e+00, -2.1866e+00,\n", - " -3.7492e+00],\n", - " [ 2.1078e+00, 3.2458e+00, 1.4892e+00, -8.3576e-01, -1.0498e+00,\n", - " -4.0883e+00],\n", - " [ 9.9136e-01, 1.9796e+00, 1.4583e+00, -1.0151e+00, -1.2891e+00,\n", - " -1.3638e+00],\n", - " [-7.7966e-01, 1.1257e+00, 1.2947e+00, 1.3868e-01, 1.3676e-01,\n", - " -1.6489e+00],\n", - " [-6.0922e+00, -1.9489e+00, 8.5093e-01, 3.7377e+00, 3.8795e+00,\n", - " -2.5683e-01],\n", - " [-2.0073e+00, 2.1025e+00, 2.3277e+00, 5.6390e-01, -2.4134e-01,\n", - " -1.4850e+00],\n", - " [-4.4859e+00, -1.1317e-01, 2.5847e+00, 2.9806e+00, 1.2162e+00,\n", - " -8.8207e-01],\n", - " [-8.9899e+00, -3.1544e+00, 7.1012e-01, 2.9412e+00, 3.5958e+00,\n", - " 5.4251e+00],\n", - " [-3.6049e+00, 1.2840e-01, 1.3965e-02, 7.4898e-01, 1.3913e+00,\n", - " 1.8497e+00],\n", - " [-1.2072e+00, 2.7003e+00, 3.5773e-01, -5.4970e-01, -3.8736e-01,\n", - " -4.0733e-01],\n", - " [ 2.1258e+00, 2.3971e+00, 1.6802e+00, -7.5136e-01, -1.1414e+00,\n", - " -5.5124e+00],\n", - " [-3.7435e+00, -3.9802e-01, 2.3528e+00, 2.7016e+00, 1.4353e+00,\n", - " -6.8330e-01],\n", - " [-2.2039e+00, 9.0749e-01, 2.2574e+00, 1.0406e+00, 3.8351e-01,\n", - " -1.6617e+00],\n", - " [-5.3093e+00, -1.6578e+00, 2.5515e-01, 2.3330e+00, 3.1566e+00,\n", - " 1.8179e+00],\n", - " [ 5.9966e-01, 3.7377e+00, 2.4657e+00, -1.7036e+00, -1.9166e+00,\n", - " -3.7371e+00],\n", - " [-2.0344e+00, 8.4664e-01, 2.1123e+00, 1.7406e+00, 2.0791e-01,\n", - " -2.2446e+00],\n", - " [-2.2902e+00, -3.2645e-02, 9.1668e-01, 8.6234e-01, 2.0169e+00,\n", - " -7.5806e-01],\n", - " [-1.5156e+00, 1.5738e+00, 1.7467e+00, 6.6018e-01, -1.5517e-01,\n", - " -9.8385e-01],\n", - " [-8.9564e-01, 2.4162e+00, 3.2291e+00, 2.1519e-01, -1.6003e+00,\n", - " -3.1226e+00],\n", - " [ 8.6466e+00, 5.3089e+00, 9.0610e-01, -4.6141e+00, -3.7922e+00,\n", - " -5.5445e+00],\n", - " [ 4.5800e+00, 5.0493e+00, 2.6587e+00, -2.5016e+00, -4.3738e+00,\n", - " -5.7015e+00],\n", - " [-8.0868e+00, -2.3441e+00, -2.0153e-01, 2.5324e+00, 5.3202e+00,\n", - " 4.3981e+00],\n", - " [-4.0602e+00, -1.1250e+00, 1.9350e+00, 1.9198e+00, 2.2325e+00,\n", - " -2.1904e-01],\n", - " [ 9.8492e-01, 2.3163e+00, 1.2344e+00, -1.1127e+00, -6.5684e-01,\n", - " -8.0291e-01],\n", - " [-3.6433e+00, -6.3669e-01, 1.5297e+00, 2.1103e+00, 1.5818e+00,\n", - " -4.9447e-01],\n", - " [ 2.8359e+00, 3.7350e+00, 1.2176e+00, -1.9901e+00, -1.4417e+00,\n", - " -3.4475e+00],\n", - " [ 4.3541e+00, 4.1501e+00, 1.9728e+00, -2.7605e+00, -3.1534e+00,\n", - " -4.7998e+00]], device='cuda:0', grad_fn=)\n", - "tensor([3, 3, 2, 2, 2, 1, 3, 1, 1, 2, 3, 2, 4, 3, 1, 4, 1, 3, 4, 2, 1, 4, 4, 3,\n", - " 2, 0, 1, 4, 0, 1, 3, 5, 0, 2, 1, 1, 2, 1, 2, 1, 1, 1, 2, 0, 1, 2, 5, 5,\n", - " 4, 1, 2, 1, 0, 0, 1, 1, 1, 3, 1, 3, 3, 4, 1, 3, 3, 2, 3, 1, 3, 4, 1, 2,\n", - " 0, 1, 4, 4, 1, 3, 1, 1], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 54 16 1 0 0 0]\n", - " [ 3 150 36 11 5 0]\n", - " [ 0 13 61 30 6 0]\n", - " [ 0 10 16 100 18 1]\n", - " [ 0 3 4 22 68 4]\n", - " [ 0 1 1 3 12 31]]\n", + "[[43 25 6 0 0 0]\n", + " [30 95 59 11 9 1]\n", + " [ 5 25 40 32 8 1]\n", + " [ 1 21 50 44 21 6]\n", + " [ 1 5 15 24 38 17]\n", + " [ 0 4 3 2 16 22]]\n", "\n" ] }, @@ -1070,8 +623,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 2/5\n", - "Train loss: 1.189767, Train accuracy: 0.6824\n", - "Val loss: 1.697787, Val accuracy: 0.3719\n", + "Train loss: 1.593585, Train accuracy: 0.4147\n", + "Val loss: 1.756185, Val accuracy: 0.2645\n", "\n" ] }, @@ -1079,170 +632,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:01, 2.08it/s, accuracy=322, loss=1.02]" + "Train progress: 100%|████████████████████████| 4/4 [00:08<00:00, 2.19s/it, accuracy=378, loss=1.53]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-1.1047, 0.9514, 2.5697, 1.0762, -0.4552, -3.7422],\n", - " [ 3.3418, 3.8453, 1.8604, -2.1147, -2.5755, -4.4263],\n", - " [-3.2442, -1.0296, -0.6309, 1.0331, 2.5131, 1.4929],\n", - " ...,\n", - " [-3.1408, -0.6257, 1.8499, 3.9204, 1.4363, -0.8423],\n", - " [-1.6896, 0.0204, 1.7826, 2.5442, 0.6649, -2.4143],\n", - " [-6.1421, -1.9445, 3.1884, 4.7668, 2.2768, -1.4381]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 1, 4, 1, 2, 2, 4, 3, 3, 2, 1, 1, 2, 0, 1, 1, 0, 1, 3, 4, 1, 3, 1, 2,\n", - " 1, 3, 3, 2, 1, 4, 3, 1, 1, 1, 4, 3, 3, 2, 0, 1, 1, 2, 5, 1, 4, 3, 4, 1,\n", - " 2, 2, 4, 1, 1, 2, 3, 3, 0, 4, 1, 0, 1, 1, 4, 1, 1, 0, 2, 0, 0, 1, 4, 1,\n", - " 3, 4, 1, 0, 0, 2, 4, 3, 5, 0, 2, 0, 1, 1, 5, 3, 0, 5, 2, 5, 1, 1, 5, 0,\n", - " 1, 2, 4, 3, 1, 1, 4, 1, 1, 3, 1, 3, 1, 1, 2, 3, 4, 2, 2, 4, 2, 5, 2, 2,\n", - " 0, 1, 4, 3, 4, 4, 3, 2, 2, 4, 3, 1, 1, 1, 3, 3, 3, 0, 1, 1, 2, 4, 1, 3,\n", - " 1, 2, 4, 2, 5, 0, 1, 3, 1, 1, 3, 2, 0, 0, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1,\n", - " 0, 1, 2, 3, 4, 5, 5, 5, 1, 5, 4, 1, 1, 1, 3, 1, 2, 2, 5, 1, 2, 3, 0, 2,\n", - " 1, 5, 4, 3, 1, 3, 3, 3], device='cuda:0')\n", - "tensor([[ 3.5107, 4.9996, 2.4691, -2.6681, -2.6880, -5.0174],\n", - " [-4.3821, -0.5751, 2.9944, 3.9737, 1.7038, -2.0697],\n", - " [ 5.4987, 4.0270, 1.8538, -2.9613, -2.8473, -4.8623],\n", - " ...,\n", - " [-4.1906, -0.0205, 2.8085, 3.6087, 0.3009, -1.4167],\n", - " [ 3.6800, 5.3459, 3.0893, -2.3959, -2.9190, -5.2647],\n", - " [-5.2653, -1.6810, 2.1467, 3.3414, 3.0073, 0.6391]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 3, 0, 3, 3, 2, 1, 3, 3, 1, 2, 0, 5, 1, 1, 3, 1, 5, 4, 1, 5, 2, 3, 2,\n", - " 0, 0, 5, 4, 2, 5, 1, 4, 2, 5, 2, 3, 1, 1, 1, 3, 4, 5, 3, 2, 2, 4, 0, 0,\n", - " 1, 2, 3, 1, 1, 1, 5, 0, 0, 3, 3, 1, 2, 1, 3, 0, 3, 2, 1, 2, 0, 1, 4, 1,\n", - " 1, 4, 2, 4, 3, 5, 1, 1, 3, 4, 3, 0, 1, 2, 1, 2, 4, 4, 5, 3, 4, 4, 1, 1,\n", - " 3, 1, 4, 1, 1, 4, 1, 4, 2, 3, 1, 1, 3, 3, 4, 2, 4, 0, 1, 3, 2, 3, 3, 3,\n", - " 1, 3, 4, 0, 4, 3, 4, 0, 2, 3, 3, 2, 3, 4, 4, 3, 0, 2, 4, 3, 4, 2, 0, 1,\n", - " 1, 1, 0, 3, 4, 1, 2, 5, 3, 1, 1, 3, 4, 3, 1, 4, 4, 3, 4, 4, 0, 3, 2, 3,\n", - " 1, 5, 3, 1, 1, 4, 4, 4, 4, 5, 1, 3, 0, 0, 1, 3, 1, 1, 3, 1, 2, 1, 3, 1,\n", - " 0, 3, 1, 1, 1, 3, 1, 3], device='cuda:0')\n", - "tensor([[-4.3138, -1.5459, 0.1167, 1.9383, 2.6190, 2.0232],\n", - " [-5.8996, -1.1784, 0.8182, 3.2418, 3.2888, 1.2049],\n", - " [-7.0555, -2.9277, -1.0806, 3.7485, 4.6092, 5.0661],\n", - " ...,\n", - " [ 3.2228, 3.9096, 2.9786, -1.8295, -2.4216, -4.1299],\n", - " [-6.0889, -2.0205, 2.0845, 4.8269, 2.0379, 0.0937],\n", - " [-4.3744, -0.8840, 1.2005, 3.2371, 1.6959, -0.0470]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 4, 5, 4, 1, 4, 3, 2, 1, 3, 4, 4, 1, 3, 0, 4, 2, 2, 5, 1, 1, 0, 1, 0,\n", - " 3, 1, 2, 2, 3, 1, 1, 1, 0, 3, 2, 3, 2, 1, 3, 0, 5, 4, 4, 1, 1, 1, 1, 1,\n", - " 3, 2, 2, 5, 3, 2, 1, 4, 5, 2, 0, 2, 2, 2, 4, 0, 4, 5, 4, 2, 3, 4, 4, 1,\n", - " 1, 1, 1, 0, 1, 0, 2, 1, 4, 4, 4, 2, 2, 1, 2, 1, 2, 1, 1, 3, 1, 1, 5, 2,\n", - " 2, 1, 1, 2, 1, 2, 3, 3, 1, 5, 4, 1, 4, 0, 4, 3, 4, 1, 0, 1, 0, 3, 2, 4,\n", - " 5, 0, 4, 1, 3, 4, 3, 1, 3, 4, 1, 1, 0, 5, 3, 2, 2, 2, 1, 3, 4, 4, 2, 1,\n", - " 3, 1, 1, 2, 4, 2, 4, 3, 3, 1, 3, 4, 3, 2, 1, 5, 1, 0, 2, 3, 1, 5, 3, 1,\n", - " 4, 2, 2, 2, 0, 5, 0, 1, 3, 5, 1, 4, 0, 0, 4, 1, 2, 1, 1, 0, 3, 0, 1, 1,\n", - " 3, 3, 3, 1, 3, 1, 3, 3], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|███████████████████████| 4/4 [00:00<00:00, 4.86it/s, accuracy=558, loss=0.968]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0.5765, 1.2425, 1.2605, -0.3023, -1.4327, -2.6292],\n", - " [ -5.4997, -1.3672, 1.6019, 3.5771, 3.3256, 0.1516],\n", - " [ -3.3461, -0.0209, 1.3426, 2.6482, 1.7152, -1.0642],\n", - " [ -2.9590, -0.3886, 1.9588, 2.9181, 1.1409, -1.6321],\n", - " [ -3.8681, -0.4661, 1.5115, 2.8366, 2.0967, -1.0420],\n", - " [ -4.0740, -0.2040, 2.2998, 2.7793, 1.0356, -0.9093],\n", - " [ -0.0752, 1.7127, 1.7855, 0.4269, -1.0077, -3.2956],\n", - " [ 2.2539, 3.4158, 3.2997, -1.1336, -2.3147, -5.3433],\n", - " [ -1.5451, 0.9156, 2.4823, 2.4143, 0.2776, -3.1742],\n", - " [ -2.2051, 1.4833, 2.2982, 1.0994, 0.0144, -2.4884],\n", - " [ -4.0980, -0.1499, 1.6221, 2.3669, 1.4651, -0.2774],\n", - " [ 2.2626, 3.2048, 1.8029, -1.4880, -2.0317, -4.5878],\n", - " [ 6.8127, 5.2500, 4.7880, -3.5068, -4.3655, -6.9875],\n", - " [ 1.5521, 2.6776, 1.2342, -0.3926, -2.0934, -3.9213],\n", - " [ -3.1305, 0.4345, 1.9745, 2.1021, 0.8095, -1.6314],\n", - " [ 1.1153, 3.8355, 3.1217, -0.6245, -2.0842, -5.1081],\n", - " [ -4.2461, -1.5614, -0.8173, 2.3128, 3.9529, 2.2445],\n", - " [ -4.1740, -1.0328, -0.5410, 1.9310, 3.1676, 1.9175],\n", - " [ 7.3648, 5.2286, 3.5542, -4.3322, -4.4827, -5.7021],\n", - " [ -3.6405, 0.2625, 1.8826, 2.8233, 1.2826, -1.7250],\n", - " [ 0.0476, 1.8299, 2.1771, 1.0116, -1.3220, -4.5319],\n", - " [ -5.0605, -1.0604, 1.9959, 3.3316, 2.3916, -0.1499],\n", - " [ 1.3626, 3.0586, 2.9934, -0.1691, -2.6867, -4.0386],\n", - " [ -4.3545, -0.5325, 1.9106, 3.3945, 2.1654, -1.0381],\n", - " [ -2.9373, 0.2946, 1.8415, 2.6841, 0.7640, -2.5379],\n", - " [ 1.0636, 2.3131, 2.2669, -0.1149, -1.1477, -4.3912],\n", - " [ -5.2982, -0.0857, 0.6887, 2.2687, 2.6591, 1.3745],\n", - " [ 2.7709, 3.0484, 2.8399, -2.0192, -2.2887, -3.9676],\n", - " [ -3.4717, 0.1765, 2.1564, 2.8190, 0.6462, -1.5732],\n", - " [ 3.0571, 3.8245, 2.9674, -1.9512, -2.2509, -4.7071],\n", - " [ -1.2419, 1.6222, 2.3481, 1.4175, -0.3930, -3.1330],\n", - " [ -4.7918, -0.7495, 1.2588, 2.5338, 2.7130, 0.9862],\n", - " [ 4.3627, 3.5016, 3.6730, -1.2136, -3.1255, -5.8406],\n", - " [ -6.3407, -1.8142, 0.0894, 2.8111, 3.6739, 3.1447],\n", - " [ -3.0620, 0.0796, 2.6367, 3.1476, 0.2504, -2.2772],\n", - " [ -4.6556, -1.3862, 1.4723, 3.3855, 2.4423, -0.4691],\n", - " [ -7.5991, -1.4266, -2.2352, 2.4291, 3.6433, 6.4373],\n", - " [ -1.7765, 0.6927, 0.6245, 1.2318, 0.6520, -0.9060],\n", - " [ 4.2020, 4.5705, 3.6864, -3.1590, -3.2302, -5.3752],\n", - " [ -4.0139, -0.5886, 1.4627, 2.8464, 2.0920, -1.1894],\n", - " [-10.9151, -3.7566, -2.5847, 3.4806, 6.9471, 7.6234],\n", - " [ 1.5438, 3.2715, 3.1761, -0.4035, -2.1175, -4.0942],\n", - " [ 2.4336, 3.4535, 3.8367, -1.0734, -2.4758, -5.0331],\n", - " [ -0.8397, 1.9589, 2.1374, 1.1826, -1.1951, -3.7411],\n", - " [ -4.4436, -1.1445, 1.2240, 3.6771, 1.6521, -0.8033],\n", - " [ -3.9789, 0.6732, 2.7100, 3.0268, 0.2618, -2.3606],\n", - " [ 0.7740, 3.4374, 3.7538, -0.5211, -2.3639, -5.4299],\n", - " [ -1.7277, 0.7823, 1.8695, 0.7556, 0.5624, -1.8624],\n", - " [ -5.2307, -0.3372, 1.7956, 3.5618, 1.6650, -0.5688],\n", - " [ -1.2256, 1.6510, 1.5812, 0.9361, 0.1209, -2.4021],\n", - " [ -9.5271, -2.1561, -2.9489, 1.9157, 5.0855, 7.9770],\n", - " [ 0.1118, 2.0894, 1.8909, 0.1472, -1.4766, -2.9976],\n", - " [ 0.8911, 2.7300, 2.5881, -0.0405, -1.9257, -3.8839],\n", - " [ -9.2480, -2.7420, -2.9960, 2.6785, 5.1223, 6.8353],\n", - " [ -5.3649, -0.8349, 1.4628, 2.8182, 2.9249, 1.0305],\n", - " [ -0.1645, 2.0038, 2.8104, 0.4227, -1.2500, -3.6848],\n", - " [ -9.4310, -2.7591, -0.2638, 3.5122, 5.6857, 4.2586],\n", - " [ 0.8285, 3.3499, 2.8682, -0.5976, -2.3923, -3.9095],\n", - " [ -2.7186, 0.6258, 2.4882, 2.3340, 0.4228, -1.7555],\n", - " [ 2.6098, 3.4463, 3.3204, -1.4514, -2.6326, -4.7444],\n", - " [ -3.6875, -0.4948, -0.6562, 1.7700, 2.8460, 1.6576],\n", - " [ 4.9053, 3.4070, 3.5387, -2.0510, -3.2267, -5.2030],\n", - " [ -5.9348, -1.3922, 1.1995, 3.3275, 2.5638, 0.5652],\n", - " [ -0.6186, 1.8070, 2.8535, 1.1906, -1.1191, -4.2350],\n", - " [ -4.6352, -1.1387, 1.3102, 2.8953, 2.7229, -0.2608],\n", - " [ 1.0800, 2.4892, 1.4571, -0.1865, -1.2878, -3.1902],\n", - " [ 9.3193, 6.0738, 4.3086, -6.3189, -6.0916, -6.4464],\n", - " [ 1.3260, 2.3923, 2.3839, -0.3540, -1.5840, -4.0442],\n", - " [ 4.3597, 3.8694, 2.5483, -2.8838, -3.4682, -4.8840],\n", - " [ 7.8931, 5.5561, 4.3343, -4.6100, -5.1641, -6.2475],\n", - " [ 6.0716, 4.8735, 2.4372, -3.6796, -4.4242, -6.0127],\n", - " [ -7.2096, -2.2404, -1.1948, 2.5522, 4.5911, 4.9529],\n", - " [ -3.3122, -0.3935, 1.4474, 2.3470, 2.0048, -0.6916],\n", - " [ 2.7613, 3.6924, 2.8528, -1.2416, -2.6379, -4.6913],\n", - " [ -4.8188, -0.3445, 1.4441, 3.0498, 2.6242, -0.1404],\n", - " [ -5.4879, -1.2319, -1.2224, 1.8586, 3.6890, 3.6816],\n", - " [ 2.1689, 3.1185, 1.9168, -1.3647, -1.8776, -3.4661],\n", - " [ -4.5519, -0.0515, 1.7006, 2.1962, 2.2071, -0.5028],\n", - " [ -3.9288, 0.2517, 1.5641, 2.5708, 1.4273, -0.4405],\n", - " [ 4.3194, 3.9928, 2.9821, -1.8597, -3.0637, -5.6397]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 3, 3, 2, 3, 3, 2, 1, 3, 2, 3, 1, 0, 1, 2, 1, 4, 4, 0, 3, 2, 3, 1, 3,\n", - " 3, 1, 4, 1, 3, 1, 2, 4, 0, 5, 3, 3, 5, 2, 1, 3, 5, 1, 1, 1, 3, 2, 1, 2,\n", - " 3, 1, 5, 2, 1, 5, 3, 2, 4, 1, 2, 1, 4, 0, 3, 2, 3, 1, 0, 1, 1, 0, 0, 5,\n", - " 3, 1, 3, 4, 1, 3, 3, 0], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 66 5 0 0 0 0]\n", - " [ 4 192 8 0 1 0]\n", - " [ 0 13 87 10 0 0]\n", - " [ 0 2 29 109 5 0]\n", - " [ 0 0 3 25 64 9]\n", - " [ 0 0 0 2 6 40]]\n", + "[[ 57 11 3 2 1 0]\n", + " [ 26 108 59 11 1 0]\n", + " [ 1 15 58 27 9 1]\n", + " [ 0 4 33 76 30 0]\n", + " [ 0 1 6 31 51 11]\n", + " [ 0 0 3 1 15 28]]\n", "\n" ] }, @@ -1258,8 +662,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 3/5\n", - "Train loss: 0.988835, Train accuracy: 0.8206\n", - "Val loss: 1.855518, Val accuracy: 0.3140\n", + "Train loss: 1.503537, Train accuracy: 0.5559\n", + "Val loss: 1.828071, Val accuracy: 0.2397\n", "\n" ] }, @@ -1267,170 +671,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|█████▊ | 1/4 [00:00<00:01, 1.92it/s, accuracy=356, loss=0.909]" + "Train progress: 100%|████████████████████████| 4/4 [00:08<00:00, 2.15s/it, accuracy=447, loss=1.42]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-4.9337, -0.9885, 1.5400, 3.5426, 1.8137, -0.2495],\n", - " [ 1.7527, 4.0484, 2.6650, -0.0102, -2.7709, -6.0298],\n", - " [-2.7510, 0.1679, 0.7059, 1.8415, 0.7303, -0.5378],\n", - " ...,\n", - " [-8.1178, -2.1425, -0.0396, 3.0900, 5.1731, 3.3863],\n", - " [-0.7351, 1.7490, 2.4605, 0.8606, -0.8792, -2.8746],\n", - " [-3.5798, 0.7257, 3.3366, 3.0798, -0.1374, -2.5952]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 1, 3, 3, 1, 3, 3, 1, 0, 0, 0, 1, 3, 5, 3, 2, 1, 1, 1, 1, 1, 5, 4, 0,\n", - " 1, 1, 0, 0, 1, 3, 1, 3, 0, 1, 4, 1, 4, 4, 3, 3, 1, 0, 4, 2, 4, 1, 2, 5,\n", - " 2, 5, 2, 5, 1, 1, 0, 2, 3, 1, 1, 4, 3, 3, 2, 3, 4, 2, 3, 1, 3, 3, 5, 0,\n", - " 5, 1, 4, 4, 1, 4, 0, 1, 3, 0, 1, 3, 3, 2, 3, 3, 2, 3, 1, 3, 4, 1, 4, 1,\n", - " 4, 4, 1, 1, 3, 2, 3, 4, 0, 1, 4, 3, 1, 2, 1, 2, 1, 1, 2, 1, 4, 3, 1, 0,\n", - " 4, 1, 1, 3, 4, 4, 1, 3, 2, 3, 3, 1, 5, 0, 2, 0, 1, 4, 1, 1, 1, 3, 2, 1,\n", - " 1, 4, 1, 3, 3, 1, 4, 4, 0, 3, 4, 3, 3, 3, 1, 3, 3, 4, 4, 3, 4, 3, 3, 4,\n", - " 2, 1, 0, 1, 1, 3, 3, 5, 3, 1, 1, 0, 2, 4, 3, 2, 3, 0, 3, 1, 1, 3, 4, 3,\n", - " 1, 0, 1, 1, 1, 4, 2, 2], device='cuda:0')\n", - "tensor([[ 4.1090, 4.7621, 3.0220, -2.3045, -3.4729, -5.1341],\n", - " [-2.3722, 0.1892, 0.1744, 0.0353, 0.6002, 1.8398],\n", - " [-4.3874, -0.2849, 2.7682, 3.6690, 1.3217, -2.0147],\n", - " ...,\n", - " [ 0.6917, 2.8364, 1.9947, 0.2951, -2.1395, -3.9947],\n", - " [-5.9529, -1.5716, -0.6779, 1.3925, 3.1945, 4.9690],\n", - " [ 1.3479, 3.5883, 2.3890, -0.4944, -2.0787, -3.8998]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 5, 2, 4, 4, 1, 3, 5, 0, 5, 1, 4, 3, 1, 1, 2, 1, 2, 0, 0, 2, 4, 4, 4,\n", - " 4, 2, 1, 1, 1, 2, 2, 3, 4, 3, 1, 1, 5, 2, 1, 5, 1, 4, 1, 2, 4, 0, 2, 1,\n", - " 4, 1, 3, 3, 4, 1, 4, 2, 4, 4, 3, 3, 1, 1, 3, 4, 4, 2, 3, 1, 4, 0, 1, 3,\n", - " 3, 5, 5, 1, 3, 4, 3, 3, 4, 1, 3, 1, 0, 3, 0, 5, 1, 1, 1, 5, 1, 2, 4, 3,\n", - " 0, 3, 1, 4, 1, 2, 5, 3, 1, 1, 3, 1, 3, 4, 3, 2, 1, 1, 3, 2, 2, 3, 1, 1,\n", - " 4, 2, 1, 2, 3, 1, 2, 1, 2, 4, 1, 3, 4, 3, 4, 3, 1, 2, 1, 4, 2, 2, 0, 1,\n", - " 1, 0, 1, 4, 2, 1, 1, 3, 5, 1, 2, 0, 3, 3, 2, 1, 0, 3, 5, 4, 1, 1, 5, 5,\n", - " 0, 1, 2, 4, 2, 0, 1, 1, 1, 1, 1, 2, 1, 3, 1, 1, 3, 4, 4, 0, 4, 5, 1, 3,\n", - " 1, 5, 5, 5, 0, 1, 5, 1], device='cuda:0')\n", - "tensor([[-6.0771, -1.4452, -0.6184, 1.6283, 3.6519, 3.7136],\n", - " [ 5.2414, 5.2906, 2.1526, -3.8356, -3.3543, -4.1647],\n", - " [-2.0709, 1.3723, 3.2082, 2.7570, -0.8230, -3.8778],\n", - " ...,\n", - " [ 3.4053, 4.9225, 2.3481, -1.5346, -3.4529, -6.2331],\n", - " [-3.1430, -0.0794, 2.4756, 2.9915, 0.5997, -2.0509],\n", - " [-1.7533, 1.0990, 2.6620, 2.2392, -1.0744, -3.4488]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 0, 2, 5, 4, 4, 3, 3, 1, 1, 5, 0, 3, 1, 1, 2, 4, 1, 1, 4, 1, 4, 1, 4,\n", - " 2, 1, 2, 2, 2, 0, 1, 0, 3, 3, 2, 1, 2, 3, 2, 0, 1, 1, 2, 1, 1, 1, 1, 3,\n", - " 3, 2, 3, 5, 1, 0, 0, 3, 2, 1, 3, 2, 0, 1, 0, 0, 4, 1, 3, 0, 2, 2, 2, 3,\n", - " 0, 2, 3, 2, 2, 4, 4, 4, 5, 1, 5, 1, 1, 1, 0, 1, 0, 1, 1, 2, 5, 2, 3, 3,\n", - " 1, 1, 5, 4, 5, 2, 1, 2, 3, 3, 1, 5, 5, 3, 3, 4, 4, 0, 4, 0, 1, 1, 1, 3,\n", - " 1, 5, 1, 2, 4, 4, 2, 2, 0, 2, 3, 5, 2, 2, 3, 2, 0, 2, 0, 1, 3, 4, 3, 4,\n", - " 1, 3, 4, 4, 4, 3, 2, 1, 3, 4, 2, 4, 3, 2, 3, 4, 3, 2, 1, 3, 4, 5, 1, 0,\n", - " 1, 1, 1, 2, 1, 3, 2, 3, 1, 5, 4, 0, 1, 2, 3, 3, 2, 2, 3, 4, 0, 0, 4, 5,\n", - " 2, 0, 1, 1, 0, 1, 2, 2], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|███████████████████████| 4/4 [00:00<00:00, 4.55it/s, accuracy=608, loss=0.944]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-10.4296, -3.8104, -2.7157, 2.5582, 6.7644, 9.8326],\n", - " [ 2.7911, 4.7170, 2.4238, -1.4134, -3.0964, -5.3750],\n", - " [ 0.8153, 3.2530, 2.2854, 0.1257, -2.1097, -4.0836],\n", - " [ 6.4898, 4.8606, 1.8179, -3.2139, -3.3854, -5.8207],\n", - " [ 0.4480, 2.9307, 1.5726, 0.0746, -1.6138, -3.2415],\n", - " [ 2.8429, 3.3472, 1.5060, -2.0152, -2.0890, -3.3464],\n", - " [ -3.1532, 0.3427, 2.5318, 2.2673, 0.3682, -1.6826],\n", - " [ 5.7756, 4.5486, 2.2251, -2.2895, -3.4067, -6.6343],\n", - " [ -1.8289, 0.2570, 3.2912, 2.5157, -0.1835, -3.6363],\n", - " [ -0.1207, 2.4728, 2.3312, 0.6245, -1.6630, -3.4890],\n", - " [ -4.9384, -0.9318, 1.2012, 2.1304, 2.1196, 2.0281],\n", - " [ 5.8274, 4.3660, 1.1185, -2.9798, -3.2530, -4.5934],\n", - " [ 0.1814, 2.6968, 2.5282, 0.2058, -1.7981, -3.9865],\n", - " [ 2.5852, 3.8653, 1.3733, -2.2356, -2.1874, -3.5946],\n", - " [ -2.9637, 0.8321, 3.0167, 2.4129, 0.0794, -2.6776],\n", - " [ 6.3480, 5.1667, 2.4490, -3.1681, -3.7982, -6.0045],\n", - " [ -5.0080, -1.0061, 3.1061, 3.8146, 1.8097, -0.2981],\n", - " [ 2.0050, 4.2631, 2.7213, 0.2393, -3.3254, -6.7129],\n", - " [ -5.5301, -1.1428, 2.7354, 3.6992, 1.7613, -0.4868],\n", - " [ 1.7544, 4.1112, 1.9041, 0.1791, -2.5757, -5.0608],\n", - " [ -5.7589, -1.8701, -1.2972, 1.2458, 3.9458, 5.1966],\n", - " [ 1.2777, 3.5089, 2.1827, -0.9147, -2.1570, -3.6267],\n", - " [ 5.9842, 3.6398, 1.5361, -3.7559, -2.8713, -5.2659],\n", - " [ -1.4188, 1.7968, 2.7819, 1.8459, -1.1084, -3.4240],\n", - " [ -0.8848, 2.1914, 1.8731, 1.1235, -1.5171, -3.3139],\n", - " [ 0.4543, 2.0357, 2.4139, 0.6919, -1.4126, -4.3122],\n", - " [ 0.7948, 3.5270, 1.9539, -0.3822, -1.9718, -3.9399],\n", - " [-12.0563, -3.5067, -3.1285, 1.9097, 6.5274, 10.7917],\n", - " [ -7.0609, -1.6023, 3.8921, 4.7460, 2.3193, -0.7062],\n", - " [ -7.6770, -2.5070, 0.3203, 3.0193, 4.2284, 4.7442],\n", - " [ -3.7027, -0.2909, 1.8157, 2.6358, 1.2444, -0.6864],\n", - " [ 1.4952, 4.0216, 2.2308, -0.8220, -2.6040, -4.9555],\n", - " [ 2.3353, 5.0735, 2.1370, -1.3626, -3.0128, -4.9274],\n", - " [ -1.6809, 0.1204, 2.8038, 2.4835, -0.2667, -2.1119],\n", - " [ -0.8956, 1.8091, 2.6209, 1.7931, -1.2469, -4.3996],\n", - " [ 4.8168, 2.7624, 0.5508, -3.1443, -2.5744, -3.3113],\n", - " [ -5.0815, -0.5820, 2.0385, 3.5322, 2.0673, -0.1394],\n", - " [ -5.4826, -1.5388, 2.4536, 3.6069, 2.4676, 0.2463],\n", - " [ 2.3224, 4.6590, 1.5354, -1.7607, -2.7640, -3.6781],\n", - " [ -4.5778, -1.1882, 2.2500, 2.7952, 1.6929, 0.1307],\n", - " [ -2.3701, 0.7256, 2.2783, 1.4841, -0.0963, -1.4081],\n", - " [ 4.9377, 3.5221, 0.9397, -1.7912, -2.6280, -4.7762],\n", - " [ -6.1309, -1.7224, 1.1682, 3.1179, 3.5583, 1.4492],\n", - " [ -6.0394, -1.4907, 2.8821, 3.4891, 2.0974, 0.5163],\n", - " [ -1.3883, 1.3575, 3.8397, 1.7139, -0.7168, -3.1748],\n", - " [ -5.1815, -1.1008, 2.6374, 3.5543, 2.0140, 0.0416],\n", - " [ 2.4156, 4.8487, 1.7903, -1.2152, -2.7921, -5.5910],\n", - " [ -1.5103, 1.6678, 2.9186, 1.9692, -1.0538, -3.5508],\n", - " [ 2.9188, 4.1986, 2.1034, -0.8220, -3.1444, -5.7260],\n", - " [ -5.6778, -1.2657, 3.1952, 4.4029, 2.2285, -0.8547],\n", - " [ -2.9026, 0.8512, 3.0836, 2.3911, -0.2982, -2.1828],\n", - " [ -6.0960, -1.3252, 2.9230, 3.9785, 2.0517, -0.2055],\n", - " [ -8.8336, -3.2210, 0.2579, 3.7906, 5.6232, 4.8129],\n", - " [ -2.7542, 0.4026, 2.4254, 2.2916, -0.0682, -2.3038],\n", - " [ -5.9363, -1.6542, -0.7113, 1.4942, 3.7367, 3.5232],\n", - " [ -7.6216, -3.3655, 0.6976, 3.8658, 5.6127, 3.4087],\n", - " [ 0.9211, 3.8290, 2.2580, -0.5556, -2.3652, -3.6705],\n", - " [ 0.0502, 2.5752, 1.9354, 0.9322, -1.6187, -3.2095],\n", - " [ 3.1170, 4.6282, 2.2956, -1.4036, -3.3262, -5.4101],\n", - " [ -6.8252, -2.7325, 1.3881, 3.2644, 5.1348, 2.9797],\n", - " [ -4.2034, -1.8419, 1.5066, 3.3848, 2.5146, 0.5307],\n", - " [ -3.9798, 0.3744, 2.9042, 3.0653, 1.0312, -2.7103],\n", - " [ 6.6035, 5.1678, 1.8864, -3.2562, -3.9221, -6.5888],\n", - " [ 1.2449, 3.9797, 2.3355, 0.4620, -2.6661, -5.8404],\n", - " [ -5.9322, -1.6597, 1.7353, 3.1994, 2.1473, 1.9017],\n", - " [ -6.8454, -2.2789, 2.3159, 3.5258, 3.5985, 1.6813],\n", - " [ 1.4788, 3.9385, 2.6118, -0.2963, -2.8277, -5.0201],\n", - " [ -6.5837, -1.9387, 1.7019, 3.2176, 3.4560, 2.1135],\n", - " [ 3.2233, 4.8331, 2.2700, -1.7802, -3.1831, -4.6161],\n", - " [ -2.2311, 1.2838, 3.2664, 2.6219, -0.6135, -3.7277],\n", - " [ -1.5448, 1.3221, 1.5162, 0.7086, 0.1059, -1.1249],\n", - " [ 6.7785, 4.9108, 0.7523, -4.6322, -2.7298, -4.2038],\n", - " [ -7.0066, -1.6586, 3.2228, 3.6473, 3.2055, 0.8055],\n", - " [ -1.0433, 2.5111, 2.6748, 1.8221, -1.4014, -4.9078],\n", - " [ -6.3764, -2.5054, -0.3301, 2.2550, 4.1997, 3.5383],\n", - " [ -5.7098, -1.1517, 3.7915, 4.3556, 1.8750, -1.1035],\n", - " [ 0.2975, 3.3524, 1.1723, -0.4793, -1.2602, -3.4301],\n", - " [ 2.9746, 4.8861, 1.7642, -1.0564, -2.9545, -5.6332],\n", - " [ -6.5024, -1.8332, 1.4816, 3.7636, 3.9813, 1.2802],\n", - " [ 2.1063, 4.0084, 2.2482, -1.1702, -2.2348, -4.5038]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 1, 1, 0, 1, 1, 2, 0, 2, 1, 3, 0, 2, 1, 2, 0, 3, 1, 3, 1, 5, 1, 0, 2,\n", - " 1, 2, 1, 5, 3, 5, 3, 1, 1, 3, 2, 0, 3, 3, 1, 3, 2, 0, 4, 3, 2, 3, 1, 2,\n", - " 1, 3, 2, 3, 4, 2, 4, 4, 1, 1, 1, 4, 3, 2, 0, 1, 3, 2, 1, 3, 1, 2, 1, 0,\n", - " 3, 1, 4, 3, 1, 1, 3, 1], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 67 4 0 0 0 0]\n", - " [ 1 189 15 0 0 0]\n", - " [ 0 5 79 25 1 0]\n", - " [ 0 0 2 139 4 0]\n", - " [ 0 0 1 5 93 2]\n", - " [ 0 0 0 0 7 41]]\n", + "[[ 53 21 0 0 0 0]\n", + " [ 10 154 38 3 0 0]\n", + " [ 1 25 42 40 3 0]\n", + " [ 0 3 16 104 18 2]\n", + " [ 0 0 2 25 61 12]\n", + " [ 0 0 0 0 14 33]]\n", "\n" ] }, @@ -1446,8 +701,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 4/5\n", - "Train loss: 0.920702, Train accuracy: 0.8941\n", - "Val loss: 1.686137, Val accuracy: 0.4298\n", + "Train loss: 1.423580, Train accuracy: 0.6574\n", + "Val loss: 1.763735, Val accuracy: 0.3471\n", "\n" ] }, @@ -1455,170 +710,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:01, 1.99it/s, accuracy=366, loss=0.88]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-5.7759, -1.5733, 0.8380, 2.3969, 3.4815, 2.0601],\n", - " [ 2.6634, 3.9990, 2.0021, -1.4259, -2.8793, -5.1193],\n", - " [-2.5737, 1.0527, 3.9949, 2.4751, -0.8109, -3.7516],\n", - " ...,\n", - " [-4.7469, -1.4491, 1.0603, 2.4618, 3.1493, 1.5438],\n", - " [ 5.1014, 3.8083, 1.8096, -2.2639, -2.8155, -5.2277],\n", - " [-5.8051, -1.1016, 3.2812, 3.5380, 2.1814, 0.4412]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([4, 1, 2, 3, 1, 0, 4, 2, 1, 3, 5, 3, 0, 0, 0, 4, 1, 3, 0, 4, 0, 2, 4, 3,\n", - " 1, 1, 3, 1, 3, 3, 1, 2, 3, 5, 2, 3, 0, 1, 5, 2, 4, 3, 3, 3, 5, 1, 0, 1,\n", - " 2, 4, 4, 1, 3, 3, 3, 1, 2, 0, 4, 0, 2, 2, 5, 3, 1, 1, 0, 2, 3, 3, 0, 2,\n", - " 5, 1, 4, 1, 2, 4, 1, 1, 1, 3, 2, 3, 0, 1, 1, 4, 3, 1, 2, 1, 5, 2, 1, 2,\n", - " 0, 2, 1, 3, 0, 4, 3, 3, 4, 3, 4, 4, 3, 0, 1, 1, 5, 1, 4, 2, 2, 2, 1, 4,\n", - " 2, 3, 4, 0, 5, 0, 4, 5, 1, 3, 1, 1, 1, 5, 3, 1, 2, 2, 3, 1, 2, 2, 3, 4,\n", - " 3, 3, 1, 1, 3, 1, 1, 3, 0, 3, 4, 3, 4, 5, 4, 2, 1, 4, 0, 0, 1, 1, 3, 1,\n", - " 3, 1, 3, 5, 0, 3, 3, 5, 1, 5, 2, 1, 1, 4, 5, 1, 4, 1, 0, 3, 1, 5, 1, 2,\n", - " 5, 1, 3, 4, 0, 4, 0, 3], device='cuda:0')\n", - "tensor([[ 5.7259, 4.0119, 2.1300, -2.3374, -3.2446, -5.6680],\n", - " [ -4.9224, -0.2833, 3.5729, 3.8218, 1.0780, -1.4681],\n", - " [-10.9443, -3.4520, 1.6554, 5.0432, 5.7462, 5.2836],\n", - " ...,\n", - " [ 2.0453, 4.1023, 2.4270, -0.5131, -2.7423, -5.1347],\n", - " [ 0.2039, 2.5082, 1.8900, -0.1353, -1.4990, -2.5726],\n", - " [ 5.4161, 3.0861, 0.8866, -2.8623, -2.1852, -3.7849]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([0, 3, 4, 1, 3, 3, 0, 1, 2, 1, 3, 1, 1, 2, 5, 3, 4, 4, 3, 4, 3, 2, 1, 0,\n", - " 4, 1, 4, 3, 3, 1, 3, 3, 2, 0, 1, 0, 1, 4, 1, 1, 1, 4, 4, 1, 3, 0, 3, 3,\n", - " 0, 3, 1, 1, 1, 2, 3, 1, 2, 3, 1, 1, 2, 1, 5, 5, 0, 4, 3, 3, 3, 1, 1, 1,\n", - " 1, 4, 1, 5, 5, 5, 3, 5, 1, 5, 5, 0, 1, 1, 1, 2, 4, 2, 2, 1, 2, 4, 3, 1,\n", - " 1, 0, 2, 2, 0, 5, 2, 3, 2, 1, 0, 5, 4, 2, 0, 4, 1, 3, 4, 1, 2, 1, 4, 3,\n", - " 1, 2, 3, 1, 2, 4, 3, 1, 3, 2, 2, 0, 3, 1, 2, 3, 3, 1, 0, 1, 2, 3, 1, 1,\n", - " 1, 3, 4, 2, 2, 1, 0, 3, 0, 0, 0, 0, 2, 1, 5, 1, 2, 2, 4, 4, 3, 4, 3, 3,\n", - " 0, 2, 3, 2, 3, 1, 3, 2, 0, 1, 2, 1, 4, 0, 1, 1, 5, 4, 4, 3, 1, 5, 1, 1,\n", - " 2, 3, 1, 1, 3, 1, 1, 0], device='cuda:0')\n", - "tensor([[ 2.5420, 3.1875, 2.1648, -0.9269, -2.5691, -4.5365],\n", - " [ 2.7113, 3.7353, 2.8463, -0.9866, -2.8635, -4.5132],\n", - " [-5.1453, -1.2633, -1.2984, 0.4569, 2.9596, 5.3405],\n", - " ...,\n", - " [ 2.1384, 3.8418, 2.3368, -1.4104, -2.6838, -4.0226],\n", - " [ 1.6626, 3.6129, 2.2975, -0.8998, -2.4072, -4.1306],\n", - " [-6.5156, -1.7716, 0.2386, 2.8963, 4.5214, 3.0853]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 1, 5, 1, 2, 3, 2, 3, 2, 3, 1, 1, 1, 3, 1, 1, 1, 0, 2, 0, 1, 5, 3, 4,\n", - " 0, 4, 1, 3, 3, 1, 1, 2, 1, 5, 1, 3, 4, 3, 4, 3, 1, 4, 1, 2, 3, 4, 2, 3,\n", - " 4, 2, 5, 3, 1, 0, 5, 3, 4, 1, 3, 1, 1, 4, 5, 2, 1, 1, 1, 2, 4, 3, 1, 1,\n", - " 3, 0, 0, 3, 0, 1, 1, 2, 3, 3, 2, 4, 0, 1, 3, 2, 4, 1, 2, 1, 0, 0, 0, 1,\n", - " 1, 2, 0, 4, 1, 4, 2, 1, 4, 1, 3, 2, 3, 4, 4, 4, 3, 1, 3, 2, 1, 0, 3, 2,\n", - " 3, 1, 1, 2, 4, 3, 4, 3, 3, 0, 0, 4, 2, 5, 2, 3, 2, 0, 2, 3, 4, 1, 4, 2,\n", - " 3, 3, 1, 0, 2, 1, 1, 4, 5, 3, 1, 3, 1, 3, 2, 1, 4, 2, 4, 1, 4, 1, 3, 1,\n", - " 4, 1, 2, 3, 0, 1, 4, 1, 4, 2, 1, 1, 1, 3, 2, 1, 5, 1, 1, 1, 2, 4, 1, 3,\n", - " 1, 1, 3, 4, 1, 1, 1, 4], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|███████████████████████| 4/4 [00:00<00:00, 4.81it/s, accuracy=633, loss=0.888]" + "Train progress: 100%|████████████████████████| 4/4 [00:08<00:00, 2.17s/it, accuracy=486, loss=1.47]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ -3.2497, 0.5673, 3.1601, 2.3869, -0.0335, -2.0497],\n", - " [ -2.4593, 0.5386, 1.7118, 2.2894, 0.1709, -2.1822],\n", - " [ -1.2751, 1.5196, 3.4448, 2.3384, -1.4411, -4.2802],\n", - " [ 2.0649, 3.7405, 2.5357, -1.3918, -2.4507, -4.4183],\n", - " [ -4.1344, -1.1459, 0.1346, 2.2316, 2.6874, 1.2632],\n", - " [ -8.3915, -3.0208, 0.3923, 3.8464, 6.2180, 4.2387],\n", - " [ 3.3730, 4.8781, 2.5860, -2.2050, -3.4234, -5.5517],\n", - " [ 8.3960, 4.3488, 2.7891, -4.0404, -4.0036, -6.3879],\n", - " [-10.3759, -3.5367, -1.6626, 3.1039, 7.0578, 6.8458],\n", - " [ 1.1714, 2.6555, 1.2416, -1.1252, -2.0335, -3.0712],\n", - " [ -6.8211, -2.5425, 0.7500, 3.5888, 4.6562, 3.0443],\n", - " [ -7.5116, -1.6554, -0.4702, 2.0625, 4.2912, 4.5781],\n", - " [ 3.4008, 4.7224, 3.2029, -1.3657, -4.1060, -5.6949],\n", - " [ -0.3255, 2.2532, 3.4362, 1.6081, -2.2351, -4.9790],\n", - " [ -4.2005, -0.5496, 2.6046, 3.4961, 0.6267, -1.1292],\n", - " [ -2.4484, 0.1961, 1.8516, 3.3229, 0.8628, -3.0822],\n", - " [ 3.8468, 4.4284, 3.1539, -1.2599, -3.5479, -6.2777],\n", - " [ 3.3208, 4.6737, 2.8503, -2.0720, -3.9097, -5.0311],\n", - " [ -6.5871, -2.2764, 0.4940, 3.5449, 4.6360, 3.0704],\n", - " [ -7.9097, -2.8796, 1.1501, 4.4332, 5.1891, 2.5854],\n", - " [ 2.6897, 3.8641, 2.0514, -1.5266, -3.3193, -4.2935],\n", - " [ -3.1083, 0.1514, 2.5449, 3.1346, 0.5927, -2.3941],\n", - " [ -0.3564, 2.3192, 3.6723, 2.3636, -2.2541, -6.2640],\n", - " [ -0.1915, 2.1245, 2.9623, 1.1767, -2.0912, -3.5429],\n", - " [ -0.6759, 2.3227, 3.6681, 1.4574, -1.9775, -4.7240],\n", - " [ 3.1258, 4.5656, 2.5672, -2.2634, -3.4849, -4.2778],\n", - " [ -0.9123, 1.8090, 3.8298, 1.6944, -1.5815, -4.4170],\n", - " [ -5.9920, -1.7374, 1.0287, 3.1680, 4.4624, 1.1680],\n", - " [ 3.7907, 4.4047, 3.0457, -1.1351, -3.6876, -6.0250],\n", - " [ -3.3688, -0.1783, -1.3004, -0.3965, 2.2764, 3.1127],\n", - " [ 2.7761, 4.1873, 2.1086, -0.7149, -3.0538, -5.3742],\n", - " [ 3.8900, 3.8658, 2.8887, -1.6419, -3.3357, -5.4547],\n", - " [ -0.5850, 2.9294, 3.5951, 1.2724, -2.1297, -5.3975],\n", - " [ -6.7407, -2.0373, 0.7286, 3.2515, 3.8931, 2.3443],\n", - " [ 2.3462, 4.3209, 2.9216, -1.6209, -2.9469, -5.0696],\n", - " [ 2.5477, 3.3807, 2.4401, -0.7962, -2.7455, -4.4175],\n", - " [ -1.8783, 1.2982, 2.0425, 2.1692, -0.6005, -2.6517],\n", - " [ 0.8112, 3.3155, 2.7968, 0.1713, -2.7553, -4.4765],\n", - " [ -1.0702, 2.1001, 3.7740, 2.2885, -1.4044, -5.3781],\n", - " [ 2.9898, 3.7143, 2.3157, -1.6990, -2.8787, -4.2093],\n", - " [ 3.9521, 4.1913, 2.9394, -1.9234, -3.6885, -6.1421],\n", - " [ -9.2552, -2.7169, -1.5086, 2.4637, 5.2994, 7.1799],\n", - " [ 1.3106, 3.4553, 2.5456, -0.1109, -2.7854, -4.6200],\n", - " [ -7.7199, -2.4138, -0.9714, 2.0882, 4.7669, 5.5203],\n", - " [ -1.6886, 1.1704, 2.5370, 2.0853, -0.4791, -3.3866],\n", - " [ -4.8490, -1.2334, 2.2039, 3.9253, 2.3341, -1.1983],\n", - " [ 3.5109, 3.7076, 2.6758, -1.6990, -3.3221, -5.1113],\n", - " [ 2.2563, 3.5966, 2.2668, -1.7175, -2.5136, -3.6700],\n", - " [ -9.8098, -3.9056, 0.6071, 5.2101, 6.9284, 4.5129],\n", - " [ -0.4145, 2.3470, 2.8647, 1.4938, -1.9801, -4.6477],\n", - " [ 6.4263, 4.2989, 2.6321, -3.7260, -3.8457, -5.2750],\n", - " [ -6.8993, -1.5562, -0.4367, 1.6091, 4.0018, 4.4247],\n", - " [ -4.0698, -0.2169, 2.6032, 3.2307, 0.8881, -2.1681],\n", - " [ -3.2765, 0.1776, 2.0771, 2.8453, 0.5345, -1.8848],\n", - " [ -1.4043, 1.8810, 3.8091, 2.3010, -1.3787, -4.7129],\n", - " [ 2.6953, 3.7391, 2.1932, -1.6802, -2.9188, -4.7977],\n", - " [-10.5831, -3.2202, -2.0012, 2.4135, 6.4770, 7.5244],\n", - " [ -4.4925, -0.7056, 0.6993, 2.2317, 2.5942, 1.0553],\n", - " [ -2.8994, 0.7600, 3.0939, 2.7299, -0.4803, -2.6573],\n", - " [ 2.8883, 4.2360, 2.7968, -0.7679, -3.4427, -5.9431],\n", - " [ 1.5932, 3.6149, 2.7090, 0.4709, -2.7898, -6.0612],\n", - " [ 8.9576, 5.6995, 3.9559, -4.8407, -4.8331, -7.3363],\n", - " [ -4.2245, -0.6129, 2.3711, 3.7665, 1.5506, -1.4135],\n", - " [ -5.3299, -1.3727, 2.6542, 4.2236, 2.3501, -0.9212],\n", - " [ -7.4862, -2.6174, 0.9792, 3.7369, 4.7495, 2.6694],\n", - " [ -5.9296, -1.7436, 2.3963, 4.7665, 2.7905, -0.0745],\n", - " [ -1.3144, 2.1166, 3.2180, 1.5221, -1.4481, -3.9369],\n", - " [ -0.5381, 2.4285, 3.1258, 1.1249, -2.0329, -4.3294],\n", - " [ -8.6213, -2.5459, 1.1549, 3.9584, 4.8105, 2.8051],\n", - " [ -2.6219, 0.3634, 3.2660, 3.0531, -0.0452, -3.4234],\n", - " [ 0.4080, 2.3107, 1.8506, 0.7688, -1.1446, -3.6645],\n", - " [ -3.4532, -1.0942, 0.1554, 1.8389, 2.7802, 0.7740],\n", - " [ 3.6948, 4.8332, 2.7365, -2.4059, -3.7626, -4.9458],\n", - " [ -3.7767, 0.1466, 3.0137, 3.7753, 0.5258, -2.8847],\n", - " [ -6.0416, -2.1325, 1.0063, 3.3022, 4.1929, 2.1103],\n", - " [ 6.3636, 4.9418, 3.0033, -3.3700, -4.1646, -6.0754],\n", - " [ -4.3211, -0.2404, 3.3578, 4.5301, 1.0619, -3.0855],\n", - " [ -3.8663, -0.4867, 2.0463, 3.2049, 1.3199, -1.1269],\n", - " [ -6.7982, -2.1376, 0.9124, 3.5671, 4.3968, 1.7382],\n", - " [ 2.5751, 4.4468, 3.0757, -1.2578, -3.4666, -5.4434]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 3, 2, 1, 4, 4, 1, 0, 5, 1, 4, 5, 1, 2, 3, 3, 1, 1, 4, 4, 1, 3, 2, 2,\n", - " 2, 1, 2, 4, 1, 5, 1, 1, 2, 4, 1, 1, 3, 1, 2, 1, 1, 5, 1, 5, 2, 3, 1, 1,\n", - " 4, 2, 0, 5, 3, 3, 2, 1, 5, 4, 2, 1, 1, 0, 3, 3, 4, 3, 2, 2, 4, 2, 1, 4,\n", - " 1, 3, 4, 0, 3, 3, 4, 1], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 71 0 0 0 0 0]\n", - " [ 8 196 1 0 0 0]\n", - " [ 0 2 107 1 0 0]\n", - " [ 0 0 29 116 0 0]\n", - " [ 0 0 0 1 96 4]\n", - " [ 0 0 0 0 1 47]]\n", + "[[ 67 7 0 0 0 0]\n", + " [ 15 163 25 2 0 0]\n", + " [ 1 25 74 9 2 0]\n", + " [ 0 0 47 80 16 0]\n", + " [ 0 0 2 26 63 9]\n", + " [ 0 0 1 0 7 39]]\n", "\n" ] }, @@ -1634,8 +740,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 5/5\n", - "Train loss: 0.888149, Train accuracy: 0.9309\n", - "Val loss: 1.563793, Val accuracy: 0.5372\n", + "Train loss: 1.416767, Train accuracy: 0.7147\n", + "Val loss: 1.757895, Val accuracy: 0.3554\n", "\n", "[INFO] Network evaluation ...\n" ] @@ -1644,7 +750,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Test progress: 100%|███████████████████████████████████████| 2/2 [00:01<00:00, 1.21it/s, loss=4.43]" + "Test progress: 100%|████████████████████████████████████████| 2/2 [00:01<00:00, 1.36it/s, loss=1.7]" ] }, { @@ -1653,24 +759,24 @@ "text": [ "\n", "Confusion matrix :\n", - "[[17 8 0 0 1 0]\n", - " [ 8 40 10 2 1 0]\n", - " [ 0 8 7 18 1 0]\n", - " [ 0 2 8 20 9 0]\n", - " [ 0 2 2 9 16 0]\n", - " [ 0 1 0 4 7 0]]\n", + "[[12 9 1 0 0 0]\n", + " [14 45 1 0 0 0]\n", + " [ 0 21 12 0 0 0]\n", + " [ 0 20 20 1 1 0]\n", + " [ 0 7 19 0 4 0]\n", + " [ 0 2 8 0 3 1]]\n", "\n", - "MS: 0.0000\n", + "MS: 0.0238\n", "\n", - "QWK: 0.7589\n", + "QWK: 0.4650\n", "\n", - "MAE: 0.6169\n", + "MAE: 1.0050\n", "\n", - "CCR: 0.4975\n", + "CCR: 0.3731\n", "\n", - "1-off: 0.9204\n", + "1-off: 0.7164\n", "\n", - "[INFO] Total training time: 8.96s\n" + "[INFO] Total training time: 56.30s\n" ] }, { @@ -1745,7 +851,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.19" }, "orig_nbformat": 4 }, diff --git a/tutorials/stick_breaking_tutorial.ipynb b/tutorials/stick_breaking_tutorial.ipynb index 817d6d8..0629f74 100644 --- a/tutorials/stick_breaking_tutorial.ipynb +++ b/tutorials/stick_breaking_tutorial.ipynb @@ -30,7 +30,6 @@ "from torch.optim import Adam\n", "from torch.utils.data import DataLoader, Subset\n", "from torchvision import models\n", - "from torchvision.datasets import ImageFolder\n", "from torchvision.transforms import Compose, ToTensor\n", "from tqdm import tqdm" ] @@ -83,37 +82,47 @@ "Files already downloaded and verified\n", "Files already processed and verified\n", "Files already split and verified\n", - "Using cuda device\n", + "Files already downloaded and verified\n", + "Files already processed and verified\n", + "Files already split and verified\n", + "Using cpu device\n", "Detected image shape: [3, 128, 128]\n", - "class_weights=array([1.60843373, 0.55394191, 1.02692308, 0.78070175, 1.12184874,\n", - " 2.34210526])\n" + "class_weights=array([1.01908397, 1.53448276, 0.79464286, 1.13135593, 0.55165289,\n", + " 2.42727273])\n" ] } ], "source": [ - "fgnet = FGNet(root=\"./datasets/fgnet\", download=True, process_data=True)\n", - "\n", - "complete_train_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/train\", transform=Compose([ToTensor()])\n", + "fgnet_trainval = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=True,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", - "test_data = ImageFolder(\n", - " root=\"./datasets/fgnet/FGNET/test\", transform=Compose([ToTensor()])\n", + "\n", + "test_data = FGNet(\n", + " root=\"./datasets\",\n", + " download=True,\n", + " train=False,\n", + " target_transform=np.array,\n", + " transform=Compose([ToTensor()]),\n", ")\n", "\n", - "num_classes = len(complete_train_data.classes)\n", - "classes = complete_train_data.classes\n", - "targets = complete_train_data.targets\n", + "num_classes = len(fgnet_trainval.classes)\n", + "classes = fgnet_trainval.classes\n", + "targets = fgnet_trainval.targets\n", "\n", "# Create a validation split\n", "sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)\n", "sss_splits = list(\n", - " sss.split(X=np.zeros(len(complete_train_data)), y=complete_train_data.targets)\n", + " sss.split(X=np.zeros(len(fgnet_trainval)), y=fgnet_trainval.targets)\n", ")\n", "train_idx, val_idx = sss_splits[0]\n", "\n", "# Create subsets for training and validation\n", - "train_data = Subset(complete_train_data, train_idx)\n", - "val_data = Subset(complete_train_data, val_idx)\n", + "train_data = Subset(fgnet_trainval, train_idx)\n", + "val_data = Subset(fgnet_trainval, val_idx)\n", "\n", "# Get CUDA device\n", "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", @@ -524,169 +533,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████▎ | 1/4 [00:01<00:05, 1.85s/it, accuracy=57, loss=1.89]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-1.4524, -2.3870, -3.8288, -3.5049, -4.5601, -6.4897],\n", - " [-1.8007, -3.1167, -3.4451, -3.7290, -4.3031, -5.5605],\n", - " [-1.5087, -1.7912, -2.5809, -3.0644, -3.0782, -3.8081],\n", - " ...,\n", - " [-1.4835, -2.4432, -3.3595, -3.5126, -4.2919, -5.4658],\n", - " [-1.4676, -2.1262, -3.5059, -4.6662, -5.5921, -6.3200],\n", - " [-1.7479, -2.2322, -3.6313, -4.1351, -4.0532, -5.1453]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 3, 4, 1, 4, 3, 0, 1, 0, 1, 1, 2, 5, 2, 4, 1, 3, 1, 1, 2, 3, 1, 4, 1,\n", - " 3, 3, 2, 3, 3, 1, 0, 1, 0, 3, 4, 4, 1, 0, 1, 1, 3, 2, 0, 4, 3, 3, 4, 5,\n", - " 4, 1, 2, 3, 0, 4, 1, 1, 3, 4, 1, 3, 0, 2, 1, 1, 0, 0, 4, 4, 2, 1, 2, 4,\n", - " 5, 4, 2, 4, 1, 4, 2, 4, 1, 2, 1, 1, 1, 0, 3, 3, 0, 1, 2, 1, 3, 3, 1, 0,\n", - " 3, 3, 1, 0, 4, 1, 3, 4, 2, 0, 3, 2, 2, 1, 3, 2, 3, 4, 3, 3, 3, 5, 3, 1,\n", - " 3, 3, 3, 1, 3, 4, 4, 0, 3, 2, 0, 3, 3, 1, 2, 1, 5, 3, 3, 4, 5, 2, 1, 1,\n", - " 3, 2, 1, 3, 4, 0, 3, 3, 5, 0, 0, 1, 3, 1, 1, 1, 5, 1, 3, 0, 3, 1, 1, 1,\n", - " 0, 5, 3, 4, 1, 4, 3, 5, 1, 0, 1, 1, 1, 4, 0, 3, 2, 2, 4, 2, 1, 1, 5, 4,\n", - " 4, 3, 4, 5, 1, 5, 4, 4], device='cuda:0')\n", - "tensor([[-1.5130, -3.5241, -4.9549, -5.1838, -5.8980, -6.5854],\n", - " [-1.8332, -1.8776, -2.9285, -2.0628, -2.9925, -3.4792],\n", - " [-1.3934, -2.6727, -3.0570, -2.9145, -2.8310, -3.8281],\n", - " ...,\n", - " [-1.5236, -1.8109, -2.7062, -3.4412, -4.6127, -4.9685],\n", - " [-1.4274, -2.3605, -2.8682, -3.2514, -3.8929, -5.1776],\n", - " [-1.6656, -3.7816, -5.2436, -7.1552, -6.4038, -7.2137]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 2, 3, 0, 1, 1, 4, 1, 1, 4, 2, 2, 3, 1, 1, 3, 4, 1, 1, 2, 0, 2, 3, 0,\n", - " 2, 1, 2, 1, 2, 2, 2, 2, 4, 3, 4, 3, 0, 2, 0, 2, 3, 1, 3, 3, 1, 1, 3, 3,\n", - " 1, 4, 3, 1, 2, 4, 3, 0, 0, 5, 1, 0, 1, 1, 1, 3, 2, 4, 3, 4, 1, 3, 1, 0,\n", - " 1, 1, 4, 2, 2, 4, 3, 3, 0, 1, 2, 2, 1, 1, 1, 2, 3, 5, 2, 5, 3, 3, 0, 1,\n", - " 2, 0, 3, 4, 3, 3, 1, 1, 3, 2, 0, 1, 2, 2, 3, 1, 0, 1, 2, 1, 2, 5, 2, 1,\n", - " 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 4, 2, 1, 3, 1, 2, 1, 3, 1, 2, 3, 0, 5, 2,\n", - " 1, 5, 1, 3, 3, 1, 5, 3, 5, 3, 1, 4, 4, 1, 0, 1, 4, 1, 3, 0, 5, 4, 4, 4,\n", - " 3, 3, 1, 1, 1, 5, 2, 3, 4, 3, 0, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 0, 1, 3,\n", - " 5, 1, 2, 4, 2, 1, 1, 0], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 75%|██████████████████▊ | 3/4 [00:02<00:00, 1.82it/s, accuracy=125, loss=1.6]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-1.6488, -2.5838, -2.8312, -1.9495, -2.4759, -3.4682],\n", - " [-1.7872, -3.9560, -6.2586, -8.3967, -7.9779, -8.0079],\n", - " [-2.4078, -2.5133, -2.5584, -1.7485, -2.3675, -2.6099],\n", - " ...,\n", - " [-2.0260, -2.0912, -2.4092, -2.4302, -2.1441, -2.6813],\n", - " [-1.9830, -3.4873, -5.6127, -4.9098, -5.1187, -5.4987],\n", - " [-1.4839, -3.4648, -3.5702, -3.6756, -2.7831, -4.2399]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([4, 0, 5, 1, 2, 2, 5, 1, 4, 1, 2, 3, 1, 0, 3, 1, 1, 3, 4, 3, 3, 4, 3, 1,\n", - " 1, 2, 2, 1, 4, 4, 2, 0, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 1, 4, 1, 5, 2, 0,\n", - " 3, 1, 4, 2, 2, 2, 4, 3, 4, 1, 3, 3, 1, 1, 1, 1, 1, 0, 2, 2, 5, 3, 4, 1,\n", - " 4, 4, 4, 3, 3, 1, 3, 4, 5, 1, 1, 3, 5, 1, 1, 1, 1, 5, 1, 2, 0, 3, 5, 2,\n", - " 3, 2, 4, 1, 3, 1, 1, 1, 4, 4, 1, 1, 2, 1, 1, 0, 0, 3, 1, 4, 3, 3, 1, 2,\n", - " 1, 4, 2, 5, 4, 4, 4, 3, 1, 0, 5, 3, 5, 5, 2, 3, 4, 5, 1, 0, 5, 1, 1, 0,\n", - " 0, 4, 3, 1, 2, 1, 0, 3, 5, 4, 2, 1, 1, 2, 3, 4, 0, 1, 1, 2, 3, 3, 1, 4,\n", - " 1, 3, 1, 0, 1, 2, 3, 3, 1, 2, 5, 1, 3, 2, 3, 4, 0, 0, 0, 3, 3, 2, 4, 4,\n", - " 3, 3, 1, 4, 3, 5, 2, 1], device='cuda:0')\n", - "tensor([[-1.8020, -2.6416, -2.7470, -2.3354, -2.2209, -3.2640],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8659, -2.8370],\n", - " [-2.4079, -4.7105, -7.0131, -9.3157, -9.4211, -8.5060],\n", - " [-2.3657, -2.5187, -2.6241, -1.7702, -2.1811, -2.9236],\n", - " [-1.8144, -2.6373, -4.9399, -7.2424, -7.3478, -7.3860],\n", - " [-2.3484, -2.5211, -2.6264, -2.7318, -1.9783, -2.2196],\n", - " [-2.4079, -2.5133, -2.6187, -1.8491, -2.2126, -2.5091],\n", - " [-1.5313, -1.8085, -2.6948, -3.5838, -3.8939, -4.7027],\n", - " [-1.8156, -3.4182, -4.7286, -5.3209, -5.2911, -5.7105],\n", - " [-1.4275, -3.3254, -3.1292, -2.6073, -3.5181, -4.4076],\n", - " [-2.4079, -2.5133, -2.3004, -1.7533, -2.6257, -3.8405],\n", - " [-2.4079, -1.9894, -1.7846, -2.9020, -2.5126, -3.8590],\n", - " [-2.0140, -2.5806, -2.6860, -1.8005, -2.4223, -3.6258],\n", - " [-2.3964, -2.5148, -2.6201, -1.7108, -2.6042, -3.7245],\n", - " [-2.4079, -2.5133, -2.4516, -1.8638, -2.1612, -3.0725],\n", - " [-1.9131, -2.6064, -2.7118, -2.0410, -2.1181, -2.6941],\n", - " [-1.6061, -2.7335, -2.8389, -2.4023, -2.5902, -3.8640],\n", - " [-2.4079, -2.5133, -2.6187, -1.8101, -2.1201, -2.8880],\n", - " [-1.3865, -3.0871, -2.9397, -2.4926, -2.7682, -3.7490],\n", - " [-1.4839, -2.7011, -2.9588, -2.3457, -2.3305, -3.2041],\n", - " [-1.6357, -2.7159, -2.8213, -2.2436, -2.1865, -2.7383],\n", - " [-2.4079, -2.5133, -2.6187, -2.3371, -1.9540, -3.0941],\n", - " [-2.4079, -2.5133, -2.6187, -2.2303, -1.9371, -2.8820],\n", - " [-2.3147, -2.5258, -2.6311, -2.7365, -1.9496, -2.9846],\n", - " [-2.2629, -2.5335, -2.0113, -2.1558, -2.2342, -3.1952],\n", - " [-2.4079, -2.3465, -1.6321, -2.7560, -2.4546, -3.5264],\n", - " [-2.4079, -2.5133, -2.6187, -2.0865, -2.0648, -3.0469],\n", - " [-1.8879, -3.1359, -3.6269, -4.0197, -4.4688, -5.4832],\n", - " [-2.4079, -2.5133, -2.6187, -1.9235, -2.2047, -3.2516],\n", - " [-2.4079, -2.5133, -2.6187, -1.7789, -2.2354, -2.6416],\n", - " [-1.9619, -2.5933, -2.3674, -2.1401, -2.1414, -3.0676],\n", - " [-2.0507, -2.1201, -2.0660, -2.1727, -2.5465, -2.7236],\n", - " [-1.5128, -2.8049, -5.1075, -6.9748, -7.5928, -6.7450],\n", - " [-1.7115, -3.8496, -3.9550, -3.0487, -3.6884, -4.6871],\n", - " [-1.5369, -2.1574, -3.7777, -4.5381, -3.6242, -4.3952],\n", - " [-2.4079, -2.5133, -4.8159, -6.5040, -7.2504, -6.6136],\n", - " [-1.4711, -2.8502, -2.9556, -3.0610, -2.3559, -2.5951],\n", - " [-1.7345, -2.6678, -4.9704, -7.2059, -7.2375, -6.7216],\n", - " [-1.7998, -1.7404, -2.7485, -2.2529, -2.8998, -2.9346],\n", - " [-2.4079, -2.5133, -2.2270, -1.8987, -2.1648, -2.8768],\n", - " [-2.4079, -3.7767, -5.7461, -5.8514, -5.2240, -5.3080],\n", - " [-2.4079, -4.7105, -4.6955, -3.9194, -4.7205, -5.7029],\n", - " [-2.4079, -2.5133, -1.8430, -2.0042, -2.3677, -3.3199],\n", - " [-2.4079, -4.1227, -5.3997, -6.7084, -5.7923, -7.1344],\n", - " [-2.4079, -3.6894, -5.3813, -5.5380, -4.8414, -6.1292],\n", - " [-2.4079, -3.6904, -4.6588, -4.8395, -5.1667, -6.3069],\n", - " [-2.4079, -4.7105, -6.2597, -7.3108, -6.4688, -7.2942],\n", - " [-2.4079, -2.5133, -2.6187, -1.8631, -2.0713, -2.8090],\n", - " [-2.4079, -2.5133, -2.4556, -1.8610, -2.1735, -2.5925],\n", - " [-2.4079, -4.7105, -7.0131, -9.3157, -9.4211, -8.8668],\n", - " [-2.4079, -2.5133, -4.8159, -6.2508, -6.7785, -7.0962],\n", - " [-1.4990, -2.8186, -2.9240, -2.2136, -2.3830, -2.8587],\n", - " [-1.8439, -4.0325, -4.1378, -3.9263, -3.7481, -5.3852],\n", - " [-2.4079, -2.5133, -4.7115, -7.0000, -7.1053, -6.2114],\n", - " [-2.3463, -2.5214, -2.6267, -1.7375, -2.3338, -3.4606],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.0045, -3.0646],\n", - " [-2.4079, -2.5133, -2.6187, -2.0837, -1.9491, -2.6585],\n", - " [-2.4079, -2.5133, -2.6187, -1.7750, -2.2376, -3.2353],\n", - " [-2.4079, -2.5133, -2.6187, -2.2013, -1.9426, -2.5109],\n", - " [-2.4079, -2.2341, -1.6597, -2.5398, -4.0950, -3.9153],\n", - " [-2.4079, -2.0818, -1.7181, -2.7305, -2.4353, -2.8380],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.3554, -2.0013],\n", - " [-1.7713, -2.6530, -2.6912, -2.4290, -2.0841, -2.5506],\n", - " [-2.4079, -4.2364, -5.7324, -6.3360, -6.1686, -7.6593],\n", - " [-2.4079, -1.5658, -2.5201, -3.4695, -4.1227, -4.3908],\n", - " [-2.4079, -1.4922, -2.1987, -3.5896, -3.2733, -3.8963],\n", - " [-2.4079, -2.5133, -2.6187, -2.3615, -1.9073, -2.3995],\n", - " [-2.4079, -1.9095, -3.0894, -4.9545, -4.0482, -5.0562],\n", - " [-2.4079, -1.5682, -2.6350, -4.2821, -3.7997, -4.5301],\n", - " [-2.4079, -3.7450, -4.6539, -5.4054, -4.5443, -5.1677],\n", - " [-1.9171, -4.0036, -4.2501, -3.6140, -3.6953, -4.6751],\n", - " [-1.5892, -2.7445, -2.8498, -2.2067, -2.3542, -3.4145],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8078, -2.5932],\n", - " [-2.4079, -2.5133, -2.6187, -2.3025, -2.1147, -3.3619],\n", - " [-1.8271, -4.0100, -4.1154, -3.2053, -3.9942, -5.1244],\n", - " [-2.1208, -2.5580, -2.1486, -1.9134, -2.3911, -2.7613],\n", - " [-2.4079, -2.5133, -2.6187, -2.2342, -1.9020, -2.5742],\n", - " [-1.6311, -2.7060, -2.8255, -2.9309, -2.0274, -2.6007],\n", - " [-2.4079, -2.5133, -2.4176, -1.7384, -2.5535, -3.6184],\n", - " [-1.7558, -2.6591, -2.7645, -2.1344, -2.1547, -2.9992]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 5, 0, 3, 0, 4, 3, 1, 1, 3, 3, 2, 3, 1, 2, 3, 3, 3, 2, 1, 3, 4, 2, 3,\n", - " 1, 4, 5, 1, 4, 5, 3, 1, 0, 2, 1, 1, 5, 1, 4, 1, 1, 2, 1, 1, 2, 1, 0, 4,\n", - " 2, 0, 1, 3, 1, 1, 1, 4, 3, 2, 5, 2, 1, 5, 2, 1, 1, 1, 5, 1, 0, 0, 0, 3,\n", - " 4, 3, 2, 1, 4, 0, 1, 3], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|█████████████████████████| 4/4 [00:02<00:00, 1.78it/s, accuracy=125, loss=1.6]" + "Train progress: 100%|████████████████████████| 4/4 [00:11<00:00, 2.81s/it, accuracy=118, loss=1.63]" ] }, { @@ -695,12 +542,12 @@ "text": [ "\n", "Train Confusion matrix :\n", - "[[ 69 2 0 0 0 0]\n", - " [166 23 7 7 2 0]\n", - " [ 85 14 4 4 3 0]\n", - " [116 6 1 15 7 0]\n", - " [ 75 5 2 6 13 0]\n", - " [ 33 3 1 5 5 1]]\n", + "[[ 64 4 5 0 1 0]\n", + " [143 20 20 13 9 0]\n", + " [ 73 12 9 11 6 0]\n", + " [ 92 12 12 10 16 1]\n", + " [ 56 8 8 15 12 1]\n", + " [ 29 3 1 4 7 3]]\n", "\n" ] }, @@ -716,8 +563,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 1/5\n", - "Train loss: 1.943677, Train accuracy: 0.1838\n", - "Val loss: 3.073993, Val accuracy: 0.1240\n", + "Train loss: 1.847242, Train accuracy: 0.1735\n", + "Val loss: 2.574011, Val accuracy: 0.1074\n", "\n" ] }, @@ -725,170 +572,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:01, 1.52it/s, accuracy=205, loss=1.43]" + "Train progress: 100%|████████████████████████| 4/4 [00:09<00:00, 2.28s/it, accuracy=379, loss=1.33]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-2.4079, -2.5133, -1.9160, -1.9258, -2.3771, -3.2150],\n", - " [-2.4079, -2.5133, -2.6187, -1.7665, -2.1753, -2.9375],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.0245, -2.1541],\n", - " ...,\n", - " [-2.4079, -2.5133, -2.6187, -1.9590, -2.9679, -2.1467],\n", - " [-2.4079, -2.3059, -2.1552, -1.8285, -2.8456, -4.5923],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8506, -2.8107]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 3, 4, 4, 2, 3, 3, 2, 0, 3, 1, 4, 2, 2, 4, 0, 1, 4, 1, 1, 3, 0, 1, 2,\n", - " 2, 0, 5, 4, 0, 1, 3, 4, 0, 2, 2, 3, 4, 1, 3, 3, 4, 1, 3, 3, 1, 3, 4, 0,\n", - " 4, 1, 5, 1, 3, 0, 1, 1, 3, 1, 4, 1, 1, 1, 2, 1, 3, 1, 4, 4, 0, 1, 2, 4,\n", - " 0, 1, 3, 4, 3, 3, 4, 3, 1, 1, 5, 0, 4, 0, 0, 1, 3, 4, 4, 0, 3, 1, 1, 2,\n", - " 4, 3, 3, 3, 3, 0, 4, 4, 1, 1, 4, 4, 1, 2, 2, 2, 3, 3, 1, 2, 0, 3, 5, 1,\n", - " 3, 2, 2, 1, 1, 0, 0, 5, 1, 3, 3, 1, 1, 5, 0, 3, 1, 4, 2, 3, 1, 1, 3, 5,\n", - " 3, 5, 2, 5, 1, 3, 0, 3, 0, 1, 4, 2, 3, 3, 1, 5, 2, 4, 5, 3, 4, 1, 2, 5,\n", - " 1, 3, 1, 1, 2, 4, 1, 4, 1, 1, 4, 3, 2, 3, 4, 4, 1, 0, 1, 0, 2, 4, 3, 3,\n", - " 0, 2, 0, 2, 2, 5, 1, 5], device='cuda:0')\n", - "tensor([[-2.3384, -2.5224, -2.6278, -2.7332, -2.8385, -1.9286],\n", - " [-2.4079, -2.5133, -2.6187, -1.7466, -2.2076, -2.9389],\n", - " [-2.4079, -2.5133, -2.6187, -2.7052, -2.7559, -1.9257],\n", - " ...,\n", - " [-2.4079, -2.5133, -1.6167, -2.5279, -3.0081, -4.1926],\n", - " [-2.4079, -2.5133, -2.6187, -1.9270, -2.1441, -2.4808],\n", - " [-2.4079, -2.5133, -2.5897, -1.7069, -2.8747, -4.1869]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([5, 3, 5, 4, 0, 0, 1, 1, 1, 3, 5, 3, 1, 3, 1, 2, 1, 3, 0, 1, 5, 3, 1, 3,\n", - " 2, 3, 5, 1, 0, 1, 4, 2, 1, 2, 1, 4, 5, 3, 0, 2, 5, 4, 3, 4, 0, 1, 1, 4,\n", - " 0, 1, 1, 0, 1, 4, 2, 3, 1, 3, 2, 4, 5, 4, 5, 1, 3, 1, 0, 3, 0, 3, 1, 2,\n", - " 3, 3, 1, 3, 0, 4, 2, 1, 1, 0, 2, 2, 4, 4, 2, 2, 4, 0, 2, 5, 3, 0, 1, 3,\n", - " 2, 1, 2, 2, 3, 1, 1, 4, 2, 4, 3, 3, 1, 1, 0, 1, 3, 2, 3, 3, 0, 2, 3, 0,\n", - " 0, 5, 4, 1, 2, 3, 2, 1, 3, 3, 1, 5, 3, 0, 1, 5, 0, 3, 1, 4, 3, 4, 1, 4,\n", - " 2, 1, 2, 4, 0, 2, 1, 3, 1, 2, 4, 1, 1, 0, 1, 3, 3, 1, 1, 5, 1, 3, 2, 1,\n", - " 0, 2, 2, 4, 1, 5, 2, 2, 2, 2, 3, 4, 1, 3, 3, 3, 4, 1, 0, 2, 5, 3, 3, 4,\n", - " 1, 0, 4, 1, 3, 0, 2, 3], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -2.6187, -2.0652, -3.7299, -5.1068],\n", - " [-2.4079, -2.1725, -2.1459, -3.3318, -3.9624, -4.4700],\n", - " [-2.4079, -1.7924, -2.0101, -4.1030, -3.3445, -4.4812],\n", - " ...,\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8202, -2.0449],\n", - " [-2.4079, -2.5133, -1.6921, -2.2241, -4.1871, -5.7217],\n", - " [-2.4079, -2.5133, -2.1586, -1.8436, -2.8317, -3.9434]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 2, 1, 1, 1, 1, 1, 1, 3, 1, 3, 2, 3, 1, 3, 1, 2, 2, 1, 3, 1, 0, 4, 2,\n", - " 2, 1, 2, 1, 1, 3, 1, 4, 1, 3, 4, 4, 2, 1, 3, 1, 1, 0, 1, 5, 1, 4, 1, 0,\n", - " 2, 2, 4, 4, 1, 1, 4, 2, 5, 3, 3, 1, 1, 5, 5, 2, 0, 1, 4, 3, 1, 1, 1, 3,\n", - " 3, 2, 0, 1, 2, 0, 4, 4, 2, 2, 1, 1, 2, 1, 0, 1, 3, 3, 2, 0, 1, 4, 3, 3,\n", - " 0, 1, 1, 1, 0, 1, 0, 4, 4, 0, 1, 1, 3, 3, 3, 2, 5, 2, 1, 3, 3, 0, 3, 1,\n", - " 0, 1, 0, 3, 1, 4, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 4, 2, 3, 3, 1, 2, 0, 1,\n", - " 4, 0, 5, 5, 3, 5, 0, 4, 1, 1, 1, 2, 4, 2, 0, 3, 1, 1, 4, 5, 3, 4, 3, 2,\n", - " 3, 1, 4, 1, 4, 5, 3, 2, 1, 1, 2, 3, 3, 1, 1, 4, 4, 1, 2, 5, 1, 5, 1, 2,\n", - " 3, 2, 3, 3, 3, 4, 1, 1], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|█████████████████████████| 4/4 [00:01<00:00, 3.97it/s, accuracy=313, loss=1.5]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-2.4079, -2.5133, -1.9351, -2.5741, -4.7170, -6.3726],\n", - " [-2.4079, -2.5133, -4.1471, -6.2970, -5.5215, -6.7864],\n", - " [-2.4079, -2.5133, -1.7565, -2.1640, -3.7764, -5.3159],\n", - " [-2.3723, -2.5179, -2.0950, -2.5039, -2.1153, -2.7624],\n", - " [-2.4079, -2.5133, -2.0883, -1.9580, -3.2718, -4.5838],\n", - " [-1.4294, -3.3310, -5.5351, -5.7522, -5.8575, -5.3719],\n", - " [-2.4079, -2.5133, -2.6187, -1.7032, -2.5025, -3.5190],\n", - " [-2.4079, -2.5133, -2.6187, -2.1777, -2.0606, -2.3036],\n", - " [-2.4079, -2.5133, -2.6187, -1.7095, -2.5151, -3.0914],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -5.0266, -6.7685],\n", - " [-2.4079, -2.5133, -2.6187, -2.4576, -1.9045, -2.3753],\n", - " [-1.5934, -1.7343, -2.6637, -4.6723, -3.8892, -4.8865],\n", - " [-2.4079, -1.6239, -1.8833, -2.8012, -3.1644, -3.2043],\n", - " [-2.4079, -2.5133, -1.6277, -2.6413, -4.8464, -6.4091],\n", - " [-2.4079, -2.0500, -1.7096, -3.2425, -2.5706, -2.7330],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9051, -2.6536],\n", - " [-2.2824, -1.8297, -2.9461, -4.6159, -3.7591, -4.3788],\n", - " [-1.7065, -2.6801, -2.7855, -2.3352, -3.1075, -5.2215],\n", - " [-2.1664, -2.5495, -2.2540, -2.5319, -2.9418, -4.8455],\n", - " [-2.4079, -2.5133, -2.6187, -1.8444, -2.1728, -2.6622],\n", - " [-2.4079, -2.5133, -2.6187, -2.6872, -1.8173, -2.6088],\n", - " [-2.4079, -2.5133, -1.8366, -1.9112, -2.5858, -3.3165],\n", - " [-2.4079, -4.0631, -4.9598, -5.0652, -4.2777, -5.8058],\n", - " [-2.4079, -2.5133, -1.7558, -2.0155, -3.9208, -5.6581],\n", - " [-1.3919, -3.1787, -2.4698, -2.5963, -3.8317, -5.4883],\n", - " [-2.4079, -2.5133, -2.6187, -1.7182, -2.2954, -3.2386],\n", - " [-2.4079, -2.5133, -1.9821, -3.2428, -4.6105, -5.9341],\n", - " [-2.4079, -1.7098, -3.7911, -6.0936, -5.2987, -6.5814],\n", - " [-2.4079, -2.5133, -1.7162, -2.2029, -3.4977, -4.7543],\n", - " [-2.4079, -2.5133, -2.6187, -1.7041, -3.1458, -4.4827],\n", - " [-2.4079, -2.5133, -1.6751, -2.0739, -2.9540, -4.2803],\n", - " [-2.4079, -1.7791, -2.2108, -3.3984, -4.3644, -5.3652],\n", - " [-2.4079, -2.5133, -1.6210, -2.6924, -3.9260, -5.1502],\n", - " [-2.4079, -2.5133, -1.7796, -3.0737, -3.4688, -4.5963],\n", - " [-2.0813, -2.5659, -1.9842, -2.4010, -4.5984, -6.3922],\n", - " [-2.4079, -2.5133, -2.1463, -1.9196, -3.9342, -5.5527],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.3652, -2.0078],\n", - " [-2.4079, -2.0043, -1.7208, -2.5860, -3.1291, -3.7378],\n", - " [-2.4079, -2.5133, -2.6187, -2.2485, -2.2158, -3.5532],\n", - " [-2.4079, -2.5133, -1.7310, -2.7484, -4.0674, -5.3502],\n", - " [-2.4079, -2.5133, -3.9016, -5.0757, -5.4568, -5.6487],\n", - " [-2.4079, -2.5133, -3.8584, -5.7958, -5.6470, -6.0451],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.4995, -1.9721],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9613],\n", - " [-2.4079, -2.5133, -2.6187, -1.7510, -2.6814, -3.5069],\n", - " [-1.3865, -3.0881, -3.1934, -2.4158, -4.4380, -6.3535],\n", - " [-2.4079, -1.9974, -1.8618, -2.0602, -2.8545, -3.1358],\n", - " [-2.4079, -2.5133, -2.6187, -2.0535, -2.0200, -2.5592],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -7.2238, -7.3292],\n", - " [-2.4079, -1.5031, -2.1230, -4.0175, -3.3207, -4.4068],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.5067, -2.0039],\n", - " [-1.6072, -2.6753, -4.4691, -4.5745, -4.3731, -3.8754],\n", - " [-1.7793, -2.6500, -1.7394, -2.5757, -2.9658, -4.1755],\n", - " [-2.4079, -2.5133, -1.6766, -2.0766, -3.6007, -5.3752],\n", - " [-2.4079, -2.5133, -4.0306, -5.8741, -5.2802, -6.1582],\n", - " [-2.4079, -2.5133, -1.6138, -2.2350, -3.1695, -4.0677],\n", - " [-2.0632, -2.5697, -4.8723, -7.1749, -7.2538, -7.3890],\n", - " [-1.8744, -4.0725, -6.3751, -6.4805, -6.5859, -6.6912],\n", - " [-2.4079, -1.6385, -1.9810, -2.9849, -3.9533, -5.0263],\n", - " [-2.4079, -2.5133, -2.6187, -1.9193, -2.0353, -2.9385],\n", - " [-2.4079, -2.5133, -2.1996, -1.9155, -3.9388, -5.4675],\n", - " [-2.4079, -1.9578, -4.0198, -5.5189, -5.5987, -6.5233],\n", - " [-2.4079, -2.5133, -2.6187, -2.0217, -1.9909, -2.6503],\n", - " [-2.4079, -2.5133, -1.8103, -2.1410, -3.4671, -4.8419],\n", - " [-2.4079, -4.3782, -6.6270, -6.7323, -5.8726, -6.3308],\n", - " [-2.4079, -2.5133, -2.3219, -2.1948, -2.1707, -2.3130],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -2.0629],\n", - " [-2.4079, -2.5133, -2.0034, -1.9250, -2.3927, -3.4318],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8344, -2.6874],\n", - " [-2.4079, -2.5133, -1.9402, -2.5164, -4.0235, -5.5646],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8787, -2.2860],\n", - " [-2.4079, -2.5133, -1.7054, -2.6797, -3.3834, -4.4149],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.0879, -2.1100],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9425],\n", - " [-2.4079, -2.5133, -1.6815, -2.2784, -2.5469, -3.6127],\n", - " [-2.4079, -2.5133, -1.6721, -2.6524, -3.8595, -5.1588],\n", - " [-2.4079, -2.5133, -1.6094, -2.1849, -3.1757, -4.4443],\n", - " [-2.4079, -2.5133, -1.5971, -2.6901, -2.8088, -4.0883],\n", - " [-2.4079, -2.5133, -1.7211, -2.1134, -3.3022, -4.5852],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8087, -2.5379]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 1, 3, 1, 3, 1, 3, 4, 2, 3, 3, 1, 2, 2, 1, 5, 1, 2, 1, 3, 4, 3, 1, 3,\n", - " 1, 3, 1, 1, 3, 3, 4, 2, 2, 1, 1, 1, 5, 2, 4, 1, 1, 1, 4, 5, 2, 1, 1, 4,\n", - " 0, 1, 5, 1, 2, 2, 1, 3, 1, 1, 4, 3, 4, 1, 4, 3, 1, 3, 5, 4, 4, 3, 4, 1,\n", - " 4, 5, 2, 2, 2, 1, 2, 4], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[62 2 2 5 0 0]\n", - " [93 36 33 35 8 0]\n", - " [ 8 10 35 52 5 0]\n", - " [ 8 2 22 99 12 2]\n", - " [ 3 2 6 34 51 5]\n", - " [ 0 1 0 10 7 30]]\n", + "[[ 60 12 0 2 0 0]\n", + " [ 26 109 22 31 16 1]\n", + " [ 4 14 18 55 20 0]\n", + " [ 7 4 1 106 24 1]\n", + " [ 2 0 0 32 63 3]\n", + " [ 0 1 1 7 15 23]]\n", "\n" ] }, @@ -904,8 +602,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 2/5\n", - "Train loss: 1.448030, Train accuracy: 0.4603\n", - "Val loss: 2.594396, Val accuracy: 0.1488\n", + "Train loss: 1.367832, Train accuracy: 0.5574\n", + "Val loss: 2.786200, Val accuracy: 0.1074\n", "\n" ] }, @@ -913,170 +611,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:01, 1.64it/s, accuracy=228, loss=1.32]" + "Train progress: 100%|█████████████████████████| 4/4 [00:09<00:00, 2.41s/it, accuracy=422, loss=1.2]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-2.4079, -2.5133, -1.6020, -2.2353, -3.2371, -4.6382],\n", - " [-2.4079, -2.5133, -2.6187, -2.6082, -1.8703, -2.4251],\n", - " [-2.4079, -1.8561, -3.9906, -5.2992, -5.8944, -7.1456],\n", - " ...,\n", - " [-2.4079, -2.5133, -2.6187, -1.7393, -3.6285, -5.4105],\n", - " [-1.7963, -1.8466, -2.2472, -4.3748, -3.5160, -3.9487],\n", - " [-2.4079, -2.5133, -2.1640, -1.7930, -3.1957, -4.6181]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 4, 1, 0, 1, 1, 0, 3, 1, 1, 4, 2, 3, 1, 0, 5, 1, 3, 2, 3, 2, 1, 1, 2,\n", - " 5, 4, 0, 4, 3, 1, 0, 3, 1, 0, 2, 1, 1, 4, 3, 0, 0, 3, 1, 1, 3, 2, 1, 2,\n", - " 2, 1, 2, 3, 1, 3, 1, 1, 1, 1, 0, 1, 1, 5, 3, 4, 1, 1, 2, 3, 1, 3, 1, 1,\n", - " 2, 3, 1, 1, 3, 5, 3, 4, 3, 1, 4, 1, 1, 3, 2, 2, 1, 4, 5, 1, 1, 1, 1, 4,\n", - " 2, 4, 2, 3, 0, 2, 1, 2, 3, 1, 4, 2, 1, 4, 2, 3, 0, 4, 3, 1, 4, 1, 5, 3,\n", - " 4, 3, 3, 4, 3, 1, 2, 1, 4, 1, 2, 4, 0, 2, 1, 5, 2, 3, 4, 4, 1, 1, 2, 5,\n", - " 3, 1, 1, 1, 3, 1, 2, 5, 4, 0, 3, 1, 4, 1, 4, 1, 4, 0, 4, 0, 2, 1, 1, 3,\n", - " 3, 3, 5, 2, 3, 3, 1, 1, 1, 3, 4, 2, 3, 0, 1, 2, 3, 1, 1, 2, 3, 2, 1, 3,\n", - " 0, 4, 5, 0, 1, 1, 1, 1], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -1.6095, -2.6331, -3.9294, -5.1521],\n", - " [-2.4079, -2.5133, -2.0915, -1.8078, -3.4673, -5.4527],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.5659, -4.3494],\n", - " ...,\n", - " [-2.4079, -2.5133, -4.3731, -6.5964, -6.3215, -6.8714],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.2576, -2.1033],\n", - " [-2.4079, -2.5133, -2.6187, -1.7220, -2.6234, -3.9801]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 3, 4, 0, 2, 1, 3, 3, 3, 1, 0, 1, 4, 0, 3, 5, 3, 1, 3, 3, 4, 2, 4, 1,\n", - " 1, 3, 0, 2, 1, 0, 1, 4, 1, 3, 1, 2, 1, 0, 1, 5, 1, 3, 3, 4, 5, 1, 5, 2,\n", - " 1, 1, 3, 1, 2, 1, 4, 3, 0, 3, 4, 1, 1, 1, 3, 2, 2, 5, 3, 3, 4, 2, 3, 5,\n", - " 1, 3, 3, 2, 5, 3, 4, 1, 2, 2, 1, 3, 0, 1, 4, 2, 1, 0, 5, 1, 3, 1, 4, 1,\n", - " 3, 4, 5, 5, 1, 4, 1, 2, 2, 2, 4, 4, 3, 3, 0, 0, 1, 0, 3, 3, 1, 2, 5, 1,\n", - " 4, 0, 2, 3, 0, 2, 0, 1, 4, 1, 3, 3, 0, 1, 4, 1, 2, 4, 0, 2, 5, 0, 2, 3,\n", - " 2, 3, 4, 4, 1, 2, 4, 3, 3, 1, 2, 1, 3, 4, 1, 3, 1, 5, 2, 1, 2, 5, 1, 1,\n", - " 5, 4, 1, 4, 2, 1, 1, 1, 1, 1, 5, 2, 4, 1, 4, 1, 1, 4, 1, 4, 2, 3, 3, 1,\n", - " 3, 3, 5, 1, 1, 1, 5, 3], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -1.7037, -2.0140, -3.6864, -5.3751],\n", - " [-2.4079, -2.5133, -2.2185, -2.5984, -1.9986, -3.1578],\n", - " [-1.5206, -2.1395, -4.1148, -4.0327, -6.2684, -8.3363],\n", - " ...,\n", - " [-2.4079, -1.5172, -3.3793, -4.9799, -4.9927, -6.0050],\n", - " [-2.4079, -2.5133, -4.8159, -4.9213, -5.0266, -5.1320],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9015, -3.0066]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 5, 0, 2, 1, 3, 2, 1, 2, 1, 4, 3, 2, 4, 2, 4, 1, 1, 2, 1, 5, 2, 5, 3,\n", - " 3, 0, 2, 4, 3, 3, 0, 4, 0, 1, 3, 1, 4, 0, 2, 4, 4, 1, 1, 1, 5, 5, 5, 3,\n", - " 1, 5, 2, 3, 3, 1, 4, 3, 2, 0, 3, 1, 1, 4, 4, 0, 2, 0, 1, 3, 3, 5, 1, 4,\n", - " 5, 0, 1, 0, 3, 0, 3, 4, 0, 0, 3, 1, 0, 1, 1, 3, 5, 4, 1, 1, 3, 3, 1, 0,\n", - " 1, 1, 1, 2, 1, 5, 4, 3, 2, 2, 1, 2, 1, 3, 4, 1, 2, 3, 0, 1, 2, 1, 1, 1,\n", - " 0, 2, 1, 3, 5, 4, 1, 1, 1, 0, 3, 2, 2, 4, 5, 1, 3, 2, 1, 0, 3, 3, 1, 3,\n", - " 4, 4, 1, 1, 3, 4, 1, 1, 4, 2, 3, 2, 1, 0, 2, 2, 4, 5, 2, 4, 3, 0, 2, 3,\n", - " 1, 2, 2, 0, 3, 2, 4, 4, 3, 1, 4, 3, 0, 3, 0, 1, 3, 5, 1, 0, 4, 5, 4, 3,\n", - " 3, 1, 2, 1, 3, 2, 0, 4], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:00<00:00, 4.15it/s, accuracy=408, loss=1.27]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-2.4079, -2.5133, -2.6187, -1.8227, -2.1180, -3.3021],\n", - " [-2.4079, -2.5133, -1.6879, -2.6541, -3.4797, -4.8191],\n", - " [-2.4079, -2.5133, -4.7065, -4.9361, -4.2972, -4.3975],\n", - " [-2.4079, -2.5133, -2.6187, -1.7396, -2.4785, -3.9880],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.0113, -2.1452],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8518, -3.1593],\n", - " [-2.4079, -1.4936, -3.1630, -4.4447, -5.2642, -6.2910],\n", - " [-2.4079, -2.5133, -4.0907, -6.2154, -6.3207, -6.4261],\n", - " [-2.4079, -2.5133, -2.4350, -1.7648, -3.0279, -4.8055],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8791, -2.9727],\n", - " [-2.4079, -1.5563, -3.0742, -5.2104, -4.4916, -5.0077],\n", - " [-2.4079, -2.5133, -2.6187, -1.7970, -3.7745, -5.6535],\n", - " [-2.4079, -2.5133, -2.6187, -1.7255, -2.5811, -4.1919],\n", - " [-2.4079, -2.5133, -1.6831, -2.1322, -3.9457, -5.7231],\n", - " [-2.4079, -2.5133, -2.6187, -1.7217, -2.5694, -4.2367],\n", - " [-2.4079, -2.5133, -2.6187, -2.3851, -1.8631, -2.6357],\n", - " [-2.4079, -2.5133, -2.6187, -1.8214, -2.7686, -4.5519],\n", - " [-2.4079, -3.7416, -5.6608, -4.7769, -6.6561, -8.9587],\n", - " [-2.4079, -2.5133, -2.6187, -1.7237, -3.5743, -5.5587],\n", - " [-2.4079, -2.5133, -2.3752, -2.0307, -2.9406, -4.7268],\n", - " [-2.4079, -2.0634, -4.2850, -5.9182, -5.8641, -7.2660],\n", - " [-2.4079, -2.5133, -2.6187, -1.8374, -2.6719, -4.2820],\n", - " [-2.4079, -1.5197, -2.4202, -2.8877, -4.2097, -5.5393],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8552, -2.4426],\n", - " [-2.4079, -1.9956, -4.1984, -5.4933, -6.2990, -7.1216],\n", - " [-2.4079, -2.5133, -2.6187, -2.1938, -1.9548, -3.0694],\n", - " [-2.4079, -2.5133, -2.6187, -1.7391, -3.2433, -5.0264],\n", - " [-2.4079, -2.3993, -4.6864, -6.9890, -6.1291, -7.2427],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9457, -2.3724],\n", - " [-2.4079, -2.5133, -4.8159, -7.1185, -6.8879, -7.3837],\n", - " [-2.4079, -1.8194, -2.5649, -3.9830, -4.3976, -5.4860],\n", - " [-2.4079, -1.6833, -2.4180, -3.6445, -4.5441, -5.4743],\n", - " [-2.4079, -2.5133, -2.6187, -2.5276, -1.8388, -2.8562],\n", - " [-2.4079, -2.5133, -4.8159, -7.1185, -6.7606, -8.2814],\n", - " [-2.4079, -1.5771, -3.5436, -5.2005, -7.3599, -8.6443],\n", - " [-2.4079, -2.5133, -2.6187, -2.2337, -2.4992, -4.0652],\n", - " [-2.4079, -2.5133, -4.8159, -4.9213, -4.1778, -5.0958],\n", - " [-2.4079, -2.5133, -2.6187, -2.4357, -1.8557, -2.7608],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9457, -2.3083],\n", - " [-2.4079, -2.5133, -2.6187, -1.7059, -2.4908, -4.0385],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8914, -2.4011],\n", - " [-2.4079, -2.5133, -2.6187, -1.7028, -3.1149, -5.0296],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8078, -2.6513],\n", - " [-2.4079, -2.5133, -2.6187, -1.9822, -4.0985, -6.0157],\n", - " [-2.4079, -2.5133, -2.2469, -2.7863, -2.2024, -3.4795],\n", - " [-2.4079, -2.5133, -2.6187, -2.3612, -1.8704, -2.7687],\n", - " [-2.4079, -1.9453, -2.7340, -4.0539, -4.8744, -6.0019],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -6.3026, -7.0270],\n", - " [-2.4079, -1.9007, -2.7494, -4.8900, -4.3017, -5.3080],\n", - " [-2.4079, -2.5133, -2.5122, -4.0233, -4.3360, -5.4871],\n", - " [-2.4079, -2.5133, -2.3858, -1.8877, -3.9201, -5.7820],\n", - " [-2.4079, -2.5133, -1.7372, -1.9984, -3.4170, -5.0393],\n", - " [-2.4079, -1.5533, -2.6839, -2.6164, -3.8441, -4.8234],\n", - " [-2.4079, -2.5133, -2.6187, -1.7364, -2.3103, -3.5918],\n", - " [-2.2430, -2.3065, -2.1833, -2.3523, -2.0773, -2.7205],\n", - " [-2.4079, -2.5133, -2.3469, -1.7634, -3.5565, -5.6472],\n", - " [-2.4079, -1.8605, -2.7647, -4.0539, -5.2612, -6.4940],\n", - " [-2.4079, -2.5133, -2.6187, -2.0417, -1.9899, -3.0864],\n", - " [-2.4079, -1.4938, -3.1609, -4.4570, -5.2061, -6.3250],\n", - " [-2.4079, -1.4978, -3.2881, -4.5764, -5.4549, -6.7396],\n", - " [-2.4079, -2.5133, -2.6187, -2.0840, -1.9624, -2.9312],\n", - " [-2.4079, -1.5615, -2.9755, -4.3274, -4.8027, -5.7814],\n", - " [-2.4079, -2.5133, -1.7868, -2.8325, -3.6159, -4.7145],\n", - " [-2.4079, -2.5133, -2.6187, -2.2113, -2.1652, -3.5604],\n", - " [-2.4079, -2.5133, -2.6187, -1.7324, -3.0044, -4.7473],\n", - " [-2.4079, -2.5133, -2.6187, -1.8864, -2.1483, -3.5137],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8216, -2.4246],\n", - " [-2.4079, -2.0397, -1.9731, -2.3440, -2.5303, -3.7052],\n", - " [-2.4079, -2.5133, -2.6187, -1.7704, -2.3466, -3.8142],\n", - " [-1.6030, -2.7021, -4.1953, -4.3006, -3.5272, -4.8391],\n", - " [-2.4079, -2.5119, -2.6188, -4.9214, -5.0268, -6.8832],\n", - " [-2.4079, -2.5133, -2.1365, -2.0431, -2.3285, -2.5240],\n", - " [-2.4079, -2.5133, -2.6187, -1.8118, -2.1167, -2.9939],\n", - " [-2.4079, -2.5133, -1.6664, -3.3652, -3.4806, -5.0244],\n", - " [-2.4079, -2.5133, -2.1614, -2.0321, -2.1667, -3.3806],\n", - " [-2.4079, -2.5133, -2.6187, -1.7061, -3.3579, -5.3463],\n", - " [-2.4079, -2.5133, -2.6187, -1.9216, -4.0036, -5.9986],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.1207, -3.5108],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.3806, -1.9955],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.4956, -1.9734]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 2, 1, 2, 4, 4, 1, 0, 1, 4, 1, 3, 1, 2, 3, 3, 3, 0, 1, 2, 0, 3, 1, 4,\n", - " 1, 4, 3, 1, 5, 1, 1, 1, 1, 0, 0, 3, 0, 3, 4, 3, 5, 2, 4, 2, 4, 2, 2, 0,\n", - " 1, 2, 3, 2, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 1, 4, 3, 3, 4, 2, 3, 0, 0, 3,\n", - " 3, 2, 2, 4, 2, 4, 4, 5], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 59 8 3 0 1 0]\n", - " [ 38 83 46 34 4 0]\n", - " [ 2 10 64 30 4 0]\n", - " [ 1 0 22 106 15 1]\n", - " [ 2 0 4 25 65 5]\n", - " [ 0 1 1 5 10 31]]\n", + "[[ 60 14 0 0 0 0]\n", + " [ 23 162 3 11 5 1]\n", + " [ 9 44 7 36 13 2]\n", + " [ 7 17 0 105 14 0]\n", + " [ 3 9 0 17 51 20]\n", + " [ 2 0 0 1 7 37]]\n", "\n" ] }, @@ -1092,8 +641,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 3/5\n", - "Train loss: 1.279222, Train accuracy: 0.6000\n", - "Val loss: 2.259112, Val accuracy: 0.2645\n", + "Train loss: 1.251675, Train accuracy: 0.6206\n", + "Val loss: 2.030842, Val accuracy: 0.1901\n", "\n" ] }, @@ -1101,170 +650,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:02, 1.33it/s, accuracy=292, loss=1.17]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-2.4079, -2.5133, -1.6101, -2.3296, -3.8890, -5.6538],\n", - " [-2.4079, -3.7057, -5.2823, -5.3877, -4.4726, -5.4674],\n", - " [-2.4079, -1.5488, -3.4753, -5.7779, -4.9131, -6.2376],\n", - " ...,\n", - " [-2.3433, -4.6374, -4.7427, -4.8481, -7.1507, -9.4533],\n", - " [-2.4079, -2.5133, -2.6187, -2.0777, -2.1504, -3.6780],\n", - " [-2.4079, -2.5133, -1.6828, -2.8385, -4.9148, -7.0869]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 0, 1, 3, 1, 0, 4, 5, 5, 1, 4, 3, 2, 1, 0, 1, 2, 1, 2, 4, 0, 3, 0, 1,\n", - " 4, 3, 0, 1, 3, 3, 5, 1, 3, 1, 5, 3, 2, 4, 3, 3, 5, 5, 4, 1, 4, 2, 1, 3,\n", - " 3, 4, 1, 0, 2, 4, 2, 1, 4, 1, 3, 0, 4, 2, 2, 1, 1, 5, 4, 0, 1, 5, 1, 2,\n", - " 1, 1, 3, 4, 1, 4, 2, 2, 4, 1, 1, 1, 2, 4, 3, 3, 1, 2, 5, 1, 0, 2, 2, 2,\n", - " 3, 5, 0, 1, 1, 4, 0, 1, 2, 0, 3, 5, 2, 2, 5, 3, 2, 2, 2, 0, 4, 1, 0, 1,\n", - " 1, 0, 2, 2, 5, 4, 4, 1, 1, 2, 2, 1, 4, 3, 2, 0, 4, 3, 2, 2, 3, 2, 1, 3,\n", - " 1, 4, 1, 3, 2, 4, 3, 0, 3, 1, 2, 3, 2, 1, 2, 2, 1, 4, 5, 0, 1, 4, 0, 5,\n", - " 4, 4, 4, 3, 2, 3, 1, 1, 1, 0, 1, 2, 5, 3, 1, 5, 0, 2, 4, 5, 4, 1, 3, 4,\n", - " 3, 1, 1, 2, 3, 1, 3, 2], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -2.6187, -1.7088, -2.6687, -4.0928],\n", - " [-2.4079, -2.5133, -2.1350, -3.6345, -5.7648, -7.4803],\n", - " [-2.4079, -2.5133, -1.9665, -3.2365, -5.1329, -6.7075],\n", - " ...,\n", - " [-2.4079, -2.5133, -1.6598, -2.0881, -3.5310, -5.2570],\n", - " [-2.4079, -2.5133, -2.6187, -1.9118, -2.2301, -3.7819],\n", - " [-2.4079, -2.5133, -1.6756, -2.7839, -4.9683, -6.9759]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 2, 2, 3, 1, 2, 1, 3, 3, 0, 1, 5, 3, 0, 1, 1, 1, 1, 1, 4, 4, 3, 5, 4,\n", - " 1, 3, 2, 5, 1, 4, 5, 4, 1, 3, 5, 1, 1, 3, 4, 3, 3, 3, 1, 3, 3, 1, 2, 1,\n", - " 1, 2, 2, 0, 4, 0, 2, 1, 1, 1, 4, 1, 2, 1, 2, 4, 0, 4, 4, 2, 3, 4, 2, 1,\n", - " 2, 0, 3, 5, 5, 4, 1, 2, 1, 3, 4, 1, 4, 4, 4, 3, 0, 1, 1, 1, 5, 1, 3, 0,\n", - " 1, 3, 4, 1, 4, 0, 1, 0, 4, 1, 4, 3, 4, 2, 2, 5, 1, 2, 0, 1, 2, 3, 1, 1,\n", - " 3, 2, 4, 3, 1, 2, 1, 4, 3, 1, 1, 1, 3, 1, 3, 1, 1, 1, 0, 4, 0, 1, 0, 1,\n", - " 1, 1, 3, 3, 2, 1, 1, 3, 4, 3, 1, 2, 4, 3, 2, 3, 2, 2, 1, 5, 3, 1, 5, 3,\n", - " 5, 1, 3, 1, 2, 3, 3, 5, 0, 2, 1, 1, 2, 0, 3, 3, 1, 2, 0, 3, 1, 1, 3, 4,\n", - " 1, 4, 5, 1, 2, 3, 4, 2], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -2.6187, -2.3587, -4.6004, -6.6545],\n", - " [-2.4079, -2.5133, -1.7340, -3.0434, -5.1746, -7.4771],\n", - " [-2.4079, -2.5133, -2.6187, -3.8997, -5.6253, -7.2381],\n", - " ...,\n", - " [-2.4079, -2.2434, -2.3921, -3.6354, -5.2905, -6.8298],\n", - " [-2.4079, -1.5163, -3.3762, -5.2964, -5.2313, -6.3796],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8092, -2.6503]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 2, 0, 4, 3, 3, 1, 4, 3, 4, 1, 4, 4, 4, 4, 3, 1, 1, 1, 4, 1, 3, 1, 2,\n", - " 4, 1, 2, 0, 2, 1, 4, 3, 1, 3, 0, 1, 4, 2, 3, 3, 1, 1, 3, 3, 2, 4, 0, 5,\n", - " 1, 1, 4, 0, 1, 3, 1, 3, 2, 3, 3, 3, 2, 4, 3, 3, 1, 1, 3, 2, 4, 5, 1, 1,\n", - " 1, 0, 1, 5, 1, 1, 1, 4, 3, 1, 3, 2, 4, 5, 3, 1, 3, 1, 2, 3, 2, 4, 1, 0,\n", - " 5, 2, 1, 0, 0, 5, 5, 1, 3, 3, 1, 2, 0, 2, 3, 2, 5, 1, 1, 4, 3, 4, 3, 2,\n", - " 0, 4, 1, 1, 3, 1, 2, 1, 2, 0, 3, 3, 1, 1, 4, 3, 2, 0, 1, 0, 1, 1, 5, 3,\n", - " 1, 1, 1, 1, 4, 2, 2, 0, 2, 1, 1, 1, 3, 3, 1, 5, 2, 1, 3, 2, 4, 1, 2, 1,\n", - " 2, 3, 1, 3, 3, 4, 2, 4, 0, 0, 4, 1, 2, 4, 1, 1, 3, 1, 2, 0, 3, 1, 3, 3,\n", - " 4, 3, 3, 3, 0, 2, 1, 4], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:01<00:00, 3.69it/s, accuracy=488, loss=1.12]" + "Train progress: 100%|████████████████████████| 4/4 [00:09<00:00, 2.40s/it, accuracy=479, loss=1.22]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-2.4079, -1.5351, -2.9437, -4.8126, -5.1631, -6.4160],\n", - " [-2.4079, -2.5133, -1.5980, -2.4511, -4.4640, -6.7665],\n", - " [-2.4079, -2.5133, -2.6187, -1.7155, -3.5383, -5.6180],\n", - " [-2.4079, -2.5133, -2.6187, -1.7973, -3.7753, -6.0412],\n", - " [-2.4079, -1.6034, -3.5997, -5.9022, -5.8738, -7.6495],\n", - " [-2.4079, -1.7508, -3.8562, -6.1588, -5.2864, -6.2990],\n", - " [-2.4079, -2.5133, -2.0059, -1.9386, -3.9282, -5.9226],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8161, -2.9066],\n", - " [-2.4079, -1.5590, -2.9792, -4.5884, -5.5504, -6.9908],\n", - " [-2.4079, -1.7422, -2.8206, -4.2968, -5.4649, -6.7461],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8958, -2.9846],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -6.2093, -7.2920],\n", - " [-2.4079, -1.6373, -3.6656, -5.9682, -5.1699, -7.1719],\n", - " [-2.4079, -2.5133, -1.6038, -2.4714, -4.5774, -6.7572],\n", - " [-2.4079, -2.5133, -2.6187, -1.8305, -2.3529, -3.7258],\n", - " [-2.4079, -2.5133, -2.6187, -1.7095, -3.3362, -5.6388],\n", - " [-2.4079, -2.5133, -2.6187, -3.9018, -5.5688, -7.0261],\n", - " [-2.4079, -1.6340, -2.6386, -4.3807, -3.6725, -3.8538],\n", - " [-2.4079, -2.5133, -2.6187, -1.7259, -3.3350, -5.4232],\n", - " [-2.4079, -2.5133, -2.6187, -1.8936, -2.1953, -3.4917],\n", - " [-2.4079, -2.5133, -2.6187, -1.8959, -3.9609, -6.2635],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -6.5471, -8.4135],\n", - " [-2.4079, -2.5133, -1.6686, -2.6364, -4.3800, -6.1621],\n", - " [-2.4079, -2.5133, -2.6187, -1.7833, -3.7439, -6.0207],\n", - " [-2.4079, -2.5133, -2.6187, -1.7586, -3.2071, -5.2429],\n", - " [-2.4079, -2.5133, -2.6187, -1.7174, -2.8311, -4.7053],\n", - " [-2.4079, -2.5133, -2.6187, -1.9363, -4.0273, -6.2204],\n", - " [-2.4079, -1.8192, -2.7822, -5.0117, -5.1997, -6.5318],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.2248, -2.0412],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9412, -2.2220],\n", - " [-2.4079, -2.5133, -2.6187, -1.7045, -3.2477, -5.4557],\n", - " [-2.4079, -2.5133, -1.5975, -2.5998, -4.7404, -6.8536],\n", - " [-2.4079, -2.5133, -2.6187, -2.5204, -1.8521, -2.7974],\n", - " [-2.4079, -2.5133, -2.6187, -2.2296, -4.0455, -5.8271],\n", - " [-2.4079, -2.5133, -2.3254, -2.7699, -1.8836, -2.8561],\n", - " [-2.4079, -2.4201, -2.6311, -4.0007, -5.3041, -6.6227],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.3697, -2.0071],\n", - " [-2.4079, -1.6276, -3.6473, -5.9139, -5.6630, -7.8976],\n", - " [-2.4079, -2.5133, -2.6187, -1.8077, -3.7971, -5.9698],\n", - " [-2.4079, -2.5133, -2.6187, -1.7234, -3.0506, -4.9395],\n", - " [-2.4079, -2.5133, -2.0040, -3.3784, -5.4639, -7.3888],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.4431, -2.0006],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8109, -2.6589],\n", - " [-2.4079, -2.5133, -2.6187, -1.7026, -2.7019, -4.1800],\n", - " [-2.4079, -2.5133, -2.6187, -1.7163, -3.5423, -5.4583],\n", - " [-2.4079, -2.5133, -1.6084, -2.2693, -4.2291, -6.2048],\n", - " [-2.4079, -1.4978, -3.2882, -5.4210, -4.8103, -6.0853],\n", - " [-1.7437, -2.1158, -2.8784, -2.9837, -5.2863, -7.5889],\n", - " [-2.4079, -1.4917, -3.2133, -4.6744, -5.1688, -6.4417],\n", - " [-2.4079, -1.7483, -3.6532, -4.9769, -5.3990, -6.2733],\n", - " [-2.4079, -2.5133, -2.3851, -4.2293, -6.4575, -8.4195],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -5.0266, -7.3292],\n", - " [-2.4079, -2.5133, -1.6791, -2.1193, -4.0744, -5.9179],\n", - " [-2.4079, -1.9755, -2.3237, -3.5463, -4.7300, -5.9395],\n", - " [-2.4079, -2.5133, -2.6187, -2.0791, -4.2388, -6.3246],\n", - " [-2.4079, -2.2202, -2.6645, -3.9490, -5.7216, -7.1460],\n", - " [-2.4079, -1.5824, -3.5553, -5.3949, -5.5339, -6.9876],\n", - " [-2.4079, -2.5133, -4.8159, -7.1185, -6.2023, -7.9065],\n", - " [-2.4079, -2.5133, -2.6187, -1.8370, -3.1083, -4.7324],\n", - " [-2.4079, -2.5133, -2.5835, -2.5587, -2.8580, -4.5894],\n", - " [-2.4079, -2.2023, -2.6681, -4.1973, -4.8793, -6.2427],\n", - " [-2.4079, -1.6283, -3.5617, -5.3014, -5.3146, -6.4526],\n", - " [-2.4079, -2.5133, -2.6187, -2.3222, -1.8828, -2.7233],\n", - " [-2.4079, -2.5133, -1.6844, -2.1302, -4.1073, -6.3644],\n", - " [-2.4079, -2.5133, -2.0823, -2.2295, -4.4060, -6.7086],\n", - " [-2.4079, -2.4383, -2.6100, -3.8898, -5.5695, -7.0818],\n", - " [-2.4079, -2.5133, -2.6187, -1.9196, -2.4712, -4.0804],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -7.2238, -6.4155],\n", - " [-2.4079, -1.8095, -3.9442, -6.1063, -5.6757, -7.4787],\n", - " [-2.4079, -2.5133, -2.6187, -1.8698, -3.6788, -5.7943],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9132],\n", - " [-2.4079, -1.5705, -2.9633, -4.5212, -5.5589, -6.8884],\n", - " [-2.4079, -2.5133, -2.6187, -1.7675, -3.7063, -6.0089],\n", - " [-2.4079, -2.5133, -2.6187, -2.6412, -1.9243, -2.4200],\n", - " [-1.7959, -3.9680, -6.2706, -6.3759, -6.4019, -6.0537],\n", - " [-2.4079, -2.5133, -2.6187, -1.8437, -3.3103, -5.2227],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9138],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9133],\n", - " [-2.4079, -2.5133, -3.7943, -5.5055, -4.8354, -5.9207],\n", - " [-2.4079, -1.8049, -3.9374, -6.2400, -5.5606, -5.7180]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([1, 2, 1, 3, 0, 1, 3, 4, 1, 1, 4, 0, 1, 2, 0, 3, 2, 0, 3, 3, 1, 0, 2, 2,\n", - " 4, 3, 3, 0, 5, 4, 3, 2, 4, 3, 4, 1, 5, 0, 3, 3, 1, 5, 4, 3, 3, 1, 1, 0,\n", - " 1, 1, 1, 0, 3, 1, 3, 1, 1, 0, 3, 1, 1, 1, 4, 1, 3, 1, 3, 0, 1, 3, 5, 1,\n", - " 3, 4, 0, 3, 5, 5, 0, 1], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 62 7 1 1 0 0]\n", - " [ 29 115 29 30 2 0]\n", - " [ 8 16 63 21 2 0]\n", - " [ 1 0 12 125 6 1]\n", - " [ 0 0 2 13 80 6]\n", - " [ 0 0 0 1 4 43]]\n", + "[[ 63 11 0 0 0 0]\n", + " [ 13 174 7 9 2 0]\n", + " [ 6 31 28 38 8 0]\n", + " [ 4 2 13 118 5 1]\n", + " [ 7 4 0 18 60 11]\n", + " [ 0 2 0 0 9 36]]\n", "\n" ] }, @@ -1280,8 +680,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 4/5\n", - "Train loss: 1.170248, Train accuracy: 0.7176\n", - "Val loss: 1.596575, Val accuracy: 0.3884\n", + "Train loss: 1.168523, Train accuracy: 0.7044\n", + "Val loss: 1.750762, Val accuracy: 0.2562\n", "\n" ] }, @@ -1289,170 +689,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "Train progress: 25%|██████ | 1/4 [00:00<00:02, 1.11it/s, accuracy=305, loss=1.11]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-2.4079, -2.5133, -2.6187, -3.9344, -5.8187, -8.1213],\n", - " [-2.4079, -2.5133, -1.6202, -2.8567, -5.0261, -7.0107],\n", - " [-2.4079, -2.5133, -2.6187, -1.8406, -2.2038, -3.8148],\n", - " ...,\n", - " [-2.4079, -2.5133, -2.6187, -1.7557, -2.4511, -4.0959],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.2739, -2.0757],\n", - " [-2.4079, -2.5133, -2.6187, -1.9429, -4.0377, -6.3079]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 2, 3, 2, 1, 0, 3, 1, 1, 1, 5, 1, 2, 2, 4, 0, 1, 3, 1, 2, 0, 3, 1, 2,\n", - " 1, 0, 1, 5, 3, 0, 3, 1, 3, 1, 3, 0, 2, 3, 2, 1, 1, 3, 3, 3, 1, 5, 1, 4,\n", - " 3, 4, 3, 0, 4, 0, 4, 4, 1, 0, 1, 1, 3, 5, 3, 1, 1, 3, 0, 1, 1, 1, 0, 4,\n", - " 3, 3, 3, 0, 1, 2, 2, 2, 1, 3, 1, 1, 2, 3, 4, 1, 4, 2, 3, 2, 3, 2, 2, 1,\n", - " 4, 3, 0, 3, 1, 1, 0, 2, 4, 3, 4, 1, 3, 2, 3, 4, 0, 1, 2, 3, 1, 0, 1, 1,\n", - " 3, 2, 1, 1, 1, 5, 1, 4, 4, 4, 0, 2, 1, 4, 1, 5, 1, 5, 0, 1, 0, 1, 3, 3,\n", - " 1, 1, 2, 5, 1, 3, 4, 5, 4, 3, 2, 3, 1, 2, 2, 0, 3, 3, 1, 1, 5, 0, 2, 2,\n", - " 4, 2, 0, 1, 1, 2, 4, 1, 2, 1, 5, 4, 5, 1, 0, 2, 1, 4, 3, 4, 2, 4, 1, 0,\n", - " 1, 1, 2, 1, 2, 3, 5, 3], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -2.6187, -2.3586, -2.5366, -4.3780],\n", - " [-2.4079, -1.5728, -3.5337, -5.1647, -5.9502, -7.4470],\n", - " [-2.4079, -2.5133, -1.6966, -3.6799, -5.9825, -8.2850],\n", - " ...,\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -6.4248, -6.6167],\n", - " [-2.4079, -2.5133, -1.5996, -2.5449, -4.6734, -6.9760],\n", - " [-2.4079, -2.4316, -2.0745, -3.2725, -5.1413, -6.6721]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([4, 1, 2, 2, 3, 4, 5, 4, 3, 1, 3, 3, 5, 2, 5, 2, 2, 3, 3, 0, 4, 3, 1, 1,\n", - " 3, 4, 4, 5, 1, 3, 2, 2, 1, 1, 1, 3, 4, 3, 0, 3, 3, 4, 3, 4, 3, 3, 3, 0,\n", - " 2, 1, 4, 1, 2, 1, 3, 0, 4, 1, 1, 4, 5, 1, 4, 3, 1, 0, 1, 1, 2, 3, 1, 3,\n", - " 2, 4, 1, 1, 1, 0, 3, 5, 1, 4, 1, 1, 2, 3, 5, 2, 3, 3, 0, 3, 2, 1, 4, 1,\n", - " 1, 0, 0, 1, 1, 5, 5, 2, 1, 4, 2, 4, 2, 2, 1, 2, 5, 3, 4, 2, 1, 1, 2, 2,\n", - " 0, 3, 3, 1, 4, 4, 0, 1, 5, 4, 1, 3, 4, 4, 2, 0, 4, 1, 3, 1, 3, 1, 1, 1,\n", - " 1, 5, 1, 1, 3, 4, 1, 3, 1, 1, 0, 4, 1, 2, 2, 1, 1, 3, 1, 2, 1, 4, 3, 4,\n", - " 2, 5, 0, 1, 3, 2, 3, 1, 1, 4, 2, 3, 2, 0, 1, 0, 5, 0, 3, 2, 2, 3, 5, 0,\n", - " 4, 1, 4, 1, 2, 0, 2, 2], device='cuda:0')\n", - "tensor([[-2.4079, -2.5133, -2.6187, -1.7190, -3.2964, -5.5990],\n", - " [-2.4079, -2.5133, -2.6187, -1.9767, -2.5392, -4.6580],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8137, -2.6181],\n", - " ...,\n", - " [-2.4079, -1.7365, -2.8239, -4.2913, -5.4743, -6.8122],\n", - " [-2.4079, -1.4925, -3.1771, -5.2636, -5.4136, -6.6655],\n", - " [-2.4079, -1.5755, -3.5400, -5.0286, -6.1760, -7.5651]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([3, 3, 4, 2, 0, 3, 1, 1, 0, 0, 1, 1, 1, 4, 5, 3, 5, 0, 5, 0, 2, 2, 3, 1,\n", - " 4, 3, 2, 3, 2, 3, 1, 1, 1, 2, 3, 4, 1, 3, 1, 1, 4, 1, 3, 5, 3, 3, 2, 3,\n", - " 1, 2, 1, 3, 4, 3, 3, 3, 1, 1, 3, 1, 4, 0, 2, 1, 0, 2, 4, 3, 1, 0, 1, 3,\n", - " 4, 4, 2, 1, 4, 3, 3, 5, 2, 3, 5, 1, 3, 1, 0, 4, 1, 4, 3, 1, 3, 0, 3, 3,\n", - " 1, 4, 0, 4, 4, 1, 0, 1, 2, 3, 3, 5, 0, 1, 3, 1, 3, 1, 4, 3, 5, 3, 3, 2,\n", - " 1, 4, 1, 3, 2, 2, 1, 1, 3, 3, 1, 3, 2, 4, 1, 3, 1, 0, 3, 3, 0, 1, 3, 1,\n", - " 1, 1, 2, 3, 2, 4, 2, 4, 2, 2, 1, 1, 3, 3, 4, 2, 1, 1, 1, 3, 5, 1, 2, 2,\n", - " 4, 4, 3, 4, 1, 2, 0, 5, 3, 0, 0, 1, 1, 0, 2, 2, 0, 3, 1, 0, 0, 0, 1, 4,\n", - " 1, 1, 1, 4, 3, 1, 1, 1], device='cuda:0')\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Train progress: 100%|████████████████████████| 4/4 [00:01<00:00, 3.14it/s, accuracy=525, loss=1.13]" + "Train progress: 100%|████████████████████████| 4/4 [00:08<00:00, 2.23s/it, accuracy=534, loss=1.12]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[-2.4079, -2.5133, -2.3843, -2.7590, -5.0616, -7.3641],\n", - " [-2.4079, -1.4922, -3.2309, -5.5335, -5.2030, -6.4739],\n", - " [-2.4079, -1.5695, -2.9646, -5.2672, -4.4221, -5.5711],\n", - " [-2.4079, -1.4918, -3.1957, -5.4983, -5.6037, -7.3999],\n", - " [-2.4079, -2.5133, -1.7193, -2.1326, -4.1510, -6.4536],\n", - " [-2.4079, -3.6930, -5.3418, -5.4472, -4.6986, -6.0549],\n", - " [-2.4079, -2.5133, -1.9900, -1.8425, -3.6143, -5.9169],\n", - " [-2.4079, -2.5133, -2.6187, -2.6963, -1.8397, -2.6284],\n", - " [-2.4079, -1.4967, -3.2803, -4.7091, -5.9607, -7.2425],\n", - " [-2.4079, -2.5133, -2.3606, -4.0126, -6.1848, -8.0438],\n", - " [-2.4079, -2.5133, -2.6187, -1.9447, -2.0538, -3.6498],\n", - " [-2.4079, -2.5133, -2.6187, -2.2005, -2.1588, -4.2106],\n", - " [-2.4079, -1.5362, -3.0169, -4.3081, -5.9160, -7.2833],\n", - " [-2.4079, -2.5133, -2.6187, -2.2347, -2.6431, -4.8954],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.5862, -2.0420],\n", - " [-2.4079, -2.5133, -2.6187, -1.8866, -2.8726, -5.1460],\n", - " [-2.4079, -1.8637, -4.0215, -6.3241, -5.9832, -8.1897],\n", - " [-2.4079, -2.5133, -2.6187, -2.6263, -2.3908, -4.5020],\n", - " [-2.4079, -2.5133, -2.4253, -1.9231, -2.3095, -3.9349],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8155, -2.7202],\n", - " [-2.4079, -1.4917, -3.2084, -5.4502, -5.6243, -7.0936],\n", - " [-2.4079, -2.5133, -2.3403, -1.7980, -3.7169, -6.0195],\n", - " [-2.4079, -2.5133, -2.6187, -2.5921, -2.0611, -3.6138],\n", - " [-2.4079, -1.5036, -3.3216, -5.6242, -5.7295, -7.0136],\n", - " [-2.4079, -2.5133, -2.4601, -4.1053, -6.2688, -8.0280],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9775],\n", - " [-2.4079, -2.5133, -2.6187, -1.7815, -3.7398, -6.0424],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9667],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9689],\n", - " [-2.4079, -1.4923, -3.1820, -5.4779, -5.5908, -6.8860],\n", - " [-2.4079, -2.5133, -1.5972, -2.4674, -4.5111, -6.5279],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.9978, -3.6921],\n", - " [-2.4079, -1.5019, -2.7133, -3.9678, -5.2833, -6.9595],\n", - " [-2.4079, -1.5016, -3.1116, -5.4141, -5.5195, -6.9415],\n", - " [-2.4079, -2.5133, -2.1178, -2.4933, -2.2127, -3.7022],\n", - " [-2.4079, -2.5133, -2.6187, -1.8549, -3.8890, -6.1916],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9294],\n", - " [-1.4080, -3.2594, -5.0660, -5.7608, -4.8526, -5.7301],\n", - " [-2.4079, -2.5133, -1.8264, -2.8954, -4.5606, -6.8321],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.7979, -1.9693],\n", - " [-2.4079, -1.4962, -3.2765, -5.5791, -5.6845, -7.3091],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9569],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -1.8320, -2.8165],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.0249, -3.6283],\n", - " [-2.4079, -1.5432, -3.0043, -4.3851, -5.7313, -7.0609],\n", - " [-2.4079, -2.5133, -4.8159, -7.1185, -6.2972, -7.6538],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.1536, -2.2487],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9442],\n", - " [-2.4079, -2.5133, -2.3668, -1.7445, -3.3940, -5.4882],\n", - " [-2.4079, -2.5133, -2.6187, -1.9626, -2.1001, -3.7886],\n", - " [-2.4079, -1.8449, -2.7711, -4.0641, -5.6629, -7.1123],\n", - " [-2.4079, -2.5133, -1.8560, -2.1446, -4.2419, -6.5445],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -2.8294, -1.9186],\n", - " [-2.4079, -2.5133, -2.6187, -2.7240, -5.0266, -7.3292],\n", - " [-2.4079, -2.5133, -2.1486, -1.8715, -3.2552, -5.5578],\n", - " [-2.4079, -2.5133, -2.6187, -1.9572, -2.2034, -3.9757],\n", - " [-2.4079, -2.5133, -2.6187, -2.2041, -4.4063, -6.7088],\n", - " [-2.4079, -2.5133, -2.6187, -1.7973, -3.7752, -6.0778],\n", - " [-2.4079, -1.8423, -3.9913, -6.2939, -5.8286, -7.9417],\n", - " [-2.4079, -2.5133, -2.6187, -2.1555, -2.4151, -4.6143],\n", - " [-2.4079, -2.5133, -2.6187, -1.7829, -3.7432, -6.0458],\n", - " [-2.4079, -2.5133, -2.6187, -1.7078, -3.4939, -5.7965],\n", - " [-2.4079, -1.9230, -2.7415, -4.3418, -6.4772, -7.8941],\n", - " [-2.4079, -1.4942, -3.2583, -4.6378, -6.6200, -7.9050],\n", - " [-2.4079, -2.5133, -1.7778, -2.0695, -4.0748, -6.3773],\n", - " [-2.4079, -2.5133, -1.6987, -2.3602, -4.5050, -6.8076],\n", - " [-2.4079, -2.5133, -1.6499, -3.1073, -5.4099, -7.7125],\n", - " [-2.4079, -4.7105, -6.0639, -7.4718, -6.7407, -7.9631],\n", - " [-2.4079, -1.5752, -3.5393, -4.8816, -6.3166, -7.6484],\n", - " [-2.4079, -2.5133, -2.4020, -1.8372, -3.8242, -6.1268],\n", - " [-2.4079, -4.7105, -7.0131, -7.1185, -7.2238, -9.5264],\n", - " [-2.4079, -2.5133, -2.6187, -2.5333, -1.8815, -3.3387],\n", - " [-2.4079, -2.0470, -2.7040, -2.8093, -5.1119, -7.4145],\n", - " [-2.4079, -2.5133, -4.8159, -7.1185, -6.2481, -7.7248],\n", - " [-2.4079, -1.6060, -3.6049, -5.6180, -6.0580, -7.6977],\n", - " [-2.4079, -2.5133, -2.6187, -2.0471, -2.6937, -4.8436],\n", - " [-1.8604, -4.0542, -6.3567, -6.4621, -5.6062, -6.8197],\n", - " [-2.4079, -2.5133, -1.6041, -2.2390, -4.1112, -6.1669],\n", - " [-2.4079, -2.5133, -2.1163, -2.0823, -4.2010, -6.5036],\n", - " [-2.4079, -2.5133, -1.7952, -3.0823, -5.1772, -7.1941]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([2, 1, 1, 1, 2, 1, 3, 4, 1, 2, 4, 4, 1, 4, 5, 4, 1, 4, 4, 4, 1, 3, 4, 1,\n", - " 2, 5, 3, 5, 5, 1, 1, 4, 1, 1, 4, 3, 5, 0, 2, 5, 1, 5, 4, 4, 1, 0, 5, 5,\n", - " 3, 4, 1, 2, 5, 2, 1, 4, 3, 1, 1, 4, 3, 3, 1, 1, 2, 2, 2, 0, 1, 3, 0, 4,\n", - " 3, 1, 1, 1, 0, 2, 3, 2], device='cuda:0')\n", "\n", "Train Confusion matrix :\n", - "[[ 66 2 2 1 0 0]\n", - " [ 31 128 20 24 2 0]\n", - " [ 10 12 72 15 1 0]\n", - " [ 0 1 8 133 2 1]\n", - " [ 1 0 2 14 83 1]\n", - " [ 0 0 0 0 5 43]]\n", + "[[ 69 5 0 0 0 0]\n", + " [ 15 178 7 4 1 0]\n", + " [ 10 27 42 27 5 0]\n", + " [ 1 1 11 124 6 0]\n", + " [ 2 3 2 8 82 3]\n", + " [ 0 0 0 0 8 39]]\n", "\n" ] }, @@ -1468,8 +719,8 @@ "output_type": "stream", "text": [ "[INFO] EPOCH: 5/5\n", - "Train loss: 1.106344, Train accuracy: 0.7721\n", - "Val loss: 1.767812, Val accuracy: 0.4132\n", + "Train loss: 1.091701, Train accuracy: 0.7853\n", + "Val loss: 1.488901, Val accuracy: 0.3554\n", "\n", "[INFO] Network evaluation ...\n" ] @@ -1478,7 +729,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Test progress: 100%|███████████████████████████████████████| 2/2 [00:00<00:00, 2.56it/s, loss=3.22]" + "Test progress: 100%|███████████████████████████████████████| 2/2 [00:01<00:00, 1.36it/s, loss=1.63]" ] }, { @@ -1487,24 +738,24 @@ "text": [ "\n", "Confusion matrix :\n", - "[[24 1 0 1 0 0]\n", - " [29 25 5 2 0 0]\n", - " [ 3 7 14 10 0 0]\n", - " [ 5 5 10 13 5 1]\n", - " [ 2 4 10 8 2 3]\n", - " [ 2 0 2 2 1 5]]\n", + "[[22 0 0 0 0 0]\n", + " [33 22 4 1 0 0]\n", + " [ 5 16 4 6 2 0]\n", + " [ 8 9 3 14 6 2]\n", + " [ 2 1 3 9 12 3]\n", + " [ 2 0 0 0 8 4]]\n", "\n", - "MS: 0.0690\n", + "MS: 0.1212\n", "\n", - "QWK: 0.5978\n", + "QWK: 0.6697\n", "\n", - "MAE: 0.8905\n", + "MAE: 0.8806\n", "\n", - "CCR: 0.4129\n", + "CCR: 0.3881\n", "\n", - "1-off: 0.8060\n", + "1-off: 0.8259\n", "\n", - "[INFO] Total training time: 10.62s\n" + "[INFO] Total training time: 55.39s\n" ] }, { @@ -1579,7 +830,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.19" }, "orig_nbformat": 4, "vscode": {