Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trace fails sanity checks during export to lite for android : torch.jit.trace() #64

Open
rounakskm opened this issue Jul 29, 2022 · 0 comments

Comments

@rounakskm
Copy link

rounakskm commented Jul 29, 2022

Code:

import cv2
import math
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
from torch.utils.mobile_optimizer import optimize_for_mobile

device=None
seed=0
lr=1e-4
weight_decay=1e-4
lr_step_period=15
model_name="r2plus1d_18"
weights = "file_path.pt"

np.random.seed(seed)
torch.manual_seed(seed)

if device is None:
    print("in if cond")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.video.__dict__[model_name](pretrained=True)

model.fc = torch.nn.Linear(model.fc.in_features, 1)
model.fc.bias.data[0] = 55.6
if device.type == "cuda":
    model = torch.nn.DataParallel(model)
model.to(device)

if weights is not None:
    checkpoint = torch.load(weights)
    model.load_state_dict(checkpoint['state_dict'])

model.eval()

summary(model, (3, 1, 112, 112))

example = torch.rand((1, 3, 1, 112, 112))
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("path_to_save.ptl")

The sanity check problem can be avoided by specifying check_trace= False
But in that case the save_for_lite_interpreter fails:

RuntimeError:
Could not export Python function call 'Scatter'. Remove calls to Python functions before export.

Using pytorch documentation to export the EF prediction model for Android

Is this as issue in the input shape? What should be the exact input shape for the first layer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant