Skip to content

Commit

Permalink
docs: Added roadmap/todo (#5)
Browse files Browse the repository at this point in the history
I added this file, which I called `ROADMAP.md`, but `TODO.md` would be
fine as well, to keep track and to help me kind of guide the
development.
  • Loading branch information
philip-paul-mueller authored May 13, 2024
1 parent 09bcb3f commit 0bc7ced
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Roadmap

A kind of roadmap that gives a rough idea about how the project will be continued and which features will be implemented, in the sense of a todo list.

- [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3).
- [ ] Basic functionalities:
- [ ] Annotation `@jace.jit`.
- [ ] Composable with Jax, i.e. take the Jax derivative of a Jace annotated function.
- [ ] Implementing the `stages` model that is supported by Jax.
- [ ] Handling Jax arrays as native input (only on single host).
- [ ] Cache the compilation and lowering results for later reuse.
In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache.
- [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as:
- [ ] Backporting the ones from the prototype.
- [ ] Implement the `scatter` primitive (needed for pyhpc).
- [ ] Implement the `scan` primitive (needed for pyhpc).
- [ ] _Initial_ optimization pipeline
In order to do benchmarks, we need to perform optimizations first.
However, the one offered by DaCe were not that well, so we should, for now, backport the ones from the prototype.
- [ ] Support GPU code (relatively simple, but needs some detection logic).
- [ ] Initial benchmark:
In the beginning we will not have the same dispatching performance as Jax.
But passing these benchmarks could give us some better hint of how to proceed in this matter.
- [ ] Passing the [pyhpc-benchmark](https://github.com/dionhaefner/pyhpc-benchmarks)
- [ ] Passing Felix' fluid project; possibility.
- [ ] Support of static arguments.
- [ ] Stop relying on `jax.make_jaxpr()`.
Look at the `jax._src.pjit.make_jit()` function for how to hijack the staging process.
- [ ] Implementing more advanced primitives:
- [ ] Handling pytrees as arguments.
- [ ] Implement random numbers.
- [ ] `jax.numpy`.
- [ ] `jax.scipy`.
- [ ] Passing the single host Jax unit tests.
- [ ] Multi-Device capabilities, i.e. multiple GPUs but all on the same host.
- [ ] Passing the associated Jax unit tests.
- [ ] Multi-Host capabilities, i.e. MPI.
- [ ] Passing the associated Jax unit tests.

## General

These are more general topics that should be addressed at one point.

- [ ] Integrating better with Jax
- [ ] Support its array type (probably implement this in DaCe).
- [ ] Increase the dispatching speed + Cache
Jax does this in C++, which is impossible to beat in Python, thus we have to go that root as well.
- [ ] Debugging information.
- [ ] Dynamic shapes
This could be done by making the inputs fully dynamic, and then use the primitives to simplify.
For example in an addition the shape of the two inputs and the outputs are the same.
That is knowledge that is inherent to the primitives itself.
However, the compiled object must know how to extract the sizes itself.
- [ ] Defining a Logo:
It should be green with a nice curly font.

# Optimization & Transformations

The SDFG generated by Jace have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them.
Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken.

- [ ] Modified state fusion; Because of the structure we have, this could make `Simplify` much more efficient.
- [ ] Trivial Tasklet removal.
Since we will work a lot with Maps that are trivial (probably the best structure for fusing) we will end up with some of trivial Tasklets, i.e. `__out = __in`.
Thus, we should have a good way to get rid of them.
- [ ] Modified Map fusion transformation.
We should still support parallel and serial fusion as the prototype did, but focusing on serial.

0 comments on commit 0bc7ced

Please sign in to comment.