-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial rust mpi support #2025
base: main
Are you sure you want to change the base?
initial rust mpi support #2025
Conversation
We need to teach Enzyme to recognize
Because right now Enzyme yells about unknown mpi_allreduce op %5, I assume we already do that to e.g. recognize enyzme_const. |
See my comment on the other thread, but no that should not be the way it is handled. Speciically if it is a generic reduction op, the derivative code can similarly be arbitrary |
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %7) | ||
store double 0.000000e+00, ptr %7, align 8 | ||
tail call void @llvm.experimental.noalias.scope.decl(metadata !7) | ||
%10 = load ptr, ptr @RSMPI_DOUBLE, align 8, !noalias !10, !noundef !13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unused atm so this won't test any thing. Can you make this a minimal runnable case. Maybe a different mpi fn?
Also get rid of the other stuff like enzyme_type etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jedbrown I can't run things, can you test if some other function works on this branch?
if (GV->getName() == "ompi_mpi_op_sum" || | ||
GV->getName() == "RSMPI_SUM") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how to resolve this.
source_id: DefId(0:12 ~ dot_enzyme[f53a]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17hd37f1f8a2c8de07dE double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_SUM = external local_unnamed_addr global ptr
This is the relevant code: rsmpi/rsmpi@840e01c#diff-9a676b0d0c142cd1e89e8174ddb007db982d8602bd374a04e40e9f6a421acaebR216-R228
Run with
$ RUSTFLAGS='-Z unstable-options' cargo +enzyme r --example=dot_enzyme --release
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Me neither, but can you share the module you got from ENZYME_OPT=1? Then I can experiment around to see if I find the right changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jedbrown you need to add RSMPI_SUM to the ActivityAnalysis.cpp code [the message (poorly) warns that the global variable is differentiable, but Enzyme is unable to determine a differentiable version of the global. Of course it makes no sense to differentiate wrt MPI_SUM so we can mark that in ActivityAnalysis.cpp]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Why didn't ompi_mpi_op_sum
need to be there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one has an extra pointer indirection causing a load which needs to be analyzed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah in this case, now the issue is that the op of MPI_allreduce cannot be detected (the earlier check if the argument was a literal global no longer applies, since this is a load of RSMPI_SUM). Changing the MPI_Allreduce check to consider something along the lines of:
if (LI = dyn_cast<LoadInst>(...)))
if (auto GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()))
if (GV->getName() == "RSMPI_SUM")
legal = true;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed that (cb2d739; assuming it was the correct place), but now I have
source_id: DefId(0:12 ~ dot_enzyme[42dc]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17h2f3ed146b457ca09E double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_COMM_WORLD = external local_unnamed_addr global ptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay that's the same as the first issue [a global which cannot be proven non-differentiable]. ActivityAnalysis.cpp is again the right place to add that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hand-holding. I think this output is correct now.
$ RUSTFLAGS='-Z unstable-options' cargo +enzyme mpirun -n 2 --example=dot_enzyme --release | sort
[0] bx: [0.0, 2.0, 4.0, 6.0, 8.0], by: [0.0, 2.0, 4.0, 6.0, 8.0]
[0] local: 30
[1] bx: [200.0, 202.0, 204.0, 206.0, 208.0], by: [20.0, 22.0, 24.0, 26.0, 28.0]
[1] local: 6130
global: 6160
global: 6160
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that --release
is required, otherwise I see
invertedPointers:
invertedPointers[ptr %0] = <badref> = load ptr, ptr %4, align 8
end invertedPointers
<badref> = load ptr, ptr %4, align 8
rustc: /home/jed/src/rust-enzyme/src/tools/enzyme/enzyme/Enzyme/GradientUtils.cpp:8489: virtual void InvertedPointerVH::deleted(): Assertion `0 && "erasing something in invertedPointers map"' failed.
error: could not compile `mpi` (example "dot_enzyme")
Given that this has a test that passes at least in release mode, can we merge this (and maybe create an issue with further TODOs, to make sure the progress here doesn't get lost? |
I don’t think there’s a test in this PR, no? (From my earlier comment which doesn’t look resolved the llvm test doesn’t test this?) |
No description provided.