Skip to content

Commit

Permalink
Merge pull request #245 from ReactiveBayes/multivariate-input-to-f
Browse files Browse the repository at this point in the history
Multivariate input to function
  • Loading branch information
bvdmitri authored Jul 3, 2024
2 parents bd5f843 + efd655d commit d83383a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1666,11 +1666,20 @@ function add_edge!(
return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, 1)
end

add_edge!(
model::Model,
factor_node_id::NodeLabel,
factor_node_propeties::FactorNodeProperties,
variable_node_id::Union{ProxyLabel, VariableRef},
interface_name::Symbol,
index
) = add_edge!(model, factor_node_id, factor_node_propeties, unroll(variable_node_id), interface_name, index)

function add_edge!(
model::Model,
factor_node_id::NodeLabel,
factor_node_propeties::FactorNodeProperties,
variable_node_id::Union{ProxyLabel, NodeLabel, VariableRef},
variable_node_id::Union{NodeLabel},
interface_name::Symbol,
index
)
Expand Down
46 changes: 45 additions & 1 deletion test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1770,4 +1770,48 @@ end
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
test_model(y = 1)
)
end
end

@testitem "Multivariate input to function" begin
using GraphPPL
import GraphPPL: create_model, getorcreate!, datalabel

include("testutils.jl")
function dot end
function relu end

@model function neuron(in, out)
local w
for i in 1:(length(in))
w[i] ~ Normal(0.0, 1.0)
end
bias ~ Normal(0.0, 1.0)
unactivated := dot(in, w) + bias
out := relu(unactivated)
end

@model function neural_network_layer(in, out, n)
for i in 1:n
out[i] ~ neuron(in = in)
end
end

@model function neural_net(in, out)
local softin
for i in 1:length(in)
softin[i] ~ Normal(in[i], 1.0)
end
h1 ~ neural_network_layer(in = softin, n = 10)
h2 ~ neural_network_layer(in = h1, n = 16)
out ~ neural_network_layer(in = h2, n = 2)
end

model = create_model(neural_net()) do model, ctx
in = datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :in, rand(3))
out = datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :out, randn(2))
return (in = in, out = out)
end
@test length(collect(filter(as_node(Normal), model))) == 253
@test length(collect(filter(as_node(dot), model))) == 28
@test length(collect(filter(as_variable(:in), model))) == 3
end

0 comments on commit d83383a

Please sign in to comment.