Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial rust mpi support #2025

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

initial rust mpi support #2025

wants to merge 7 commits into from

Conversation

ZuseZ4
Copy link
Member

@ZuseZ4 ZuseZ4 commented Jul 30, 2024

No description provided.

@ZuseZ4 ZuseZ4 requested a review from wsmoses July 30, 2024 22:41
@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Jul 30, 2024

We need to teach Enzyme to recognize

fn main( ) {
%5 = load ptr %MPI_SUM
%6 = call diffe_test_mpi_call(..., %5)
}

test_mpi_call(... ptr %5) {
%7 = call MPI_Allreduce(..., %5)
}

Because right now Enzyme yells about unknown mpi_allreduce op %5, I assume we already do that to e.g. recognize enyzme_const.

@wsmoses
Copy link
Member

wsmoses commented Jul 30, 2024

See my comment on the other thread, but no that should not be the way it is handled. Speciically if it is a generic reduction op, the derivative code can similarly be arbitrary

call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %7)
store double 0.000000e+00, ptr %7, align 8
tail call void @llvm.experimental.noalias.scope.decl(metadata !7)
%10 = load ptr, ptr @RSMPI_DOUBLE, align 8, !noalias !10, !noundef !13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused atm so this won't test any thing. Can you make this a minimal runnable case. Maybe a different mpi fn?

Also get rid of the other stuff like enzyme_type etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jedbrown I can't run things, can you test if some other function works on this branch?

Comment on lines +1153 to +1154
if (GV->getName() == "ompi_mpi_op_sum" ||
GV->getName() == "RSMPI_SUM") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to resolve this.

source_id: DefId(0:12 ~ dot_enzyme[f53a]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17hd37f1f8a2c8de07dE double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_SUM = external local_unnamed_addr global ptr

This is the relevant code: rsmpi/rsmpi@840e01c#diff-9a676b0d0c142cd1e89e8174ddb007db982d8602bd374a04e40e9f6a421acaebR216-R228

Run with

$ RUSTFLAGS='-Z unstable-options' cargo +enzyme r --example=dot_enzyme --release

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me neither, but can you share the module you got from ENZYME_OPT=1? Then I can experiment around to see if I find the right changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jedbrown you need to add RSMPI_SUM to the ActivityAnalysis.cpp code [the message (poorly) warns that the global variable is differentiable, but Enzyme is unable to determine a differentiable version of the global. Of course it makes no sense to differentiate wrt MPI_SUM so we can mark that in ActivityAnalysis.cpp]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Why didn't ompi_mpi_op_sum need to be there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one has an extra pointer indirection causing a load which needs to be analyzed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah in this case, now the issue is that the op of MPI_allreduce cannot be detected (the earlier check if the argument was a literal global no longer applies, since this is a load of RSMPI_SUM). Changing the MPI_Allreduce check to consider something along the lines of:

if (LI = dyn_cast<LoadInst>(...)))
  if (auto GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()))
    if (GV->getName() == "RSMPI_SUM")
      legal = true;

Copy link
Collaborator

@jedbrown jedbrown Aug 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed that (cb2d739; assuming it was the correct place), but now I have

source_id: DefId(0:12 ~ dot_enzyme[42dc]::dot_parallel)
num_fnc_args: 5
input_activity.len(): 5
error: <unknown>:0:0: in function preprocess__ZN10dot_enzyme12dot_parallel17h2f3ed146b457ca09E double (ptr, ptr, i64, ptr, i64): Enzyme: cannot compute with global variable that doesn't have marked shadow global
@RSMPI_COMM_WORLD = external local_unnamed_addr global ptr

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay that's the same as the first issue [a global which cannot be proven non-differentiable]. ActivityAnalysis.cpp is again the right place to add that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hand-holding. I think this output is correct now.

$ RUSTFLAGS='-Z unstable-options' cargo +enzyme mpirun -n 2 --example=dot_enzyme --release | sort
[0] bx: [0.0, 2.0, 4.0, 6.0, 8.0], by: [0.0, 2.0, 4.0, 6.0, 8.0]
[0] local: 30
[1] bx: [200.0, 202.0, 204.0, 206.0, 208.0], by: [20.0, 22.0, 24.0, 26.0, 28.0]
[1] local: 6130
global: 6160
global: 6160

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that --release is required, otherwise I see

invertedPointers:
   invertedPointers[ptr %0] =   <badref> = load ptr, ptr %4, align 8
end invertedPointers
  <badref> = load ptr, ptr %4, align 8
rustc: /home/jed/src/rust-enzyme/src/tools/enzyme/enzyme/Enzyme/GradientUtils.cpp:8489: virtual void InvertedPointerVH::deleted(): Assertion `0 && "erasing something in invertedPointers map"' failed.
error: could not compile `mpi` (example "dot_enzyme")

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Sep 9, 2024

Given that this has a test that passes at least in release mode, can we merge this (and maybe create an issue with further TODOs, to make sure the progress here doesn't get lost?

@wsmoses
Copy link
Member

wsmoses commented Sep 9, 2024

I don’t think there’s a test in this PR, no? (From my earlier comment which doesn’t look resolved the llvm test doesn’t test this?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants