Skip to content

Cranelift: improve egraph cost function / extraction to "understand" sharing / reusing sub-expressions #12156

@fitzgen

Description

@fitzgen

Note: I am forking off a new issue from #12106. See that issue's discussion for a description of the issue of enode cost saturating to infinity due to extraction not "understanding" shared structure and how this problem is equivalent to weighted-set cover and therefore NP-hard. +cc @cfallin @bongjunj

Chris and I were talking about extraction again yesterday. A couple things came out of it, which I will note down for our future selves and posterity:

  • We could divide/amortize the cost by the number of users of any given value. ... This maintains the phased aspect

    I kind of tried to explain this point earlier but don't think I was very clear so I will try again: The problem with this algorithm is that the "number of users" == "number of incoming edges in the (a)egraph", and for this to be a good algorithm in practice, it must be the case that the number of incoming edges in the (a)egraph approximates the number of incoming edges in the set of actually-extracted expressions in practice. It is not clear to me that property is generally true. In fact, I suspect that the more we do "egraph-y" things like add more canonicalizations and neutral rewrites with the hope that they will unlock further rewrites that are beneficial, then the less this property will hold!

  • I had the realization that if we topo-sort values before computing best costs, then we will reach a fix point in a single iteration. This has some nice implications for the "If we aren't maintaining phasing" algorithm I sketched above. Bear with me for a second.

    Our current algorithm to compute the best enode (specific value) for each eclass (canonical value) is a fixpoint loop that looks roughly like this

    fn compute_best_values() {
        let mut keep_going = true;
        while keep_going {
            for value in dfg.all_values() {
                let orig_best_value = best[value];
                best[value] = best_value(value, dfg);
                keep_going |= best[value] != orig_best_value;
            }
        }
    }

    It will do n + 1 iterations before returning where n is the length of the longest chain of vNN -> vMM edges in the dataflow graph where NN < MM (and the + 1 is to determine that we did actually reach the fixpoint and nothing else is changing anymore).1 Note that at most n == len(values) so we do a pass over all n values n times, leading to an O(n^2) worst case runtime. But at minimum, in the best case, it will do at least two passes over all n values: once to compute the best values and then another to verify that we reached the fixpoint and nothing else is changing.

    Topologically sorting the values based on the DFG is also one pass over all n values. And then we can do one iteration of the inner loop from the previous algorithm, which is also one pass over all n values, and in this case we know that we will have reached a fixedpoint immediately because (a) the DFG is acyclic (we don't look through block params) and (b) we are processing uses before defs. So this gives us an O(n) worst case (and best case) runtime algorithm. Strictly better (in terms of worst case runtime) than what we do today!

    fn compute_best_values() {
        let sorted = topo_sort_values(dfg);
        compute_best_values_already_sorted(sorted);
    }
    
    fn compute_best_values_already_sorted(sorted) {
        for value in sorted {
            best[value] = best_value(value, dfg);
        }
    }

    But, this now plays nicely with incrementally updating costs to reflect certain sub-expressions getting cheaper because we've already extracted them! The following brings us back to O(n^2) but I think should give us a much better extraction algorithm that "understands" shared subexpressions:

    let sorted = topo_sort_values(dfg);
    
    for inst in skeleton {
        // Compute the best values. If this is the first iteration of the loop,
        // it is equivalent to what we had before. Otherwise, it will update the
        // best values based the adjusted cost of values we have already
        // extracted. 
        compute_best_values_already_sorted(sorted);
    
        // Elaborate this instruction, extracting the best values for its
        // operands and inserting their defining instructions back into the
        // CFG's basic blocks.
        elaborate_inst(inst);
    
        // Update the cost for each value we extracted. It is ~free (modulo
        // register pressure) to reuse this value, since we have already computed
        // it and it is now available. The next iteration of the outer loop will
        // recompute best values to reflect these updated costs.
        for value in transitive_values_used_by(inst) {
            // Or could do something very low but non-zero, like `1`, to reflect
            // the cost of additional register pressure. Worth experimenting
            // with, but is also somewhat orthogonal...
            let new_cost = 0;
            set_cost(value, new_cost);
        }
    }

    It could be that this is fast enough to just be what we do by default. If not, it might be something we only use when configured to higher optimization levels.

Footnotes

  1. Typically n is fairly small; it is affected by how much we created new instructions and spliced their definitions into the middle of the DFG after the CLIF-producer gave us the initial CLIF. That is something we generally don't do a ton of, except in our legalization pass and NaN canonicalization and neither of them do it too much.

Metadata

Metadata

Assignees

No one assigned

    Labels

    craneliftIssues related to the Cranelift code generatorcranelift:goal:optimize-speedFocus area: the speed of the code produced by Cranelift.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions