Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConstraintViolationError #6782

Open
adonnini opened this issue Nov 12, 2024 · 16 comments
Open

ConstraintViolationError #6782

adonnini opened this issue Nov 12, 2024 · 16 comments
Labels
module: exir Issues related to Export IR triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@adonnini
Copy link

Hi,
I just upgraded to executorch 0.4 and ran my code which previously failed as described in
#1350

Now it fails with
ConstraintViolationError

Please find the error log below.

Please let me know if you need additional information.

Thanks

ERROR LOG

Traceback (most recent call last):
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 670, in _export_to_aten_ir
    produce_guards_callback(gm)
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1655, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 287, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 270, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3788, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: L['args'][0][1].size()[1] = 12 is not equal to L['args'][0][0].size()[1] = 7

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/train-minimum.py", line 443, in <module>
    pre_autograd_aten_dialect = torch.export._trace._export(
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1880, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 672, in _export_to_aten_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: L['args'][0][1].size()[1] = 12 is not equal to L['args'][0][0].size()[1] = 7
@GregoryComer GregoryComer added the module: exir Issues related to Export IR label Nov 12, 2024
@GregoryComer GregoryComer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed module: exir Issues related to Export IR labels Nov 12, 2024
@GregoryComer
Copy link
Member

Hi, @adonnini. Can you provide the dynamic shape constraints that you pass into export, as well as information about the model that you are exporting? Thanks!

@GregoryComer GregoryComer added need-user-input The issue needs more information from the reporter before moving forward and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 12, 2024
@adonnini
Copy link
Author

adonnini commented Nov 12, 2024

@GregoryComer thanks for the follow-up

You can find the entire model with instructions on how to run it here

https://github.com/adonnini/adonnini-trajectory-prediction-transformers-masterContextQ/tree/main

Here is the line where I pass the model and the inputs

        dim1_x = Dim("dim1_x", min=1, max=100000)
        dynamic_shapes = {"enc_input": {1: dim1_x}, "dec_input": {1: dim1_x}, "dec_source_mask": {1: dim1_x}, "dec_target_mask": {1: dim1_x}}

         pre_autograd_aten_dialect = torch.export._trace._export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            pre_dispatch=True,
            strict=False
        )
        # pre_autograd_aten_dialect = capture_pre_autograd_graph(m,
        #                                                        (enc_input, dec_input, dec_source_mask, dec_target_mask), dynamic_shapes=dynamic_shapes)
        aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,
                                               (enc_input, dec_input, dec_source_mask, dec_target_mask), strict=False)

where
enc_input, dec_input, dec_source_mask, dec_target_mask
are defined earlier in
train-minimum.py

Thanks

@GregoryComer GregoryComer added module: exir Issues related to Export IR triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed need-user-input The issue needs more information from the reporter before moving forward labels Nov 13, 2024
@larryliu0820
Copy link
Contributor

L['args'][0][1].size()[1] = 12 is not equal to L['args'][0][0].size()[1] = 7

According to this error message, it seems like one of your input tensor is having an unexpected size. Can you please print out the tensor shapes for (enc_input, dec_input, dec_source_mask, dec_target_mask)?

It's a bit hard to tell from reading your code https://github.com/adonnini/adonnini-trajectory-prediction-transformers-masterContextQ/tree/main

@adonnini
Copy link
Author

adonnini commented Nov 13, 2024

@larryliu0820 , Thanks for getting back to me
I added print statements for the four variables as you requested (also to train-minimum.py in the repository).

The printout of the four variables themselves is very big. I thought that the information about the shapes is what your really wanted given your question.

Below, you will find the output of the four print statements of the shapes of the four variables.

Please note that the model runs successfully when excluding the executorch related code. Even with the executorch code, the frist epoch completes successfully.

By the way, the fact that the shapes of dim 1 of enc_input and dec_input is different is not an issue for the model. As I said, the train_minimum.py runs successfully without the executorch code. You can try it yourself.

Thanks

 - train_minimum - Lowering the Whole Module - enc_input.shape -  torch.Size([27, 7, 2])
 - train_minimum - Lowering the Whole Module - dec_input.shape -  torch.Size([27, 12, 3])
 - train_minimum - Lowering the Whole Module - dec_source_mask.shape -  torch.Size([27, 1, 7])
 - train_minimum - Lowering the Whole Module - dec_target_mask.shape -  torch.Size([27, 12, 12])

@larryliu0820
Copy link
Contributor

@angelayi do you have any suggestions? Is it worth an issue in pytorch/pytorch?

@angelayi
Copy link
Contributor

