-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
53 lines (37 loc) · 1.19 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# All resnet50 pretrained models require 3 x 224 x 224 sizes
# Visto che il data set è diviso in cartelle, noi facciamo che prendiamo una cartella e prendiamo x% delle prime immagini
# Per ogni classe, creare le matrici di training con valore + label -----> assegnare noi le label
# fare shuffle una volta preso tutto il training set
# based
import os # paths
import numpy as np # arrays
import matplotlib.pyplot as plt
#from osgeo import gdal from tif to rgb
from tqdm import tqdm # progress bar
# torch
import torch
from torchvision.models import resnet50, ResNet50_Weights
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
torch.set_default_device(DEVICE)
print(f"Using {DEVICE} device")
def main () :
folders_path = '../dataset/'
classes = os.listdir(folders_path)
dict_class = {}
for i in range(len(classes)) :
dict_class.update({classes[i] : i + 1})
num_classes = len(classes)
bands = 13
print(classes)
print(dict_class)
# RESNET50
model = resnet50(weights = ResNet50_Weights.DEFAULT)
def taskb ():
if __name__ == '__main__':
main()