Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Added roadmap/todo #5

Merged
merged 3 commits into from
May 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Roadmap

A kind of roadmap that gives a rough idea about how the project will be continued and which features will be implemented, in the sense of a todo list.

- [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3).
- [ ] Basic functionalities:
- [ ] Annotation `@jace.jit`.
- [ ] Composable with Jax, i.e. take the Jax derivative of a Jace annotated function.
- [ ] Implementing the `stages` model that is supported by Jax.
- [ ] Handling Jax arrays as native input (only on single host).
- [ ] Cache the compilation and lowering results for later reuse.
In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache.
- [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as:
- [ ] Backporting the ones from the prototype.
- [ ] Implement the `scatter` primitive (needed for pyhpc).
- [ ] Implement the `scan` primitive (needed for pyhpc).
- [ ] _Initial_ optimization pipeline
In order to do benchmarks, we need to perform optimizations first.
However, the one offered by DaCe were not that well, so we should, for now, backport the ones from the prototype.
- [ ] Support GPU code (relatively simple, but needs some detection logic).
- [ ] Initial benchmark:
In the beginning we will not have the same dispatching performance as Jax.
But passing these benchmarks could give us some better hint of how to proceed in this matter.
- [ ] Passing the [pyhpc-benchmark](https://github.com/dionhaefner/pyhpc-benchmarks)
- [ ] Passing Felix' fluid project; possibility.
- [ ] Support of static arguments.
- [ ] Stop relying on `jax.make_jaxpr()`.
Look at the `jax._src.pjit.make_jit()` function for how to hijack the staging process.
- [ ] Implementing more advanced primitives:
- [ ] Handling pytrees as arguments.
- [ ] Implement random numbers.
- [ ] `jax.numpy`.
- [ ] `jax.scipy`.
- [ ] Passing the single host Jax unit tests.
- [ ] Multi-Device capabilities, i.e. multiple GPUs but all on the same host.
- [ ] Passing the associated Jax unit tests.
- [ ] Multi-Host capabilities, i.e. MPI.
- [ ] Passing the associated Jax unit tests.

## General

These are more general topics that should be addressed at one point.

- [ ] Integrating better with Jax
- [ ] Support its array type (probably implement this in DaCe).
- [ ] Increase the dispatching speed + Cache
Jax does this in C++, which is impossible to beat in Python, thus we have to go that root as well.
- [ ] Debugging information.
- [ ] Dynamic shapes
This could be done by making the inputs fully dynamic, and then use the primitives to simplify.
For example in an addition the shape of the two inputs and the outputs are the same.
That is knowledge that is inherent to the primitives itself.
However, the compiled object must know how to extract the sizes itself.
- [ ] Defining a Logo:
It should be green with a nice curly font.

# Optimization & Transformations

The SDFG generated by Jace have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them.
Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken.

- [ ] Modified state fusion; Because of the structure we have, this could make `Simplify` much more efficient.
- [ ] Trivial Tasklet removal.
Since we will work a lot with Maps that are trivial (probably the best structure for fusing) we will end up with some of trivial Tasklets, i.e. `__out = __in`.
Thus, we should have a good way to get rid of them.
- [ ] Modified Map fusion transformation.
We should still support parallel and serial fusion as the prototype did, but focusing on serial.
Loading