You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to ask about guidelines regarding using timm models as a boilerplate for retrieving the architecture of the selected model and then training it from scratch. Let's assume that I use the following code to retrieve ResNet34:
model = timm.create_model('resnet34', pretrained=False, num_classes=10)
Afterwards, I would like to create my own functions to split data into batches, load optimizer and perform the training and evaluations. Let's assume that I would like to use the following code for each epoch (it's an abstraction derived from some basing operations like placing the data on the GPU):
def train(net, l_data, optimizer):
criterion = nn.CrossEntropyLoss()
running_loss = 0.0
total_correct = 0
total = 0
net.train()
for _, dic in enumerate(l_data):
data = dic['data']
target = dic['target']
optimizer.zero_grad()
if isinstance(data, list):
data = data[0]
outputs = net(data)
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == target).float().sum()
loss = criterion(outputs, target)
running_loss += loss.item()
total_correct += correct
total += target.size(0)
optimizer.zero_grad()
net.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
net.zero_grad()
loss = running_loss / len(self.traindata)
accuracy = total_correct / total
return loss, accuracy
As far as I've seen - it can work fine since the model class also defines a fully functional .forward method that encapsulates all the logic necessary to perform training. However, since I could not find a direct answer to that question in the documentation (and most use cases that I've seen use official training scripts), I am not entirely sure whether I am not running here into an issue that may pop-out later and comes from my rather slim understanding of the library.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Dear community,
I would like to ask about guidelines regarding using timm models as a boilerplate for retrieving the architecture of the selected model and then training it from scratch. Let's assume that I use the following code to retrieve ResNet34:
model = timm.create_model('resnet34', pretrained=False, num_classes=10)
Afterwards, I would like to create my own functions to split data into batches, load optimizer and perform the training and evaluations. Let's assume that I would like to use the following code for each epoch (it's an abstraction derived from some basing operations like placing the data on the GPU):
As far as I've seen - it can work fine since the model class also defines a fully functional .forward method that encapsulates all the logic necessary to perform training. However, since I could not find a direct answer to that question in the documentation (and most use cases that I've seen use official training scripts), I am not entirely sure whether I am not running here into an issue that may pop-out later and comes from my rather slim understanding of the library.
I would be grateful for all of your help.
Beta Was this translation helpful? Give feedback.
All reactions