Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add the initial part of MPI tablegen for adjoint generator #1325

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

PragmaTwice
Copy link
Collaborator

WIP, some comment and pr description will be added soon.

@PragmaTwice PragmaTwice requested a review from ZuseZ4 July 11, 2023 17:51
@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 12, 2023

As discusssed here is a slight adjustment of your PR which you could use to add your mpi code as part of InstructionDerivatives.td instead of copying the style of BlasDerivatives.td

It might have some mistakes so make sure to check, but if you get something similar to this style to compile,
we should be able to improve from there on.

class mpiPattern<dag patternToMatch, list<string> funcNames, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      (InactiveArg), //sendcount
                      (InactiveArg), //sendtype
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                      (InactiveArg), //recvcunt
                      (InactiveArg), //recvtype
                      (InactiveArg), //root
                      (InactiveArg), //comm
                  ],
                  (ForwardFromSummedReverse)
                  >;

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 12, 2023

Based on your second question, this could be a shorter design, which might need a few more adjustments,
so we can try this later.

class mpiPattern<dag patternToMatch, list<string> funcNames, list<string> actArgs, string overwrittenArg, list<dag> resultOps, dag forwardOps> {
  dag PatternToMatch = patternToMatch;
  list<string> activeArgs = actArgs;
  string overwritten= overwrittenArg,
  list<string> names = funcNames;
  list<dag> ArgDerivatives = resultOps;
  dag ArgDuals = forwardOps;
}
                    
def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  ["sendbuf", "recvbuf"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;

@wsmoses
Copy link
Member

wsmoses commented Jul 12, 2023

I would heavily recommend against making a new MPI-specific tablegen infrastructure, as opposed to extending and using the existing call infrastructure with whatever new operations you need.

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 12, 2023

@wsmoses I agree on not having a third mpi-tg beside of enzyme-tg and blas-tg,
but I did care less about having 4 or 5 different classes inside of enzyme-tg.
But most of the extensions I had in mind (e.g. potentially active args vs. always inactive args) can also be solved
by extending the call class and set some default values like all-active.

@PragmaTwice Using the existing call class should make it easier to get a first compiling version, so maybe focus on this one. Most rules will likely be to complex and miss features so you won't be able to emit all of the required code to handle mpi yet. However, once you get it to compile we can see which features are missing and add those one by one.

@PragmaTwice
Copy link
Collaborator Author

I will try to construct a CallMPIPattern that inherits CallPattern so that CallMPIPattern can be treated as CallPattern and meanwhile we can add some additional arguments for other uses.

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 14, 2023

def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),

                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                  ],
                  (ForwardFromSummedReverse)
                  >;                 

So if you do want to have a type system as extension for the callpatern, you could use this to extend the list.
The default callPattern expects one dag per input argument. So every time your type system does return buffer, you pick the next dag from your own list and every time you do have a a type different from buffer you do add (InactiveArg), as dag rule

the above example should therefore translate into:

def mpiPattern<(Op $sendbuf, $sendcount, $sendtype, $recvbuf, $recvcount, $recvtype, $root, $comm),
                  ["MPI_Gather", "PMPI_Gather"],
                  [buf, size, datatype, buf, size, datatype, integer, comm]
                  "recvbuf",
                  [                    
                      (b<"MPI_Scatter"> shadow $recvbuf, $recvcount, $recvtype, $buf, $sendcount, $sendtype, $root, $comm),

                    (InactiveArg),                    
                    (InactiveArg),                  
                      // TODO: if root, Zero diff(recvbuffer) [memset to 0]   
                      // (Select (ifRoot diff(recvbuffer)), (ifNotRoot doSomethingElse)
                      // (Select (FCmpOLT $x, $y), (SelectIfActive $x, (Shadow $x), (Zero $x)), (SelectIfActive $y, (Shadow $y), (Zero $y)))
                      (MemCopyFloats $buf, shadow $sendbuf", (Mul $sendcount, (MPITySize $sendtype))), //recvbuf
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                    
                    (InactiveArg),                      
                  ],
                  (ForwardFromSummedReverse)
                  >;

simplification:
add the following lines at the beginning of all mpi functions:

        Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType());
        Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType());

Also you can try to generate the following names and helper for each input argument:

  const int pos_x = 1;
  const auto orig_x = call.getArgOperand(pos_x);
  auto arg_x = gutils->getNewFromOriginal(orig_x);

and for all the shadow arguments (might already be done by tablegen, just check)

        Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jul 14, 2023

Simplifaction1: Mark all MPI_Gather arguments as InactiveArg,
Then you specify the forward pass as ForwardFromSummedReverse.
You create all helper args I mentioned above.
For all arguments which are buffer based on your type system, you do look up the shadow, as in the example above.
Then you do create the primal call function and call it, replacing your buffer args with the shadow of those buffer args:

        if (forwardMode) {
          Value *args[] = {
              /*sendbuf*/ shadow_sendbuf,
              /*sendcount*/ sendcount,
              /*sendtype*/ sendtype,
              /*recvbuf*/ shadow_recvbuf,
              /*recvcount*/ recvcount,
              /*recvtype*/ recvtype,
              /*root*/ root,
              /*comm*/ comm,
          };

          auto Defs = gutils->getInvertedBundles(
              &call,
              {ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Shadow, ValueType::Primal, ValueType::Primal,
               ValueType::Primal, ValueType::Primal},
              Builder2, /*lookup*/ false);

#if LLVM_VERSION_MAJOR >= 11
          auto callval = call.getCalledOperand();
#else
          auto callval = call.getCalledValue();
#endif
          Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
          return;
        }

Beside of the 3 funcitons which you already added, this logic would also be able to handle the forward mode of

    if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" ||
        funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") {

So I do think it is worth starting with it :)

MilesCranmer pushed a commit to MilesCranmer/Enzyme that referenced this pull request Jul 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants