Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -45,6 +46,7 @@ HDF5 = "0.17"
IterTools = "1.10.0"
JSON = "0.21"
NPZ = "0.4"
OrderedCollections = "1.8.0"
PrecompileTools = "1"
Preferences = "v1.4.3"
Requires = "1"
Expand All @@ -61,4 +63,4 @@ julia = "1.10"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
1 change: 1 addition & 0 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const FINCH_VERSION = VersionNumber(
TOML.parsefile(joinpath(dirname(@__DIR__), "Project.toml"))["version"]
)

include("util/stable_set.jl")
include("util/convenience.jl")
include("util/special_functions.jl")
include("util/shims.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/Galley/FinchCompat/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ function remove_reorders(prgm::LogicNode)
)(
expr
)
bc_idxs = OrderedSet()
bc_idxs = StableSet()
for n in PostOrderDFS(expr)
if n.kind == reorder
bc_idxs = bc_idxs ∪ setdiff(n.idxs, getfields(n.arg))
end
end
table_idxs = OrderedSet()
table_idxs = StableSet()
for n in PostOrderDFS(expr)
if n.kind == table
table_idxs = table_idxs ∪ n.idxs
Expand Down
2 changes: 1 addition & 1 deletion src/Galley/Galley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Statistics
using Finch
using Finch: Element, SparseListLevel, SparseDict, Dense, SparseCOO, fsparse_impl,
compute_parse,
isimmediate, set_options, flatten_plans, initmax, initmin
isimmediate, set_options, flatten_plans, initmax, initmin, StableSet
using Finch.FinchNotation: index_instance, variable_instance, tag_instance,
literal_instance,
access_instance, reader_instance, updater_instance, assign_instance,
Expand Down
2 changes: 1 addition & 1 deletion src/Galley/LogicalOptimizer/LogicalOptimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function high_level_optimize_query(
if check_dnf
min_cost = cnf_cost
min_query = canonicalize(plan_copy(q), false)
visited_queries = OrderedSet()
visited_queries = StableSet()
finished = false
while !finished
finished = true
Expand Down
18 changes: 9 additions & 9 deletions src/Galley/LogicalOptimizer/annotated-query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mutable struct AnnotatedQuery
parent_idxs::OrderedDict{IndexExpr,Vector{IndexExpr}} # Index orders that must be respected
original_idx::OrderedDict{IndexExpr,IndexExpr} # When an index is split into many, we track their relationship.
connected_components::Vector{Vector{IndexExpr}}
connected_idxs::OrderedDict{IndexExpr,OrderedSet{IndexExpr}}
connected_idxs::OrderedDict{IndexExpr,StableSet{IndexExpr}}
end

function copy_aq(aq::AnnotatedQuery)
Expand Down Expand Up @@ -229,7 +229,7 @@ function AnnotatedQuery(q::PlanNode, ST)
end

parent_idxs = OrderedDict(i => [] for i in reduce_idxs)
connected_idxs = OrderedDict(i => OrderedSet{IndexExpr}() for i in reduce_idxs)
connected_idxs = OrderedDict(i => StableSet{IndexExpr}() for i in reduce_idxs)
for idx1 in reduce_idxs
idx1_op = idx_op[idx1]
idx1_bottom_root = id_to_node[idx_lowest_root[idx1]]
Expand Down Expand Up @@ -283,8 +283,8 @@ function get_reduce_query(reduce_idx, aq)
root_node_id = aq.idx_lowest_root[reduce_idx]
root_node = aq.id_to_node[root_node_id]
query_expr = nothing
idxs_to_be_reduced = OrderedSet([reduce_idx])
nodes_to_remove = OrderedSet()
idxs_to_be_reduced = StableSet([reduce_idx])
nodes_to_remove = StableSet()
node_to_replace = -1
reducible_idxs = get_reducible_idxs(aq)
if root_node.kind === MapJoin && isdistributive(root_node.op.val, reduce_op)
Expand Down Expand Up @@ -343,7 +343,7 @@ function get_reduce_query(reduce_idx, aq)
end
end
end
final_idxs_to_be_reduced = OrderedSet(
final_idxs_to_be_reduced = StableSet(
Index(aq.original_idx[idx]) for idx in idxs_to_be_reduced
)
reduced_idxs = idxs_to_be_reduced
Expand All @@ -356,7 +356,7 @@ function get_reduce_query(reduce_idx, aq)
query_expr.stats = reduce_tensor_stats(
query_expr.op.val,
query_expr.init.val,
OrderedSet([idx for idx in final_idxs_to_be_reduced]),
StableSet([idx for idx in final_idxs_to_be_reduced]),
query_expr.arg.stats,
)
query = Query(Alias(galley_gensym("A")), query_expr)
Expand Down Expand Up @@ -492,7 +492,7 @@ function reduce_idx!(reduce_idx, aq; do_condense=false)
new_idx_op = OrderedDict{IndexExpr,Any}()
new_idx_init = OrderedDict{IndexExpr,Any}()
new_parent_idxs = OrderedDict{IndexExpr,Vector{IndexExpr}}()
new_connected_idxs = OrderedDict{IndexExpr,OrderedSet{IndexExpr}}()
new_connected_idxs = OrderedDict{IndexExpr,StableSet{IndexExpr}}()
for idx in keys(aq.idx_lowest_root)
if idx in reduced_idxs
continue
Expand Down Expand Up @@ -523,7 +523,7 @@ function reduce_idx!(reduce_idx, aq; do_condense=false)
new_components = get_idx_connected_components(new_parent_idxs, new_connected_idxs)

# Here, we update the statistics for all nodes above the affected nodes
rel_child_nodes = OrderedSet{Int}(n for n in nodes_to_remove)
rel_child_nodes = StableSet{Int}(n for n in nodes_to_remove)
push!(rel_child_nodes, node_to_replace)
for n in PostOrderDFS(new_point_expr)
if n.node_id == node_to_replace
Expand Down Expand Up @@ -571,7 +571,7 @@ end

# Given a node in the tree, return all indices which can be reduced after computing that subtree.
function get_reducible_idxs(aq, n)
reduce_idxs = OrderedSet{IndexExpr}()
reduce_idxs = StableSet{IndexExpr}()
for idx in aq.reduce_idxs
idx_root = aq.idx_lowest_root[idx]
if intree(idx_root, n)
Expand Down
6 changes: 3 additions & 3 deletions src/Galley/LogicalOptimizer/query-splitter.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function count_index_occurences(nodes)
vars = OrderedSet()
vars = StableSet()
occurences = 0
for n in nodes
for c in PostOrderDFS(n)
Expand Down Expand Up @@ -89,7 +89,7 @@ function split_query(q::PlanNode, ST, max_kernel_size, alias_stats, verbose)
while cur_occurences > max_kernel_size
nodes_to_remove = nothing
new_query = nothing
new_agg_idxs = OrderedSet()
new_agg_idxs = StableSet()
min_cost = Inf
for node in PostOrderDFS(pe)
if node.kind in (Value, Input, Alias, Index) ||
Expand Down Expand Up @@ -127,7 +127,7 @@ function split_query(q::PlanNode, ST, max_kernel_size, alias_stats, verbose)
cache_key = sort([n.node_id for n in s])
if !haskey(cost_cache, cache_key)
s_stat = merge_tensor_stats(node.op.val, [n.stats for n in s]...)
s_reduce_idxs = OrderedSet{IndexExpr}()
s_reduce_idxs = StableSet{IndexExpr}()
for idx in n_reduce_idxs
if !any([
idx ∈ get_index_set(n.stats) for n in setdiff(node.args, s)
Expand Down
4 changes: 2 additions & 2 deletions src/Galley/PhysicalOptimizer/format-selector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ function select_output_format(output_stats::TensorStats,
approx_sparsity = approx_nnz_per / get_dim_size(output_stats, prefix[1])
dense_memory_footprint = prev_nnz * get_dim_size(output_stats, prefix[1])
if approx_sparsity > 0.5 && dense_memory_footprint < 3 * 10^10
# if get_dim_space_size(output_stats, OrderedSet(prefix)) > 10^10
# if get_dim_space_size(output_stats, StableSet(prefix)) > 10^10
# throw(OutOfMemoryError())
# end
push!(formats, t_dense)
elseif approx_sparsity > 0.05 && dense_memory_footprint < 3 * 10^10 &&
(length(formats) == 0 ? true : formats[end] != t_bytemap) # TODO: Check out finch double bytemap bug
# if get_dim_space_size(output_stats, OrderedSet(prefix)) > 10^10
# if get_dim_space_size(output_stats, StableSet(prefix)) > 10^10
# throw(OutOfMemoryError())
# end
push!(formats, t_bytemap)
Expand Down
10 changes: 5 additions & 5 deletions src/Galley/PhysicalOptimizer/loop-ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ function needs_reformat(stat::TensorStats, prefix::Vector{IndexExpr})
end

function get_reformat_set(input_stats::Vector{TensorStats}, prefix::Vector{IndexExpr})
ref_set = OrderedSet()
ref_set = StableSet()
for i in eachindex(input_stats)
needs_reformat(input_stats[i], prefix) && push!(ref_set, i)
end
return ref_set
end

PLAN_CLASS = Tuple{OrderedSet{IndexExpr},OrderedSet{Int}}
PLAN_CLASS = Tuple{StableSet{IndexExpr},StableSet{Int}}
PLAN = Tuple{Vector{IndexExpr},Float64}

function cost_of_plan_class(pc::PLAN_CLASS, reformat_costs, output_size)
Expand Down Expand Up @@ -98,13 +98,13 @@ function get_join_loop_order_bounded(disjunct_and_conjunct_stats,
reformat_costs = OrderedDict(
i => cost_of_reformat(transposable_stats[i]) for i in eachindex(transposable_stats)
)
PLAN_CLASS = Tuple{OrderedSet{IndexExpr},OrderedSet{Int}}
PLAN_CLASS = Tuple{StableSet{IndexExpr},StableSet{Int}}
PLAN = Tuple{Vector{IndexExpr},Float64}
optimal_plans = OrderedDict{PLAN_CLASS,PLAN}()
for var in all_vars
prefix = [var]
rf_set = get_reformat_set(transposable_stats, prefix)
class = (OrderedSet(prefix), rf_set)
class = (StableSet(prefix), rf_set)
cost = get_prefix_cost(prefix, output_vars, conjunct_stats, disjunct_stats)
optimal_plans[class] = (prefix, cost)
end
Expand All @@ -116,7 +116,7 @@ function get_join_loop_order_bounded(disjunct_and_conjunct_stats,
prefix = plan[1]
cost = plan[2]
# We only consider extensions that don't result in cross products
potential_vars = OrderedSet{IndexExpr}()
potential_vars = StableSet{IndexExpr}()
for stat in all_stats
index_set = get_index_set(stat)
if length(∩(index_set, prefix_set)) > 0
Expand Down
4 changes: 2 additions & 2 deletions src/Galley/PhysicalOptimizer/physical-optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ function logical_query_to_physical_queries(

agg_op = nothing
agg_init = nothing
reduce_idxs = OrderedSet{IndexExpr}()
reduce_idxs = StableSet{IndexExpr}()
if expr.kind == Aggregate
agg_op = expr.op
agg_init = expr.init
reduce_idxs = OrderedSet{IndexExpr}([i.name for i in expr.idxs])
reduce_idxs = StableSet{IndexExpr}([i.name for i in expr.idxs])
expr = expr.arg
end

Expand Down
8 changes: 4 additions & 4 deletions src/Galley/PhysicalOptimizer/validate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function get_input_indices(n::PlanNode)
elseif n.kind == Alias
get_index_set(n.stats)
elseif n.kind == Value
OrderedSet{IndexExpr}()
StableSet{IndexExpr}()
elseif n.kind == Aggregate
get_input_indices(n.arg)
elseif n.kind == MapJoin
Expand Down Expand Up @@ -121,9 +121,9 @@ function validate_physical_query(q::PlanNode)
end
@assert idx_dim[idx] == dim "idx:$idx dim:$dim query:$q "
end
output_indices = OrderedSet([idx.name for idx in q.expr.idx_order])
@assert input_indices ∪ output_indices == OrderedSet([idx.name for idx in q.loop_order])
@assert OrderedSet(output_indices) == OrderedSet([idx.name for idx in q.expr.idx_order])
output_indices = StableSet([idx.name for idx in q.expr.idx_order])
@assert input_indices ∪ output_indices == StableSet([idx.name for idx in q.loop_order])
@assert StableSet(output_indices) == StableSet([idx.name for idx in q.expr.idx_order])
check_sorted_inputs(q.expr, [idx.name for idx in q.loop_order])
check_protocols(q.expr)
check_formats(q.expr)
Expand Down
2 changes: 1 addition & 1 deletion src/Galley/PlanAST/canonicalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function _insert_statistics!(
expr.stats = reduce_tensor_stats(
expr.op.val,
expr.init.val,
OrderedSet{IndexExpr}([idx.name for idx in expr.idxs]),
StableSet{IndexExpr}([idx.name for idx in expr.idxs]),
expr.arg.stats,
)
elseif expr.kind === Materialize
Expand Down
2 changes: 1 addition & 1 deletion src/Galley/PlanAST/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function Base.:(==)(a::PlanNode, b::PlanNode)
return b.kind === Index && a.name == b.name
elseif a.kind == Aggregate
return b.kind === Aggregate && a.op == b.op && a.init == b.init &&
OrderedSet(a.idxs) == OrderedSet(b.idxs) && a.arg == b.arg
StableSet(a.idxs) == StableSet(b.idxs) && a.arg == b.arg
elseif istree(a)
return a.kind === b.kind && a.children == b.children
else
Expand Down
2 changes: 1 addition & 1 deletion src/Galley/TensorStats/StaticBitset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function SmallBitSet(ints::Vector{Int})
return s
end

function SmallBitSet(ints::OrderedSet{Int})
function SmallBitSet(ints::StableSet{Int})
s = SmallBitSet()
for i in ints
s = _setint(s, i, true)
Expand Down
4 changes: 2 additions & 2 deletions src/Galley/TensorStats/cost-estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const DenseAllocateCost = 0.5
const SparseAllocateCost = 60

# We estimate the prefix cost based on the number of iterations in that prefix.
function get_loop_lookups(vars::OrderedSet{IndexExpr}, rel_conjuncts, rel_disjuncts)
function get_loop_lookups(vars::StableSet{IndexExpr}, rel_conjuncts, rel_disjuncts)
# This tensor stats doesn't actually correspond to a particular place in the expr tree,
# so we unfortunately have to mangle the statistics interface a bit.
rel_conjuncts = map(stat -> set_fill_value!(stat, false), rel_conjuncts)
Expand Down Expand Up @@ -45,7 +45,7 @@ function get_prefix_cost(
new_prefix::Vector{IndexExpr}, output_vars, conjunct_stats, disjunct_stats
)
new_var = new_prefix[end]
prefix_set = OrderedSet(new_prefix)
prefix_set = StableSet(new_prefix)
rel_conjuncts = [
stat for stat in conjunct_stats if !isempty(get_index_set(stat) ∩ prefix_set)
]
Expand Down
18 changes: 9 additions & 9 deletions src/Galley/TensorStats/propagate-stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function merge_tensor_stats_union(op, all_stats::Vararg{TensorStats})
end

function reduce_tensor_stats(
op, init, reduce_indices::OrderedSet{IndexExpr}, stats::TensorStats
op, init, reduce_indices::StableSet{IndexExpr}, stats::TensorStats
)
throw(error("reduce_tensor_stats not implemented for: ", typeof(stats)))
end
Expand Down Expand Up @@ -38,7 +38,7 @@ function merge_tensor_def(op, all_defs::Vararg{TensorDef})
)
end

function reduce_tensor_def(op, init, reduce_indices::OrderedSet{IndexExpr}, def::TensorDef)
function reduce_tensor_def(op, init, reduce_indices::StableSet{IndexExpr}, def::TensorDef)
op = op isa PlanNode ? op.val : op
init = init isa PlanNode ? init.val : init
if isnothing(init)
Expand Down Expand Up @@ -110,10 +110,10 @@ function merge_tensor_stats(op::PlanNode, all_stats::Vararg{ST}) where {ST<:Tens
end

function reduce_tensor_stats(
op, init, reduce_indices::Union{Vector{PlanNode},OrderedSet{PlanNode}}, stats::ST
op, init, reduce_indices::Union{Vector{PlanNode},StableSet{PlanNode}}, stats::ST
) where {ST<:TensorStats}
return reduce_tensor_stats(
op, init, OrderedSet{IndexExpr}([idx.name for idx in reduce_indices]), stats
op, init, StableSet{IndexExpr}([idx.name for idx in reduce_indices]), stats
)
end

Expand Down Expand Up @@ -154,7 +154,7 @@ function merge_tensor_stats_union(op, new_def::TensorDef, all_stats::Vararg{Naiv
end

function reduce_tensor_stats(
op, init, reduce_indices::OrderedSet{IndexExpr}, stats::NaiveStats
op, init, reduce_indices::StableSet{IndexExpr}, stats::NaiveStats
)
if length(reduce_indices) == 0
return copy_stats(stats)
Expand Down Expand Up @@ -227,7 +227,7 @@ function merge_tensor_stats_join(op, new_def::TensorDef, all_stats::Vararg{DCSta
new_def,
final_idx_2_int,
final_int_2_idx,
OrderedSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dc_dict),
StableSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dc_dict),
)
return new_stats
end
Expand Down Expand Up @@ -284,19 +284,19 @@ function merge_tensor_stats_union(op, new_def::TensorDef, all_stats::Vararg{DCSt
#=
for Y in subsets(collect(get_index_set(new_def)))
proj_dc_key = (X=BitSet(), Y=idxs_to_bitset(final_idx_2_int, Y))
new_dcs[proj_dc_key] = min(get(new_dcs, proj_dc_key, typemax(UInt)/2), get_dim_space_size(new_def, OrderedSet(Y)))
new_dcs[proj_dc_key] = min(get(new_dcs, proj_dc_key, typemax(UInt)/2), get_dim_space_size(new_def, StableSet(Y)))
end
=#
return DCStats(
new_def,
final_idx_2_int,
final_int_2_idx,
OrderedSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dcs),
StableSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dcs),
)
end

function reduce_tensor_stats(
op, init, reduce_indices::OrderedSet{IndexExpr}, stats::DCStats
op, init, reduce_indices::StableSet{IndexExpr}, stats::DCStats
)
if length(reduce_indices) == 0
return copy_stats(stats)
Expand Down
Loading
Loading