Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Permit eval while tracing, but do not update trace. #443

Closed
wants to merge 2 commits into from

Conversation

slyubomirsky
Copy link
Collaborator

Addresses #409. Evaluation while tracing still gives a warning but does not alter the graph, so it does not produce any errors.

@slyubomirsky slyubomirsky added the tripy Pull request for the tripy project label Dec 12, 2024
@slyubomirsky
Copy link
Collaborator Author

I am not fully certain how the storage mechanism described in the subsequent comment on #409 should work. If the main issue is recompilation caused by the fact that we not updating the trace graph if we evaluate while tracing, perhaps one approach could be to insert the Storage op anyway but restore the original graph after compilation.

@@ -265,8 +269,15 @@ def __repr__(self) -> str:

data_list = self.tolist()

assert isinstance(self.trace_tensor.producer, Storage)
data_shape = self.trace_tensor.producer.shape
if isinstance(self.trace_tensor.producer, Storage):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the memref have shape information? If so, it would probably be cleaner to use the memref directly rather than tolist().

Comment on lines 230 to 231
if not self.trace_tensor.is_compile_tracer:
Storage.build_internal([], [self.trace_tensor], data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is if we evaluate multiple tensors in a compiled graph, we'll get quadratic time complexity. What I suggested in #409 (comment) is to store the evaluated result but only use it for evaluation and not tracing. I'm not sure exactly how that would work, but I expect it would require changes to how we trace during eval().

@pranavm-nvidia
Copy link
Collaborator

I am not fully certain how the storage mechanism described in the subsequent comment on #409 should work. If the main issue is recompilation caused by the fact that we not updating the trace graph if we evaluate while tracing, perhaps one approach could be to insert the Storage op anyway but restore the original graph after compilation.

Could you elaborate on this approach? Sounds promising.

@slyubomirsky
Copy link
Collaborator Author

slyubomirsky commented Dec 12, 2024

I haven't tried implementing this yet, but if we could record which ops we changed during the trace, we could avoid recompilation during the trace and then change them back afterwards.

@slyubomirsky slyubomirsky force-pushed the no-error-for-eval-while-tracing branch from dbeccd0 to f89fd44 Compare December 17, 2024 04:39
Comment on lines +235 to +236
if REVERT_GRAPH_AFTER_COMPILING is not None:
REVERT_GRAPH_AFTER_COMPILING.append((self.trace_tensor, self.trace_tensor.producer))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bit of a clumsy approach to keeping around the old producers (required a global var from the compiler). I wasn't sure where else the data structure could reside, since evaluation is handled as a method on tensors. It's a tricky issue because the tensor evaluated could be anywhere in the middle of the graph and the interface to compile is an opaque Callable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about storing it on the trace tensors themselves? original_producer or something like that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in that case we would need to do a DFS from the output and swap the producer before stepping further

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah searching the graph might be necessary if there isn't anywhere else we could record the tensors that have been evaluated.

@slyubomirsky
Copy link
Collaborator Author

Closing per discussion: We've decided that there aren't any cases we could think of where evaluating while tracing is useful, especially since evaluating without using the result (e.g., just printing) already worked. We can revisit if such a case does present itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tripy Pull request for the tripy project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants