Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement forward- and backward mode AD in the interpreter #2153

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

vox9
Copy link

@vox9 vox9 commented May 30, 2024

I'm solely pushing this for code review - I'd need to revise the commits, and clean up the code before I do a real pull request :)

@athas athas marked this pull request as draft May 31, 2024 05:45
Copy link
Member

@athas athas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually pretty good and I am happy with how relatively little code it requires. There is however a pretty long list of style issues to address (and a little functionality things missing, too). I suggest make check.

Comment on lines +1345 to +1351
adToValue v = case v of
AD.Seed s -> ValueSeed s
AD.Primal v -> ValuePrim $ putV v
valueToAD v = case v of
ValuePrim v' -> primToAD v'
ValueSeed x' -> AD.Seed x'
_ -> error "Invalid AD value"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write this equationally:

adToValue (AD.Seed s) = ...
...

Comment on lines +1608 to +1610
let primOnly = all (\x -> case x of
ValueSeed _ -> False
_ -> True) $ fromJust $ fromTuple x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining an isValueSeed function instead of inlining multi-line lambdas (except for forM and similar that take the lambda as the last argument).

Comment on lines 1999 to 2004
-- Impregnate the values
let addSeeds i v = case v of
ValuePrim v' -> (i + 1, ValueSeed $ AD.VjpSeed depth $ AD.TapeId i $ AD.Primal $ fromJust $ getV v')
ValueSeed v' -> (i + 1, ValueSeed $ AD.VjpSeed depth $ AD.TapeId i $ AD.Seed v')
ValueRecord m -> second ValueRecord $ M.mapAccum addSeeds i m
_ -> error "Not implemented (gs3g3ss)" -- TODO: Arrays?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to improve the way this is written (and to handle arrays), I suggest defining a utility function for performing a monadic traversal across the "leaves" of an interpreter value. Maybe something with this type:

traverseValue :: Monad f => (PrimValue -> f (Value m)) -> (AD.ADSeed -> f (Value m)) -> Value m -> f (Value m)

Comment on lines 2059 to 2066
let deriveSeeds v = case v of
ValuePrim v' -> ValuePrim $ putV $ AD.valueAsType (fromJust $ getV v') 0
ValueSeed s@(AD.JvpSeed d _ d') ->
if d == depth then -- TODO: Fix add op
adToValue d'
else ValuePrim $ putV $ AD.valueAsType (AD.value $ AD.primal s) 0
ValueRecord m -> ValueRecord $ M.map deriveSeeds m
_ -> error "Not implemented (d19hasd782)" -- TODO: Arrays?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a hunch that some functions (like this one) should be factored into their own top-level definitions, to make it easier to understand.

import Data.Maybe (fromJust)
import Data.Foldable (foldlM)

data ADSeed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment remarking on what the meaning of "seed" is (both here and in your implementation in general).

primal (VjpSeed _ v) = tapeValue v
primal (JvpSeed _ v _) = v

valueAsType :: PrimValue -> Rational -> PrimValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only ever call this function with a zero value for the Rational. There is a notion of blankPrimValue that you could use instead.

Comment on lines 111 to 114
addFor op = OpBin $ fromJust $ case opReturnType op of
IntType t -> Just $ Add t OverflowUndef
FloatType t -> Just $ FAdd t
_ -> Nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using fromJust and getting a bad error message, just call error directly with an explanation of which op failed.

FloatType t -> Just $ FMul t
_ -> Nothing

typeAsValue :: PrimType -> Rational -> PrimValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost identical to valueAsType. Surely some code can be deduplicated?

show (OpFn fn) = show fn
show (OpConv op) = show op

value :: ADValue -> PrimValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this function primalValue. Everything is a "value" in the interpreter!

-- VJP

data Tape
= TapeId Int ADValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You convinced me that the Int here is necessary (and it is obvious in retrospect), but it is really subtle. I suggest coming up with a name for it (too bad "seed" is taken - or is it?).

@athas
Copy link
Member

athas commented Sep 17, 2024

Will this see any further changes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants