Skip to content

Commit

Permalink
Change saving of neighbors between variable and factor nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterwln committed Nov 27, 2023
1 parent 2363882 commit 14a9852
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
1 change: 0 additions & 1 deletion src/constraints_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
edges = GraphPPL.edges(model, node_label)
constraint = Tuple(sort!(collect(constraint_set), by = first))
constraint = map(clusters -> Tuple(getindex.(Ref(edges), clusters)), constraint)


node_data.factorization_constraint = constraint
end
Expand Down
18 changes: 12 additions & 6 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ to_symbol(label::NodeLabel) = Symbol(String(label.name) * "_" * string(label.glo

Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter)


struct EdgeLabel
name::Symbol
index::Union{Int, Nothing}
Expand Down Expand Up @@ -278,9 +277,17 @@ Graphs.ne(model::Model) = Graphs.ne(model.graph)
Graphs.edges(model::Model) = Graphs.edges(model.graph)
MetaGraphsNext.label_for(model::Model, node_id::Int) = MetaGraphsNext.label_for(model.graph, node_id)

Graphs.neighbors(model::Model, node::NodeLabel) = map(neighbor -> neighbor[1], model[node].neighbors)
Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node])
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors)
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node)
Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes))
Graphs.edges(model::Model, node::NodeLabel) = map(edge -> edge[2], model[node].neighbors)

Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node])
Graphs.edges(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[2], nodedata.neighbors)
function Graphs.edges(model::Model, node::NodeLabel, nodedata::VariableNodeData)
return Tuple(model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node))
end
Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.edges(model, node), nodes))

Check warning on line 290 in src/graph_engine.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_engine.jl#L290

Added line #L290 was not covered by tests

abstract type AbstractModelFilterPredicate end

Expand Down Expand Up @@ -768,9 +775,8 @@ iterator(interfaces::NamedTuple) = zip(keys(interfaces), values(interfaces))

function add_edge!(model::Model, factor_node_id::NodeLabel, variable_node_id::Union{ProxyLabel, NodeLabel}, interface_name::Symbol; index = nothing)
label = EdgeLabel(interface_name, index)
# model.graph[unroll(variable_node_id), factor_node_id] = label
model[factor_node_id].neighbors = (model[factor_node_id].neighbors..., (unroll(variable_node_id), label))
# model[unroll(variable_node_id)].neighbors = (model[unroll(variable_node_id)].neighbors..., (factor_node_id, label))
model.graph[unroll(variable_node_id), factor_node_id] = label
end

function add_edge!(model::Model, factor_node_id::NodeLabel, variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, interface_name::Symbol; index = 1)
Expand All @@ -783,7 +789,7 @@ increase_index(any) = 1
increase_index(x::AbstractArray) = length(x)

function add_factorization_constraint!(model::Model, factor_node_id::NodeLabel)
out_degree = length(model[factor_node_id].neighbors)
out_degree = length(model[factor_node_id].neighbors)
constraint = BitSetTuple(out_degree)
set_factorization_constraint!(model[factor_node_id], constraint)
end
Expand Down
15 changes: 7 additions & 8 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,12 @@ end
b = NodeLabel(:b, 2)
model[a] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, nothing, ())
model[b] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing, ())
add_edge!(model, a, b, :edge; index=1)
add_edge!(model, a, b, :edge; index = 1)
@test length(edges(model)) == 1

c = NodeLabel(:c, 2)
model[NodeLabel(:c, 2)] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing, ())
add_edge!(model, a, c, :edge; index=2)
add_edge!(model, a, c, :edge; index = 2)
@test length(edges(model)) == 2

# Test 2: Test getting all edges from a model with a specific node
Expand All @@ -345,7 +345,7 @@ end
b = NodeLabel(:b, 2)
model[a] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, __context__, ())
model[b] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, __context__, ())
add_edge!(model, a, b, :edge; index=1)
add_edge!(model, a, b, :edge; index = 1)
@test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)]

model = create_model()
Expand All @@ -357,7 +357,7 @@ end
model[a[i]] = VariableNodeData(:a, VariableNodeOptions(), i, nothing, __context__, ())
b[i] = NodeLabel(:b, i)
model[b[i]] = VariableNodeData(:b, VariableNodeOptions(), i, nothing, __context__, ())
add_edge!(model, a[i], b[i], :edge; index=i)
add_edge!(model, a[i], b[i], :edge; index = i)
end
for n in b
@test n neighbors(model, a)
Expand Down Expand Up @@ -1294,8 +1294,8 @@ end
end

@testitem "sort_interfaces" begin
import GraphPPL: sort_interfaces
include("model_zoo.jl")
import GraphPPL: sort_interfaces
include("model_zoo.jl")

# Test 1: Test that sort_interfaces sorts the interfaces in the correct order
@test sort_interfaces(NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1)
Expand All @@ -1308,5 +1308,4 @@ end
@test sort_interfaces(NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1)

@test_throws ErrorException sort_interfaces(NormalMeanVariance, (σ = 1, μ = 1, τ = 1))

end
end

0 comments on commit 14a9852

Please sign in to comment.