From 491f2f0a80d065d8ea537c3a6862d3e7556c960d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 14:02:25 +0100 Subject: [PATCH] Fix and test asymmetric_quadratic example --- examples/asymmetric_quadratic.jl | 15 +++++++-------- test/runtests.jl | 5 +++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/asymmetric_quadratic.jl b/examples/asymmetric_quadratic.jl index d5abc6c0..63aef7f1 100644 --- a/examples/asymmetric_quadratic.jl +++ b/examples/asymmetric_quadratic.jl @@ -1,9 +1,11 @@ using Nabla +using LinearAlgebra +using Random # Generate the values required to compute a matrix quadratic form. N = 5 B = randn(5, 5) -A = B.'B + UniformScaling(1e-6) +A = B'B + UniformScaling(1e-6) # Low-level API computation of derivatives. @@ -12,13 +14,10 @@ x_, y_ = randn(N), randn(N) x, y = Leaf.(Tape(), (x_, y_)) # Compute the forward pass. -z = x.' * (A * y) # Temporary bracketting because we don't support RowVectors yet. +z = x' * (A * y) # Temporary bracketting because we don't support RowVectors yet. -println("Output of the forward pass is:") -println(z) -println() -println("y is $(Nabla.unbox(z)).") -println() +println("Output of the forward pass is:\n $z\n") +println("y is $(Nabla.unbox(z)).\n") # Get the reverse tape. z̄ = ∇(z) @@ -42,7 +41,7 @@ println("Gradient of z w.r.t. y at $y_ is $ȳ") # Define the function to be differentiated. Parameters w.r.t. which we want gradients must # be arguments. Parameters that we don't want gradients w.r.t. should be passed in via a # closure. -@unionise f(x::AbstractVector, y::AbstractVector) = x.'A * y +@unionise f(x::AbstractVector, y::AbstractVector) = x'A * y # Compute a function `∇f` which computes the derivative of `f` w.r.t. the inputs. ∇f = ∇(f) diff --git a/test/runtests.jl b/test/runtests.jl index 92e57405..c2cc6ed0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,4 +51,9 @@ end include("checkpointing.jl") end +@testset "examples" begin + # make sure the examples don't throw errors + include("../examples/asymmetric_quadratic.jl") +end + end