-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement forward- and backward mode AD in the interpreter #2153
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually pretty good and I am happy with how relatively little code it requires. There is however a pretty long list of style issues to address (and a little functionality things missing, too). I suggest make check
.
adToValue v = case v of | ||
AD.Seed s -> ValueSeed s | ||
AD.Primal v -> ValuePrim $ putV v | ||
valueToAD v = case v of | ||
ValuePrim v' -> primToAD v' | ||
ValueSeed x' -> AD.Seed x' | ||
_ -> error "Invalid AD value" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Write this equationally:
adToValue (AD.Seed s) = ...
...
let primOnly = all (\x -> case x of | ||
ValueSeed _ -> False | ||
_ -> True) $ fromJust $ fromTuple x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defining an isValueSeed
function instead of inlining multi-line lambdas (except for forM
and similar that take the lambda as the last argument).
src/Language/Futhark/Interpreter.hs
Outdated
-- Impregnate the values | ||
let addSeeds i v = case v of | ||
ValuePrim v' -> (i + 1, ValueSeed $ AD.VjpSeed depth $ AD.TapeId i $ AD.Primal $ fromJust $ getV v') | ||
ValueSeed v' -> (i + 1, ValueSeed $ AD.VjpSeed depth $ AD.TapeId i $ AD.Seed v') | ||
ValueRecord m -> second ValueRecord $ M.mapAccum addSeeds i m | ||
_ -> error "Not implemented (gs3g3ss)" -- TODO: Arrays? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to improve the way this is written (and to handle arrays), I suggest defining a utility function for performing a monadic traversal across the "leaves" of an interpreter value. Maybe something with this type:
traverseValue :: Monad f => (PrimValue -> f (Value m)) -> (AD.ADSeed -> f (Value m)) -> Value m -> f (Value m)
src/Language/Futhark/Interpreter.hs
Outdated
let deriveSeeds v = case v of | ||
ValuePrim v' -> ValuePrim $ putV $ AD.valueAsType (fromJust $ getV v') 0 | ||
ValueSeed s@(AD.JvpSeed d _ d') -> | ||
if d == depth then -- TODO: Fix add op | ||
adToValue d' | ||
else ValuePrim $ putV $ AD.valueAsType (AD.value $ AD.primal s) 0 | ||
ValueRecord m -> ValueRecord $ M.map deriveSeeds m | ||
_ -> error "Not implemented (d19hasd782)" -- TODO: Arrays? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a hunch that some functions (like this one) should be factored into their own top-level definitions, to make it easier to understand.
import Data.Maybe (fromJust) | ||
import Data.Foldable (foldlM) | ||
|
||
data ADSeed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment remarking on what the meaning of "seed" is (both here and in your implementation in general).
primal (VjpSeed _ v) = tapeValue v | ||
primal (JvpSeed _ v _) = v | ||
|
||
valueAsType :: PrimValue -> Rational -> PrimValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only ever call this function with a zero value for the Rational
. There is a notion of blankPrimValue
that you could use instead.
addFor op = OpBin $ fromJust $ case opReturnType op of | ||
IntType t -> Just $ Add t OverflowUndef | ||
FloatType t -> Just $ FAdd t | ||
_ -> Nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using fromJust
and getting a bad error message, just call error
directly with an explanation of which op
failed.
FloatType t -> Just $ FMul t | ||
_ -> Nothing | ||
|
||
typeAsValue :: PrimType -> Rational -> PrimValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is almost identical to valueAsType
. Surely some code can be deduplicated?
show (OpFn fn) = show fn | ||
show (OpConv op) = show op | ||
|
||
value :: ADValue -> PrimValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call this function primalValue
. Everything is a "value" in the interpreter!
-- VJP | ||
|
||
data Tape | ||
= TapeId Int ADValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You convinced me that the Int
here is necessary (and it is obvious in retrospect), but it is really subtle. I suggest coming up with a name for it (too bad "seed" is taken - or is it?).
Will this see any further changes? |
I'm solely pushing this for code review - I'd need to revise the commits, and clean up the code before I do a real pull request :)