By specifying dynamic_shapes={"enc_input": {1: dim1_x}, "dec_input": {1: dim1_x}} where dim1_x is the same for both enc_input and dec_input, you're claiming that enc_input.shape[1] == dec_input.shape[1]. However, your input shapes are different for the two enc_input.shape[1] = 7 and dec_input.shape[1] = 12, which results in the ConstraintViolationError.

To fix this you can make each of the dims different, or you can try a new thing, Dim.AUTO: dynamic_shapes = {"enc_input": {1: Dim.AUTO}, "dec_input": {1: Dim.AUTO}, "dec_source_mask": {1: Dim.AUTO}, "dec_target_mask": {1: Dim.AUTO}}

@adonnini
Copy link
Author

adonnini commented Nov 13, 2024

@angelayi Thanks for your suggestion. After making the change you suggested (I should have seen the error, Thanks), execution failes producing the error log reported below.

@larryliu0820 I think the ball may be back in the executorch court.

FYI, below you will also find the executorch related code (in train-minimum.py)

ERROR LOG

Traceback (most recent call last):
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/train-minimum.py", line 461, in <module>
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/__init__.py", line 261, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'torch.export.exported_program.ExportedProgram'>.

CODE

        dim1_x = Dim("dim1_x", min=1, max=100000)

        dynamic_shapes = {"enc_input": {1: Dim.AUTO}, "dec_input": {1: Dim.AUTO}, "dec_source_mask": {1: Dim.AUTO}, "dec_target_mask": {1: Dim.AUTO}}

        print(" - train_minimum - Lowering the Whole Module - enc_input - ", enc_input)
        print(" - train_minimum - Lowering the Whole Module - dec_input - ", dec_input)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask - ", dec_source_mask)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask - ", dec_target_mask)

        print(" - train_minimum - Lowering the Whole Module - enc_input.shape - ", enc_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_input.shape - ", dec_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask.shape - ", dec_source_mask.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask.shape - ", dec_target_mask.shape)

        pre_autograd_aten_dialect = torch.export._trace._export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            pre_dispatch=True,
            strict=False
        )

        aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,
                                               (enc_input, dec_input, dec_source_mask, dec_target_mask), strict=False)

        print(" - train_minimum - Lowering the Whole Module - ATen Dialect Graph")
        print(" - train_minimum - Lowering the Whole Module - aten_dialect - ", aten_dialect)

        edge_program: EdgeProgramManager = to_edge(aten_dialect)
        to_be_lowered_module = edge_program.exported_program()

        from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

        lowered_module = edge_program.to_backend(XnnpackPartitioner())

        print(" - train_minimum - Lowering the Whole Module - lowered_module - ", lowered_module)

        # Serialize and save it to a file
        save_path = save_path = "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/trajectory-prediction-transformers-master/models/tpt_delegate.pte"
        with open(save_path, "wb") as f:
            f.write(lowered_module.to_executorch().buffer)

@larryliu0820
Copy link
Contributor

pre_autograd_aten_dialect = torch.export._trace._export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            pre_dispatch=True,
            strict=False
        )

        aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,
                                               (enc_input, dec_input, dec_source_mask, dec_target_mask), strict=False)

Can you use the new torch.export.export API?

ep = torch.export.export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            strict=False
)

and the rest should be the same

@adonnini
Copy link
Author

@larryliu0820 , I made the change you suggested. Execution failed with the same error. Please find below the error log and the code change I made. Did I make a mistake in making the change?
Thanks

ERROR LOG

Traceback (most recent call last):
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/train-minimum.py", line 475, in <module>
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/export/__init__.py", line 261, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'torch.export.exported_program.ExportedProgram'>.

CODE

        pre_autograd_aten_dialect = torch.export.export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            strict=False
        )

@larryliu0820
Copy link
Contributor

What I meant is you can skip this line:

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect,

And do to_edge() to pre_autograd_aten_dialect

@adonnini
Copy link
Author

@larryliu0820 Sorry. I re-ran the code with the corrected code. There was progress (I think). Execution failed.
You will find the error log and updated code below. Hopefully, I did not make any mistakes this time.
Thanks

ERROR LOG

Traceback (most recent call last):
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/train-minimum.py", line 511, in <module>
    f.write(lowered_module.to_executorch().buffer)
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/executorch/exir/program/_program.py", line 1322, in to_executorch
    new_gm_res = p(new_gm)
  File "/home/adonnini1/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_base.py", line 41, in __call__
    res = self.call(graph_module)
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/executorch/exir/passes/replace_view_copy_with_view_pass.py", line 289, in call
    node.meta["spec"] = _ViewSpec(base.meta["spec"], shape)
  File "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/adonnini-trajectory-prediction-transformers-masterContextQ/executorch/exir/passes/replace_view_copy_with_view_pass.py", line 156, in __init__
    raise Exception(
Exception: _ViewSpec is incompatible with its base on creation.  It has shape_dynamism=0, but its base has shape_dynamism=1.

CODE

 
        dim1_x = Dim("dim1_x", min=1, max=100000)

        dynamic_shapes = {"enc_input": {1: Dim.AUTO}, "dec_input": {1: Dim.AUTO}, "dec_source_mask": {1: Dim.AUTO}, "dec_target_mask": {1: Dim.AUTO}}

        print(" - train_minimum - Lowering the Whole Module - enc_input - ", enc_input)
        print(" - train_minimum - Lowering the Whole Module - dec_input - ", dec_input)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask - ", dec_source_mask)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask - ", dec_target_mask)

        print(" - train_minimum - Lowering the Whole Module - enc_input.shape - ", enc_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_input.shape - ", dec_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask.shape - ", dec_source_mask.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask.shape - ", dec_target_mask.shape)

       pre_autograd_aten_dialect = torch.export.export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            strict=False
        )

        edge_program: EdgeProgramManager = to_edge(pre_autograd_aten_dialect)

        to_be_lowered_module = edge_program.exported_program()

        from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

        print(" - train_minimum - Lowering the Whole Module - lowered_module - ", lowered_module)

        # Serialize and save it to a file
        save_path = save_path = "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/trajectory-prediction-transformers-master/models/tpt_delegate.pte"
        with open(save_path, "wb") as f:
            f.write(lowered_module.to_executorch().buffer)

@larryliu0820
Copy link
Contributor

@metascroy This goes into ReplaceViewCopyWithViewPass. Any thoughts on why this failed?

At the same time @adonnini we can also skip this pass - change your code to:

f.write(lowered_module.to_executorch(ExecutorchBackendConfig(remove_view_copy=False)).buffer)

@adonnini
Copy link
Author

adonnini commented Nov 13, 2024

@larryliu0820 I made the change. Please find below update code and (big) error log.
Thanks
P.S. I am going to bed now (past 11 PM here). I will respond to any further comments first thing tomorrow)

CODE


         dim1_x = Dim("dim1_x", min=1, max=100000)

        dynamic_shapes = {"enc_input": {1: Dim.AUTO}, "dec_input": {1: Dim.AUTO}, "dec_source_mask": {1: Dim.AUTO}, "dec_target_mask": {1: Dim.AUTO}}

        print(" - train_minimum - Lowering the Whole Module - enc_input - ", enc_input)
        print(" - train_minimum - Lowering the Whole Module - dec_input - ", dec_input)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask - ", dec_source_mask)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask - ", dec_target_mask)

        print(" - train_minimum - Lowering the Whole Module - enc_input.shape - ", enc_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_input.shape - ", dec_input.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_source_mask.shape - ", dec_source_mask.shape)
        print(" - train_minimum - Lowering the Whole Module - dec_target_mask.shape - ", dec_target_mask.shape)

       pre_autograd_aten_dialect = torch.export.export(
            m,
            (enc_input, dec_input, dec_source_mask, dec_target_mask),
            dynamic_shapes=dynamic_shapes,
            strict=False
        )

        edge_program: EdgeProgramManager = to_edge(pre_autograd_aten_dialect)

        to_be_lowered_module = edge_program.exported_program()

        from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

        print(" - train_minimum - Lowering the Whole Module - lowered_module - ", lowered_module)

        # Serialize and save it to a file
        save_path = save_path = "/home/adonnini1/Development/ContextQSourceCode/NeuralNetworks/trajectory-prediction-transformers-master/models/tpt_delegate.pte"
        with open(save_path, "wb") as f:
            f.write(lowered_module.to_executorch(ExecutorchBackendConfig(remove_view_copy=False)).buffer)

ERROR LOG

(see attached file)
6782-errorLogLatest.txt

@larryliu0820
Copy link
Contributor

@adonnini :) this is not helpful for debugging because it seems like it's just a graph dump. Someone from our side need to repro this locally and debug. I don't have the time right now but if you want to wait I may get back next week.

@adonnini
Copy link
Author

@larryliu0820 Thanks for your help. I really appreciate it. I understand about your priorities. Would you mind if I check in with you in 1-2 weeks?
https://github.com/adonnini/adonnini-trajectory-prediction-transformers-masterContextQ/tree/main
has all the code and instructions for running the code., and instructions for local repro.
Please let me know if you need anything else.
Thanks. Have a good day

@larryliu0820
Copy link
Contributor

@adonnini yeah ping me next week I'll spend some time repro your issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants