diff --git a/Project.toml b/Project.toml index 10397f622..5bae53269 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -45,6 +46,7 @@ HDF5 = "0.17" IterTools = "1.10.0" JSON = "0.21" NPZ = "0.4" +OrderedCollections = "1.8.0" PrecompileTools = "1" Preferences = "v1.4.3" Requires = "1" @@ -61,4 +63,4 @@ julia = "1.10" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77" \ No newline at end of file +TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77" diff --git a/src/Finch.jl b/src/Finch.jl index 43e6eb32a..aba935ae5 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -79,6 +79,7 @@ const FINCH_VERSION = VersionNumber( TOML.parsefile(joinpath(dirname(@__DIR__), "Project.toml"))["version"] ) +include("util/stable_set.jl") include("util/convenience.jl") include("util/special_functions.jl") include("util/shims.jl") diff --git a/src/Galley/FinchCompat/translate.jl b/src/Galley/FinchCompat/translate.jl index 4225a2611..0789a3595 100644 --- a/src/Galley/FinchCompat/translate.jl +++ b/src/Galley/FinchCompat/translate.jl @@ -120,13 +120,13 @@ function remove_reorders(prgm::LogicNode) )( expr ) - bc_idxs = OrderedSet() + bc_idxs = StableSet() for n in PostOrderDFS(expr) if n.kind == reorder bc_idxs = bc_idxs ∪ setdiff(n.idxs, getfields(n.arg)) end end - table_idxs = OrderedSet() + table_idxs = StableSet() for n in PostOrderDFS(expr) if n.kind == table table_idxs = table_idxs ∪ n.idxs diff --git a/src/Galley/Galley.jl b/src/Galley/Galley.jl index 4c1fd7763..2483b55bd 100644 --- a/src/Galley/Galley.jl +++ b/src/Galley/Galley.jl @@ -13,7 +13,7 @@ using Statistics using Finch using Finch: Element, SparseListLevel, SparseDict, Dense, SparseCOO, fsparse_impl, compute_parse, - isimmediate, set_options, flatten_plans, initmax, initmin + isimmediate, set_options, flatten_plans, initmax, initmin, StableSet using Finch.FinchNotation: index_instance, variable_instance, tag_instance, literal_instance, access_instance, reader_instance, updater_instance, assign_instance, diff --git a/src/Galley/LogicalOptimizer/LogicalOptimizer.jl b/src/Galley/LogicalOptimizer/LogicalOptimizer.jl index 852a5a095..3cc5513be 100644 --- a/src/Galley/LogicalOptimizer/LogicalOptimizer.jl +++ b/src/Galley/LogicalOptimizer/LogicalOptimizer.jl @@ -123,7 +123,7 @@ function high_level_optimize_query( if check_dnf min_cost = cnf_cost min_query = canonicalize(plan_copy(q), false) - visited_queries = OrderedSet() + visited_queries = StableSet() finished = false while !finished finished = true diff --git a/src/Galley/LogicalOptimizer/annotated-query.jl b/src/Galley/LogicalOptimizer/annotated-query.jl index af7967fe6..190e9b50c 100644 --- a/src/Galley/LogicalOptimizer/annotated-query.jl +++ b/src/Galley/LogicalOptimizer/annotated-query.jl @@ -12,7 +12,7 @@ mutable struct AnnotatedQuery parent_idxs::OrderedDict{IndexExpr,Vector{IndexExpr}} # Index orders that must be respected original_idx::OrderedDict{IndexExpr,IndexExpr} # When an index is split into many, we track their relationship. connected_components::Vector{Vector{IndexExpr}} - connected_idxs::OrderedDict{IndexExpr,OrderedSet{IndexExpr}} + connected_idxs::OrderedDict{IndexExpr,StableSet{IndexExpr}} end function copy_aq(aq::AnnotatedQuery) @@ -229,7 +229,7 @@ function AnnotatedQuery(q::PlanNode, ST) end parent_idxs = OrderedDict(i => [] for i in reduce_idxs) - connected_idxs = OrderedDict(i => OrderedSet{IndexExpr}() for i in reduce_idxs) + connected_idxs = OrderedDict(i => StableSet{IndexExpr}() for i in reduce_idxs) for idx1 in reduce_idxs idx1_op = idx_op[idx1] idx1_bottom_root = id_to_node[idx_lowest_root[idx1]] @@ -283,8 +283,8 @@ function get_reduce_query(reduce_idx, aq) root_node_id = aq.idx_lowest_root[reduce_idx] root_node = aq.id_to_node[root_node_id] query_expr = nothing - idxs_to_be_reduced = OrderedSet([reduce_idx]) - nodes_to_remove = OrderedSet() + idxs_to_be_reduced = StableSet([reduce_idx]) + nodes_to_remove = StableSet() node_to_replace = -1 reducible_idxs = get_reducible_idxs(aq) if root_node.kind === MapJoin && isdistributive(root_node.op.val, reduce_op) @@ -343,7 +343,7 @@ function get_reduce_query(reduce_idx, aq) end end end - final_idxs_to_be_reduced = OrderedSet( + final_idxs_to_be_reduced = StableSet( Index(aq.original_idx[idx]) for idx in idxs_to_be_reduced ) reduced_idxs = idxs_to_be_reduced @@ -356,7 +356,7 @@ function get_reduce_query(reduce_idx, aq) query_expr.stats = reduce_tensor_stats( query_expr.op.val, query_expr.init.val, - OrderedSet([idx for idx in final_idxs_to_be_reduced]), + StableSet([idx for idx in final_idxs_to_be_reduced]), query_expr.arg.stats, ) query = Query(Alias(galley_gensym("A")), query_expr) @@ -492,7 +492,7 @@ function reduce_idx!(reduce_idx, aq; do_condense=false) new_idx_op = OrderedDict{IndexExpr,Any}() new_idx_init = OrderedDict{IndexExpr,Any}() new_parent_idxs = OrderedDict{IndexExpr,Vector{IndexExpr}}() - new_connected_idxs = OrderedDict{IndexExpr,OrderedSet{IndexExpr}}() + new_connected_idxs = OrderedDict{IndexExpr,StableSet{IndexExpr}}() for idx in keys(aq.idx_lowest_root) if idx in reduced_idxs continue @@ -523,7 +523,7 @@ function reduce_idx!(reduce_idx, aq; do_condense=false) new_components = get_idx_connected_components(new_parent_idxs, new_connected_idxs) # Here, we update the statistics for all nodes above the affected nodes - rel_child_nodes = OrderedSet{Int}(n for n in nodes_to_remove) + rel_child_nodes = StableSet{Int}(n for n in nodes_to_remove) push!(rel_child_nodes, node_to_replace) for n in PostOrderDFS(new_point_expr) if n.node_id == node_to_replace @@ -571,7 +571,7 @@ end # Given a node in the tree, return all indices which can be reduced after computing that subtree. function get_reducible_idxs(aq, n) - reduce_idxs = OrderedSet{IndexExpr}() + reduce_idxs = StableSet{IndexExpr}() for idx in aq.reduce_idxs idx_root = aq.idx_lowest_root[idx] if intree(idx_root, n) diff --git a/src/Galley/LogicalOptimizer/query-splitter.jl b/src/Galley/LogicalOptimizer/query-splitter.jl index 631fbc396..81a24e2c0 100644 --- a/src/Galley/LogicalOptimizer/query-splitter.jl +++ b/src/Galley/LogicalOptimizer/query-splitter.jl @@ -1,5 +1,5 @@ function count_index_occurences(nodes) - vars = OrderedSet() + vars = StableSet() occurences = 0 for n in nodes for c in PostOrderDFS(n) @@ -89,7 +89,7 @@ function split_query(q::PlanNode, ST, max_kernel_size, alias_stats, verbose) while cur_occurences > max_kernel_size nodes_to_remove = nothing new_query = nothing - new_agg_idxs = OrderedSet() + new_agg_idxs = StableSet() min_cost = Inf for node in PostOrderDFS(pe) if node.kind in (Value, Input, Alias, Index) || @@ -127,7 +127,7 @@ function split_query(q::PlanNode, ST, max_kernel_size, alias_stats, verbose) cache_key = sort([n.node_id for n in s]) if !haskey(cost_cache, cache_key) s_stat = merge_tensor_stats(node.op.val, [n.stats for n in s]...) - s_reduce_idxs = OrderedSet{IndexExpr}() + s_reduce_idxs = StableSet{IndexExpr}() for idx in n_reduce_idxs if !any([ idx ∈ get_index_set(n.stats) for n in setdiff(node.args, s) diff --git a/src/Galley/PhysicalOptimizer/format-selector.jl b/src/Galley/PhysicalOptimizer/format-selector.jl index 966907b41..9478c2fce 100644 --- a/src/Galley/PhysicalOptimizer/format-selector.jl +++ b/src/Galley/PhysicalOptimizer/format-selector.jl @@ -45,13 +45,13 @@ function select_output_format(output_stats::TensorStats, approx_sparsity = approx_nnz_per / get_dim_size(output_stats, prefix[1]) dense_memory_footprint = prev_nnz * get_dim_size(output_stats, prefix[1]) if approx_sparsity > 0.5 && dense_memory_footprint < 3 * 10^10 - # if get_dim_space_size(output_stats, OrderedSet(prefix)) > 10^10 + # if get_dim_space_size(output_stats, StableSet(prefix)) > 10^10 # throw(OutOfMemoryError()) # end push!(formats, t_dense) elseif approx_sparsity > 0.05 && dense_memory_footprint < 3 * 10^10 && (length(formats) == 0 ? true : formats[end] != t_bytemap) # TODO: Check out finch double bytemap bug - # if get_dim_space_size(output_stats, OrderedSet(prefix)) > 10^10 + # if get_dim_space_size(output_stats, StableSet(prefix)) > 10^10 # throw(OutOfMemoryError()) # end push!(formats, t_bytemap) diff --git a/src/Galley/PhysicalOptimizer/loop-ordering.jl b/src/Galley/PhysicalOptimizer/loop-ordering.jl index 5ff627882..7d11e02d3 100644 --- a/src/Galley/PhysicalOptimizer/loop-ordering.jl +++ b/src/Galley/PhysicalOptimizer/loop-ordering.jl @@ -43,14 +43,14 @@ function needs_reformat(stat::TensorStats, prefix::Vector{IndexExpr}) end function get_reformat_set(input_stats::Vector{TensorStats}, prefix::Vector{IndexExpr}) - ref_set = OrderedSet() + ref_set = StableSet() for i in eachindex(input_stats) needs_reformat(input_stats[i], prefix) && push!(ref_set, i) end return ref_set end -PLAN_CLASS = Tuple{OrderedSet{IndexExpr},OrderedSet{Int}} +PLAN_CLASS = Tuple{StableSet{IndexExpr},StableSet{Int}} PLAN = Tuple{Vector{IndexExpr},Float64} function cost_of_plan_class(pc::PLAN_CLASS, reformat_costs, output_size) @@ -98,13 +98,13 @@ function get_join_loop_order_bounded(disjunct_and_conjunct_stats, reformat_costs = OrderedDict( i => cost_of_reformat(transposable_stats[i]) for i in eachindex(transposable_stats) ) - PLAN_CLASS = Tuple{OrderedSet{IndexExpr},OrderedSet{Int}} + PLAN_CLASS = Tuple{StableSet{IndexExpr},StableSet{Int}} PLAN = Tuple{Vector{IndexExpr},Float64} optimal_plans = OrderedDict{PLAN_CLASS,PLAN}() for var in all_vars prefix = [var] rf_set = get_reformat_set(transposable_stats, prefix) - class = (OrderedSet(prefix), rf_set) + class = (StableSet(prefix), rf_set) cost = get_prefix_cost(prefix, output_vars, conjunct_stats, disjunct_stats) optimal_plans[class] = (prefix, cost) end @@ -116,7 +116,7 @@ function get_join_loop_order_bounded(disjunct_and_conjunct_stats, prefix = plan[1] cost = plan[2] # We only consider extensions that don't result in cross products - potential_vars = OrderedSet{IndexExpr}() + potential_vars = StableSet{IndexExpr}() for stat in all_stats index_set = get_index_set(stat) if length(∩(index_set, prefix_set)) > 0 diff --git a/src/Galley/PhysicalOptimizer/physical-optimizer.jl b/src/Galley/PhysicalOptimizer/physical-optimizer.jl index 5b7b1498f..58f5f4268 100644 --- a/src/Galley/PhysicalOptimizer/physical-optimizer.jl +++ b/src/Galley/PhysicalOptimizer/physical-optimizer.jl @@ -113,11 +113,11 @@ function logical_query_to_physical_queries( agg_op = nothing agg_init = nothing - reduce_idxs = OrderedSet{IndexExpr}() + reduce_idxs = StableSet{IndexExpr}() if expr.kind == Aggregate agg_op = expr.op agg_init = expr.init - reduce_idxs = OrderedSet{IndexExpr}([i.name for i in expr.idxs]) + reduce_idxs = StableSet{IndexExpr}([i.name for i in expr.idxs]) expr = expr.arg end diff --git a/src/Galley/PhysicalOptimizer/validate.jl b/src/Galley/PhysicalOptimizer/validate.jl index 7701fcfb8..9f0d64ab7 100644 --- a/src/Galley/PhysicalOptimizer/validate.jl +++ b/src/Galley/PhysicalOptimizer/validate.jl @@ -11,7 +11,7 @@ function get_input_indices(n::PlanNode) elseif n.kind == Alias get_index_set(n.stats) elseif n.kind == Value - OrderedSet{IndexExpr}() + StableSet{IndexExpr}() elseif n.kind == Aggregate get_input_indices(n.arg) elseif n.kind == MapJoin @@ -121,9 +121,9 @@ function validate_physical_query(q::PlanNode) end @assert idx_dim[idx] == dim "idx:$idx dim:$dim query:$q " end - output_indices = OrderedSet([idx.name for idx in q.expr.idx_order]) - @assert input_indices ∪ output_indices == OrderedSet([idx.name for idx in q.loop_order]) - @assert OrderedSet(output_indices) == OrderedSet([idx.name for idx in q.expr.idx_order]) + output_indices = StableSet([idx.name for idx in q.expr.idx_order]) + @assert input_indices ∪ output_indices == StableSet([idx.name for idx in q.loop_order]) + @assert StableSet(output_indices) == StableSet([idx.name for idx in q.expr.idx_order]) check_sorted_inputs(q.expr, [idx.name for idx in q.loop_order]) check_protocols(q.expr) check_formats(q.expr) diff --git a/src/Galley/PlanAST/canonicalize.jl b/src/Galley/PlanAST/canonicalize.jl index a0adba249..bb58948d8 100644 --- a/src/Galley/PlanAST/canonicalize.jl +++ b/src/Galley/PlanAST/canonicalize.jl @@ -83,7 +83,7 @@ function _insert_statistics!( expr.stats = reduce_tensor_stats( expr.op.val, expr.init.val, - OrderedSet{IndexExpr}([idx.name for idx in expr.idxs]), + StableSet{IndexExpr}([idx.name for idx in expr.idxs]), expr.arg.stats, ) elseif expr.kind === Materialize diff --git a/src/Galley/PlanAST/plan.jl b/src/Galley/PlanAST/plan.jl index 202ce852e..e1b9bd4f9 100644 --- a/src/Galley/PlanAST/plan.jl +++ b/src/Galley/PlanAST/plan.jl @@ -325,7 +325,7 @@ function Base.:(==)(a::PlanNode, b::PlanNode) return b.kind === Index && a.name == b.name elseif a.kind == Aggregate return b.kind === Aggregate && a.op == b.op && a.init == b.init && - OrderedSet(a.idxs) == OrderedSet(b.idxs) && a.arg == b.arg + StableSet(a.idxs) == StableSet(b.idxs) && a.arg == b.arg elseif istree(a) return a.kind === b.kind && a.children == b.children else diff --git a/src/Galley/TensorStats/StaticBitset.jl b/src/Galley/TensorStats/StaticBitset.jl index 82d21d729..59a10bd3e 100644 --- a/src/Galley/TensorStats/StaticBitset.jl +++ b/src/Galley/TensorStats/StaticBitset.jl @@ -25,7 +25,7 @@ function SmallBitSet(ints::Vector{Int}) return s end -function SmallBitSet(ints::OrderedSet{Int}) +function SmallBitSet(ints::StableSet{Int}) s = SmallBitSet() for i in ints s = _setint(s, i, true) diff --git a/src/Galley/TensorStats/cost-estimates.jl b/src/Galley/TensorStats/cost-estimates.jl index ca0d9a30d..f233090db 100644 --- a/src/Galley/TensorStats/cost-estimates.jl +++ b/src/Galley/TensorStats/cost-estimates.jl @@ -12,7 +12,7 @@ const DenseAllocateCost = 0.5 const SparseAllocateCost = 60 # We estimate the prefix cost based on the number of iterations in that prefix. -function get_loop_lookups(vars::OrderedSet{IndexExpr}, rel_conjuncts, rel_disjuncts) +function get_loop_lookups(vars::StableSet{IndexExpr}, rel_conjuncts, rel_disjuncts) # This tensor stats doesn't actually correspond to a particular place in the expr tree, # so we unfortunately have to mangle the statistics interface a bit. rel_conjuncts = map(stat -> set_fill_value!(stat, false), rel_conjuncts) @@ -45,7 +45,7 @@ function get_prefix_cost( new_prefix::Vector{IndexExpr}, output_vars, conjunct_stats, disjunct_stats ) new_var = new_prefix[end] - prefix_set = OrderedSet(new_prefix) + prefix_set = StableSet(new_prefix) rel_conjuncts = [ stat for stat in conjunct_stats if !isempty(get_index_set(stat) ∩ prefix_set) ] diff --git a/src/Galley/TensorStats/propagate-stats.jl b/src/Galley/TensorStats/propagate-stats.jl index b9b148774..7390fc47f 100644 --- a/src/Galley/TensorStats/propagate-stats.jl +++ b/src/Galley/TensorStats/propagate-stats.jl @@ -10,7 +10,7 @@ function merge_tensor_stats_union(op, all_stats::Vararg{TensorStats}) end function reduce_tensor_stats( - op, init, reduce_indices::OrderedSet{IndexExpr}, stats::TensorStats + op, init, reduce_indices::StableSet{IndexExpr}, stats::TensorStats ) throw(error("reduce_tensor_stats not implemented for: ", typeof(stats))) end @@ -38,7 +38,7 @@ function merge_tensor_def(op, all_defs::Vararg{TensorDef}) ) end -function reduce_tensor_def(op, init, reduce_indices::OrderedSet{IndexExpr}, def::TensorDef) +function reduce_tensor_def(op, init, reduce_indices::StableSet{IndexExpr}, def::TensorDef) op = op isa PlanNode ? op.val : op init = init isa PlanNode ? init.val : init if isnothing(init) @@ -110,10 +110,10 @@ function merge_tensor_stats(op::PlanNode, all_stats::Vararg{ST}) where {ST<:Tens end function reduce_tensor_stats( - op, init, reduce_indices::Union{Vector{PlanNode},OrderedSet{PlanNode}}, stats::ST + op, init, reduce_indices::Union{Vector{PlanNode},StableSet{PlanNode}}, stats::ST ) where {ST<:TensorStats} return reduce_tensor_stats( - op, init, OrderedSet{IndexExpr}([idx.name for idx in reduce_indices]), stats + op, init, StableSet{IndexExpr}([idx.name for idx in reduce_indices]), stats ) end @@ -154,7 +154,7 @@ function merge_tensor_stats_union(op, new_def::TensorDef, all_stats::Vararg{Naiv end function reduce_tensor_stats( - op, init, reduce_indices::OrderedSet{IndexExpr}, stats::NaiveStats + op, init, reduce_indices::StableSet{IndexExpr}, stats::NaiveStats ) if length(reduce_indices) == 0 return copy_stats(stats) @@ -227,7 +227,7 @@ function merge_tensor_stats_join(op, new_def::TensorDef, all_stats::Vararg{DCSta new_def, final_idx_2_int, final_int_2_idx, - OrderedSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dc_dict), + StableSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dc_dict), ) return new_stats end @@ -284,19 +284,19 @@ function merge_tensor_stats_union(op, new_def::TensorDef, all_stats::Vararg{DCSt #= for Y in subsets(collect(get_index_set(new_def))) proj_dc_key = (X=BitSet(), Y=idxs_to_bitset(final_idx_2_int, Y)) - new_dcs[proj_dc_key] = min(get(new_dcs, proj_dc_key, typemax(UInt)/2), get_dim_space_size(new_def, OrderedSet(Y))) + new_dcs[proj_dc_key] = min(get(new_dcs, proj_dc_key, typemax(UInt)/2), get_dim_space_size(new_def, StableSet(Y))) end =# return DCStats( new_def, final_idx_2_int, final_int_2_idx, - OrderedSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dcs), + StableSet{DC}(DC(key.X, key.Y, d) for (key, d) in new_dcs), ) end function reduce_tensor_stats( - op, init, reduce_indices::OrderedSet{IndexExpr}, stats::DCStats + op, init, reduce_indices::StableSet{IndexExpr}, stats::DCStats ) if length(reduce_indices) == 0 return copy_stats(stats) diff --git a/src/Galley/TensorStats/tensor-stats.jl b/src/Galley/TensorStats/tensor-stats.jl index 615f50e6c..14340da7e 100644 --- a/src/Galley/TensorStats/tensor-stats.jl +++ b/src/Galley/TensorStats/tensor-stats.jl @@ -3,7 +3,7 @@ # `Nothing` is considered a part of the physical definition which may be undefined for logical # intermediates but is required to be defined for the inputs to an executable query. @auto_hash_equals mutable struct TensorDef - index_set::OrderedSet{IndexExpr} + index_set::StableSet{IndexExpr} dim_sizes::OrderedDict{IndexExpr,Float64} fill_val::Any level_formats::Union{Nothing,Vector{LevelFormat}} @@ -12,7 +12,7 @@ end function TensorDef(x) TensorDef( - OrderedSet{IndexExpr}(), + StableSet{IndexExpr}(), OrderedDict{IndexExpr,Float64}(), x, IndexExpr[], @@ -22,7 +22,7 @@ function TensorDef(x) end function copy_def(def::TensorDef) - TensorDef(OrderedSet{IndexExpr}(x for x in def.index_set), + TensorDef(StableSet{IndexExpr}(x for x in def.index_set), OrderedDict{IndexExpr,Float64}(x for x in def.dim_sizes), def.fill_val, isnothing(def.level_formats) ? nothing : [x for x in def.level_formats], @@ -65,7 +65,7 @@ function TensorDef(tensor::Tensor, indices) ) fill_val = Finch.fill_value(tensor) return TensorDef( - OrderedSet{IndexExpr}(indices), dim_size, fill_val, level_formats, indices, nothing + StableSet{IndexExpr}(indices), dim_size, fill_val, level_formats, indices, nothing ) end @@ -75,7 +75,7 @@ function reindex_def(indices, def::TensorDef) for i in eachindex(indices) rename_dict[def.index_order[i]] = indices[i] end - new_index_set = OrderedSet{IndexExpr}() + new_index_set = StableSet{IndexExpr}() for idx in def.index_set push!(new_index_set, rename_dict[idx]) end @@ -206,7 +206,7 @@ end function estimate_nnz( stat::NaiveStats; indices=get_index_set(stat), - conditional_indices=OrderedSet{IndexExpr}(), + conditional_indices=StableSet{IndexExpr}(), ) return stat.cardinality / get_dim_space_size(stat, conditional_indices) end @@ -224,7 +224,7 @@ end function NaiveStats(x) def = TensorDef( - OrderedSet{IndexExpr}(), OrderedDict{IndexExpr,Int}(), x, nothing, nothing, nothing + StableSet{IndexExpr}(), OrderedDict{IndexExpr,Int}(), x, nothing, nothing, nothing ) return NaiveStats(def, 1) end @@ -262,7 +262,7 @@ end def::TensorDef idx_2_int::OrderedDict{IndexExpr,Int} int_2_idx::OrderedDict{Int,IndexExpr} - dcs::OrderedSet{DC} + dcs::StableSet{DC} DCStats(def, idx_2_int, int_2_idx, dcs) = new(def, idx_2_int, int_2_idx, dcs) @@ -273,7 +273,7 @@ end def = TensorDef(tensor, indices) idx_2_int = OrderedDict{IndexExpr,Int}() int_2_idx = OrderedDict{Int,IndexExpr}() - for (i, idx) in enumerate(OrderedSet(indices)) + for (i, idx) in enumerate(StableSet(indices)) idx_2_int[idx] = i int_2_idx[i] = idx end @@ -298,19 +298,19 @@ function copy_stats(stat::DCStats) copy_def(stat.def), copy(stat.idx_2_int), copy(stat.int_2_idx), - OrderedSet{DC}(dc for dc in stat.dcs), + StableSet{DC}(dc for dc in stat.dcs), ) end function DCStats(x) DCStats( TensorDef(x), OrderedDict{IndexExpr,Int}(), OrderedDict{Int,IndexExpr}(), - OrderedSet{DC}(), + StableSet{DC}(), ) end # Return a stats object where values have been geometrically rounded. function get_cannonical_stats(stat::DCStats, rel_granularity=4) - new_dcs = OrderedSet{DC}() + new_dcs = StableSet{DC}() for dc in stat.dcs push!(new_dcs, DC(dc.X, dc.Y, geometric_round(rel_granularity, dc.d))) end @@ -356,7 +356,7 @@ function idxs_to_bitset(idx_2_int::OrderedDict{IndexExpr,Int}, indices) end bitset_to_idxs(stat::DCStats, bitset) = bitset_to_idxs(stat.int_2_idx, bitset) function bitset_to_idxs(int_2_idx::OrderedDict{Int,IndexExpr}, bitset) - OrderedSet{IndexExpr}(int_2_idx[idx] for idx in bitset) + StableSet{IndexExpr}(int_2_idx[idx] for idx in bitset) end function add_dummy_idx!(stats::DCStats, i::IndexExpr; idx_pos=-1) @@ -364,13 +364,13 @@ function add_dummy_idx!(stats::DCStats, i::IndexExpr; idx_pos=-1) new_idx_int = maximum(values(stats.idx_2_int); init=0) + 1 stats.idx_2_int[i] = new_idx_int stats.int_2_idx[new_idx_int] = i - Y = idxs_to_bitset(stats, OrderedSet([i])) + Y = idxs_to_bitset(stats, StableSet([i])) push!(stats.dcs, DC(BitSet(), Y, 1)) end function fix_cardinality!(stat::DCStats, card) had_dc = false - new_dcs = OrderedSet{DC}() + new_dcs = StableSet{DC}() for dc in stat.dcs if length(dc.X) == 0 && dc.Y == get_index_bitset(stat) push!(new_dcs, DC(BitSet(), get_index_bitset(stat), min(card, dc.d))) @@ -400,7 +400,7 @@ end # When we're only attempting to infer for nnz estimation, we only need to consider # left dcs which have X = {}. -function _infer_dcs(dcs::OrderedSet{DC}; timeout=Inf, strength=0) +function _infer_dcs(dcs::StableSet{DC}; timeout=Inf, strength=0) all_dcs = OrderedDict{DCKey,Float64}() for dc in dcs all_dcs[(X=dc.X, Y=dc.Y)] = dc.d @@ -447,7 +447,7 @@ function _infer_dcs(dcs::OrderedSet{DC}; timeout=Inf, strength=0) finished = true end end - final_dcs = OrderedSet{DC}() + final_dcs = StableSet{DC}() for (dc_key, dc) in all_dcs push!(final_dcs, DC(dc_key.X, dc_key.Y, dc)) end @@ -480,7 +480,7 @@ function condense_stats!(stat::DCStats; timeout=100000, cheap=false) ) end - end_dcs = OrderedSet{DC}() + end_dcs = StableSet{DC}() for (dc_key, d) in min_dcs push!(end_dcs, DC(dc_key.X, dc_key.Y, d)) end @@ -489,7 +489,7 @@ function condense_stats!(stat::DCStats; timeout=100000, cheap=false) end function estimate_nnz( - stat::DCStats; indices=get_index_set(stat), conditional_indices=OrderedSet{IndexExpr}() + stat::DCStats; indices=get_index_set(stat), conditional_indices=StableSet{IndexExpr}() ) if length(indices) == 0 return 1 @@ -499,11 +499,11 @@ function estimate_nnz( current_weights = OrderedDict{BitSet,Float64}( conditional_indices_bitset => 1, BitSet() => 1 ) - frontier = OrderedSet{BitSet}([BitSet(), conditional_indices_bitset]) + frontier = StableSet{BitSet}([BitSet(), conditional_indices_bitset]) finished = false while !finished current_bound::Float64 = get(current_weights, indices_bitset, typemax(Float64)) - new_frontier = OrderedSet{BitSet}() + new_frontier = StableSet{BitSet}() finished = true for x in frontier weight = current_weights[x] @@ -542,10 +542,10 @@ function estimate_nnz( return min_weight end -DCStats() = DCStats(TensorDef(), OrderedSet()) +DCStats() = DCStats(TensorDef(), StableSet()) function _calc_dc_from_structure( - X::OrderedSet{IndexExpr}, Y::OrderedSet{IndexExpr}, indices::Vector{IndexExpr}, + X::StableSet{IndexExpr}, Y::StableSet{IndexExpr}, indices::Vector{IndexExpr}, s::Tensor, ) Z = [i for i in indices if i ∉ ∪(X, Y)] # Indices that we want to project out before counting @@ -577,7 +577,7 @@ function _vector_structure_to_dcs(indices::Vector{Int}, s::Tensor) d_i[] += s[i] end end - return OrderedSet{DC}([DC(BitSet(), BitSet(indices), d_i[])]) + return StableSet{DC}([DC(BitSet(), BitSet(indices), d_i[])]) end function _matrix_structure_to_dcs(indices::Vector{Int}, s::Tensor) @@ -622,7 +622,7 @@ function _matrix_structure_to_dcs(indices::Vector{Int}, s::Tensor) end i = indices[2] j = indices[1] - return OrderedSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), + return StableSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), DC(BitSet(), BitSet([j]), d_j[]), DC(BitSet([i]), BitSet([j]), d_i_j[]), DC(BitSet([j]), BitSet([i]), d_j_i[]), @@ -693,7 +693,7 @@ function _3d_structure_to_dcs(indices::Vector{Int}, s::Tensor) i = indices[3] j = indices[2] k = indices[1] - return OrderedSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), + return StableSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), DC(BitSet(), BitSet([j]), d_j[]), DC(BitSet(), BitSet([k]), d_k[]), DC(BitSet([i]), BitSet([j, k]), d_i_jk[]), @@ -783,7 +783,7 @@ function _4d_structure_to_dcs(indices::Vector{Int}, s::Tensor) j = indices[3] k = indices[2] l = indices[1] - return OrderedSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), + return StableSet{DC}([DC(BitSet(), BitSet([i]), d_i[]), DC(BitSet(), BitSet([j]), d_j[]), DC(BitSet(), BitSet([k]), d_k[]), DC(BitSet(), BitSet([l]), d_l[]), @@ -805,7 +805,7 @@ function _structure_to_dcs(int_2_idx, indices::Vector{Int}, s::Tensor) elseif length(indices) == 4 return _4d_structure_to_dcs(indices, s) end - dcs = OrderedSet{DC}() + dcs = StableSet{DC}() # Calculate DCs for all combinations of X and Y for X in subsets(indices) X = BitSet(X) @@ -820,7 +820,7 @@ function _structure_to_dcs(int_2_idx, indices::Vector{Int}, s::Tensor) push!(dcs, DC(X, Y, d)) d = _calc_dc_from_structure( - OrderedSet{IndexExpr}(), + StableSet{IndexExpr}(), bitset_to_idxs(int_2_idx, Y), [int_2_idx[i] for i in indices], s, @@ -831,7 +831,7 @@ function _structure_to_dcs(int_2_idx, indices::Vector{Int}, s::Tensor) end function dense_dcs(def, int_2_idx, indices::Vector{Int}) - dcs = OrderedSet() + dcs = StableSet() for X in subsets(indices) Y = setdiff(indices, X) for Z in subsets(Y) diff --git a/src/Galley/utility-funcs.jl b/src/Galley/utility-funcs.jl index 0c13ed4bd..6e2ad31a6 100644 --- a/src/Galley/utility-funcs.jl +++ b/src/Galley/utility-funcs.jl @@ -23,7 +23,7 @@ function relative_sort(indices::Vector{IndexExpr}, index_order; rev=false) end end -function relative_sort(indices::OrderedSet{IndexExpr}, index_order; rev=false) +function relative_sort(indices::StableSet{IndexExpr}, index_order; rev=false) return relative_sort(collect(indices), index_order; rev=rev) end @@ -157,11 +157,11 @@ end # This function determines whether any ordering of the `l_set` is a prefix of `r_vec`. # If r_vec is smaller than l_set, we just check whether r_vec is a subset of l_set. -function set_compat_with_loop_prefix(tensor_order::OrderedSet, loop_prefix::Vector) +function set_compat_with_loop_prefix(tensor_order::StableSet, loop_prefix::Vector) if length(tensor_order) > length(loop_prefix) - return OrderedSet(loop_prefix) ⊆ tensor_order + return StableSet(loop_prefix) ⊆ tensor_order else - return tensor_order == OrderedSet(loop_prefix[1:length(tensor_order)]) + return tensor_order == StableSet(loop_prefix[1:length(tensor_order)]) end end @@ -255,3 +255,5 @@ end function geometric_round(b, x) b^(floor(log(b, x)) + 0.5) end + + diff --git a/src/util/stable_set.jl b/src/util/stable_set.jl new file mode 100644 index 000000000..008cc99fa --- /dev/null +++ b/src/util/stable_set.jl @@ -0,0 +1,31 @@ +struct StableSet{T} <: AbstractSet{T} + data::OrderedSet{T} + StableSet{T}(arg) where {T} = new(OrderedSet(arg)) +end + +StableSet(arg) = StableSet(OrderedSet(arg)) +StableSet(args...) = StableSet(OrderedSet(args...)) +StableSet{T}(args...) where {T} = StableSet{T}(OrderedSet(args...)) + +Base.push!(s::StableSet, x) = push!(s.data, x) +Base.pop!(s::StableSet) = pop!(s.data) +Base.iterate(s::StableSet) = iterate(s.data) +Base.iterate(s::StableSet, i) = iterate(s.data, i) +Base.intersect!(s::StableSet, x...) = intersect!(s.data, x...) +Base.union!(s::StableSet, x...) = union!(s.data, x...) +Base.setdiff!(s::StableSet, x...) = setdiff!(s.data, x...) +Base.intersect(s::StableSet, x...) = intersect(s.data, x...) +Base.union(s::StableSet, x...) = union(s.data, x...) +Base.setdiff(s::StableSet, x...) = setdiff(s.data, x...) +Base.length(s::StableSet) = length(s.data) +Base.in(s::StableSet, x) = in(s.data, x) +Base.delete!(s::StableSet, x) = delete!(s.data, x) +Base.empty!(s::StableSet) = empty!(s.data) +function Base.hash(s::StableSet{T}, h::UInt) where {T} + h = hash(hash(StableSet{T}, h), h) + h_2 = UInt(0) + for k in s.data + h_2 ⊻= hash(k, h) + end + h_2 +end