Exporting to ONNX; a painful road fraught with poor late-on-a-Friday decisions. #1144
MathijsdeBoer
started this conversation in
Show and tell
Replies: 1 comment
-
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
This is a story of a poor PhD student who's sitting in the office late on a friday evening. "Can we serve this trained model up?" they had asked a few days ago. "Sure", the student had said, "We can package the trained model in a dockerfile, and run the prediction on some bound directories!"
"But won't that require a GPU?", the response scared the student, ever so slightly. "Well yeah, but running on a CPU will probably take a very long time. Our GPU server with slightly old and underpowered hardware already takes about 5-7 minutes to predict on one sample!" "So no way then?" "Not with just the nnUNet code, no." A short silence fell.
"I suppose you could try exporting to ONNX, I'm sure there's a CPU runtime for that?"
Just give me the solution, old man!
Fine.
Minimal-ish example:
There's been a few discussions on this topic, however the usual response has been "look at the code" (Which is fair enough, the maintainers are not obliged to support any and all potential weird things people are trying with their code). I've been doing just that, and I've hit a bit of a snag. Because I am only interested in getting the actual model exported, I'm disregarding acquiring the pre- and postprocessing steps here for brevity.
To export to ONNX with Pytorch, you have to call:
Simple enough, acquire the
nn.Module
object, a random input and feed them both to theexport
function.I've gone through the code, and have ended up acquiring the
trainer
andparams
objects with:I reckon the model parameters are loaded with:
Finding the input shape for the model isn't too difficult, as it's in
trainer
, too:Generating a random input array with:
Getting the
nn.Module
object seems to work like this:Now,
onnx.export()
will run the inputdummy_input
through the model, tracing each step as it goes to build the graph.Per the PyTorch documentation, this line should do:
Unfortunately, that errors out with an impressive stacktrace:
stacktrace
Ok, maybe we made a mistake?
Let's try feeding the patch directly to the model:
stacktrace
Nope, same error. This is strange, because this time there's no weird
torch.onnx
orjit
stuff that might get in our way.I've stepped through the code to see how the normal prediction command,
nnUNet_predict
works.First it calls:
nnunet/inference/predict_simple.py:main()
That one calls:
nnunet/inference/predict.py:predict_from_folder()
to:
nnunet/inference/predict.py:predict_cases()
to:
nnunet/training/network_training/nnUNetTrainer.py:nnUNetTrainer.predict_preprocessed_data_return_seg_and_softmax()
to:
nnunet/network_architecture/neural_network.py:SegmentationNetwork.predict_3D()
Here we have a choice, tiled or untiled. As we're just feeding a single patch through it makes sense to follow the untiled path:
nnunet/network_architecture/neural_network.py:SegmentationNetwork._internal_predict_3D_3Dconv()
Are we there yet?
This one calls:
nnunet/network_architecture/neural_network.py:SegmentationNetwork._internal_maybe_mirror_and_pred_3D()
Which contains the line:
A-ha!
But wait, that just calls the same function that we called with
model(dummy_input)
again withself(x)
!Because at this point we're living on a Friday evening at about 18:15, my brain cells are a little boiled. But even a boiled stew is right sometimes, or whatever the saying was.
Could it be?
Could it be that I forgot to send my tensor to the GPU?
I try:
[You can imagine your favorite expletive and paste it here]
I quickly change the command to:
So I plug that into the earlier
onnx.export()
call:Well now, that seems to have worked!
There is indeed an
.onnx
file written out, and there's only the minimal amount of scary red text.Note that I haven't tried to load the model in an actual ONNX runtime yet, it's currently nearing 19:00 on the same Friday evening, and I am ready for the weekend.
I hope my painful day helps some other people out, and that this may serve as a cautionary tale... Don't export nnUNets on a Friday evening after 18:00... They grow evil, or something like that.
I would also like to mention that I did this all to myself, none of the people referred to in the intro actually made me do this, I just really wanted to figure this puzzle out.
Beta Was this translation helpful? Give feedback.
All reactions