diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..e27152a --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,67 @@ +# Roadmap + +A kind of roadmap that gives a rough idea about how the project will be continued and which features will be implemented, in the sense of a todo list. + +- [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3). +- [ ] Basic functionalities: + - [ ] Annotation `@jace.jit`. + - [ ] Composable with Jax, i.e. take the Jax derivative of a Jace annotated function. + - [ ] Implementing the `stages` model that is supported by Jax. + - [ ] Handling Jax arrays as native input (only on single host). + - [ ] Cache the compilation and lowering results for later reuse. + In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache. +- [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as: + - [ ] Backporting the ones from the prototype. + - [ ] Implement the `scatter` primitive (needed for pyhpc). + - [ ] Implement the `scan` primitive (needed for pyhpc). +- [ ] _Initial_ optimization pipeline + In order to do benchmarks, we need to perform optimizations first. + However, the one offered by DaCe were not that well, so we should, for now, backport the ones from the prototype. +- [ ] Support GPU code (relatively simple, but needs some detection logic). +- [ ] Initial benchmark: + In the beginning we will not have the same dispatching performance as Jax. + But passing these benchmarks could give us some better hint of how to proceed in this matter. + - [ ] Passing the [pyhpc-benchmark](https://github.com/dionhaefner/pyhpc-benchmarks) + - [ ] Passing Felix' fluid project; possibility. +- [ ] Support of static arguments. +- [ ] Stop relying on `jax.make_jaxpr()`. + Look at the `jax._src.pjit.make_jit()` function for how to hijack the staging process. +- [ ] Implementing more advanced primitives: + - [ ] Handling pytrees as arguments. + - [ ] Implement random numbers. + - [ ] `jax.numpy`. + - [ ] `jax.scipy`. +- [ ] Passing the single host Jax unit tests. +- [ ] Multi-Device capabilities, i.e. multiple GPUs but all on the same host. + - [ ] Passing the associated Jax unit tests. +- [ ] Multi-Host capabilities, i.e. MPI. + - [ ] Passing the associated Jax unit tests. + +## General + +These are more general topics that should be addressed at one point. + +- [ ] Integrating better with Jax + - [ ] Support its array type (probably implement this in DaCe). +- [ ] Increase the dispatching speed + Cache + Jax does this in C++, which is impossible to beat in Python, thus we have to go that root as well. +- [ ] Debugging information. +- [ ] Dynamic shapes + This could be done by making the inputs fully dynamic, and then use the primitives to simplify. + For example in an addition the shape of the two inputs and the outputs are the same. + That is knowledge that is inherent to the primitives itself. + However, the compiled object must know how to extract the sizes itself. +- [ ] Defining a Logo: + It should be green with a nice curly font. + +# Optimization & Transformations + +The SDFG generated by Jace have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them. +Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken. + +- [ ] Modified state fusion; Because of the structure we have, this could make `Simplify` much more efficient. +- [ ] Trivial Tasklet removal. + Since we will work a lot with Maps that are trivial (probably the best structure for fusing) we will end up with some of trivial Tasklets, i.e. `__out = __in`. + Thus, we should have a good way to get rid of them. +- [ ] Modified Map fusion transformation. + We should still support parallel and serial fusion as the prototype did, but focusing on serial.