diff --git a/compiler/backend/ccgstmts.nim b/compiler/backend/ccgstmts.nim index bd918b3d4f2..81dbebafc39 100644 --- a/compiler/backend/ccgstmts.nim +++ b/compiler/backend/ccgstmts.nim @@ -12,8 +12,6 @@ const RangeExpandLimit = 256 # do not generate ranges # over 'RangeExpandLimit' elements - stringCaseThreshold = 8 - # above X strings a hash-switch for strings is generated proc startBlockInternal(p: BProc) = let result = p.blocks.len @@ -213,52 +211,6 @@ template genIfForCaseUntil(p: BProc, t: CgNode, else: linefmt(p, cpsStmts, "goto $1;$n", [t[i][^1].label]) -template genCaseGeneric(p: BProc, t: CgNode, - rangeFormat, eqFormat: FormatStr) = - var a: TLoc - initLocExpr(p, t[0], a) - genIfForCaseUntil(p, t, rangeFormat, eqFormat, t.len-1, a) - -proc genCaseStringBranch(p: BProc, b: CgNode, e: TLoc, labl: BlockId, - branches: var openArray[Rope]) = - var x: TLoc - for i in 0.. stringCaseThreshold: - var bitMask = math.nextPowerOfTwo(strings) - 1 - var branches: seq[Rope] - newSeq(branches, bitMask + 1) - var a: TLoc - initLocExpr(p, t[0], a) # fist pass: generate ifs+goto: - for i in 1..= $2 && $1 <= $3) goto $4;$n", - "if ($1 == $2) goto $3;$n") + if t[0].kind == cnkLocal and sfGoto in p.body[t[0].local].flags: + genGotoForCase(p, t) else: - if t[0].kind == cnkLocal and sfGoto in p.body[t[0].local].flags: - genGotoForCase(p, t) - else: - genOrdinalCase(p, t) + genOrdinalCase(p, t) proc bodyCanRaise(p: BProc; n: CgNode): bool = case n.kind diff --git a/compiler/mir/mirconstr.nim b/compiler/mir/mirconstr.nim index 3b22dcc2105..2be7aeb3bc7 100644 --- a/compiler/mir/mirconstr.nim +++ b/compiler/mir/mirconstr.nim @@ -452,6 +452,10 @@ func join*(bu: var MirBuilder, label: LabelId) = bu.subTree mnkJoin: bu.add MirNode(kind: mnkLabel, label: label) +func goto*(bu: var MirBuilder, label: LabelId) = + bu.subTree mnkGoto: + bu.add MirNode(kind: mnkLabel, label: label) + template pathNamed*(bu: var MirBuilder, t: TypeId, f: int32, body: untyped) = ## Emits a ``mnkPathNamed`` expression. bu.subTree MirNode(kind: mnkPathNamed, typ: t): diff --git a/compiler/mir/mirpasses.nim b/compiler/mir/mirpasses.nim index c9b3d498899..40c438431a3 100644 --- a/compiler/mir/mirpasses.nim +++ b/compiler/mir/mirpasses.nim @@ -38,6 +38,9 @@ import idioms ] +from std/math import nextPowerOfTwo +from compiler/backend/ccgutils import hashString + # for type-based alias analysis from compiler/sem/aliases import isPartOf, TAnalysisResult @@ -58,6 +61,10 @@ const LocSkip = abstractRange + tyUserTypeClasses ## types to skip to arrive at the underlying concrete value type +template addCompilerProc(env: var MirEnv, graph: ModuleGraph, + name: string): ProcedureId = + env.procedures.add(graph.getCompilerProc(name)) + template subTree(bu: var MirBuilder, k: MirNodeKind, t: TypeId, body: untyped) = bu.subTree MirNode(kind: k, typ: t): @@ -652,6 +659,137 @@ proc lowerMove(tree: MirTree, changes: var Changeset) = else: discard "not relevant" +proc lowerCase(tree: MirTree, graph: ModuleGraph, env: var MirEnv, + changes: var Changeset) = + ## Lowers case statements with string or float selectors. For large string- + ## case statements, a hash-table optimization is used. + const stringCaseThreshold = 8 + ## above X strings a hash-switch for strings is generated + + iterator targets(tree: MirTree, n: NodePosition): (LabelId, NodePosition) = + ## Returns all comparison candidate together with their associated jump + ## target. + for it in tree.subNodes(n, 1): + let target = tree[tree.last(it)].label + var x = tree.child(it, 0) + for _ in 0..<(tree[it].len - 1): # -1 for the label node + yield (target, x) + x = tree.sibling(x) + + proc genericCase(bu: var MirBuilder, tree: MirTree, n: NodePosition, + eq: TMagic, sel: Value) {.nimcall.} = + for (target, it) in tree.targets(n): + if tree[it].kind == mnkRange: + # only float case-statements can use ranges, so we know that the + # operands are floats here + var cond = bu.wrapTemp BoolType: + bu.buildMagicCall mLeF64, BoolType: + bu.emitByVal bu.inline(tree, tree.child(it, 0)) + bu.emitByVal sel + + bu.buildIf (;bu.use cond): + cond = bu.wrapTemp BoolType: + bu.buildMagicCall mLeF64, BoolType: + bu.emitByVal sel + bu.emitByVal bu.inline(tree, tree.child(it, 1)) + + # jump to the branch body if the run-time value is within the given + # range + bu.buildIf (;bu.use cond): + bu.goto target + else: + # single comparison + let cond = bu.wrapTemp BoolType: + bu.buildMagicCall eq, BoolType: + bu.emitByVal sel + bu.emitByVal bu.inline(tree, it) + + bu.buildIf (;bu.use cond): + bu.goto target + + bu.goto tree[tree.last(tree.last(n))].label # jump to else branch + + for n in search(tree, {mnkCase}): + case env.types[tree[n, 0].typ].skipTypes(abstractInst).kind + of tyFloat, tyFloat64, tyFloat32: + # simple: use the generic lowering + changes.replaceMulti(tree, n, bu): + let sel = bu.inline(tree, tree.child(n, 0)) + genericCase(bu, tree, n, mEqF64, sel) + of tyString: + # count the number of strings: + var numStrings = 0 + for it in tree.subNodes(n, start=1): + numStrings += (tree.len(it) - 1) # -1 for the target label + + if numStrings < stringCaseThreshold: + # compare against every string + changes.replaceMulti(tree, n, bu): + let sel = bu.inline(tree, tree.child(n, 0)) + genericCase(bu, tree, n, mEqStr, sel) + else: + # reduce the number of string comparisons through usage of a hash + # table + changes.replaceMulti(tree, n, bu): + let bitMask = nextPowerOfTwo(numStrings) - 1 + var branches: seq[tuple[label: LabelId, + strings: seq[(NodePosition, LabelId)]]] + branches.newSeq(bitMask + 1) + + # sort the string operands into buckets (`branches`) based on their + # hash: + for (target, it) in tree.targets(n): + let bI = hashString(graph.config, env[tree[it].strVal]) and bitMask + if branches[bI].strings.len == 0: + # the label is allocated on demand + branches[bI].label = bu.allocLabel() + + branches[bI].strings.add (it, target) + + let + elseLabel = tree[tree.last(tree.last(n))].label + typ = env.types.sizeType + sel = bu.inline(tree, tree.child(n, 0)) + var hash: Value + + # emit the hash computation: + hash = bu.wrapTemp typ: + bu.buildCall env.addCompilerProc(graph, "hashString"), typ: + bu.emitByVal sel + hash = bu.wrapTemp typ: + bu.buildMagicCall mBitandI, typ: + bu.emitByVal hash + bu.emitByVal: + literal(mnkIntLit, env.getOrIncl(BiggestInt bitMask), typ) + + # emit the dispatcher over the hash value: + bu.subTree mnkCase: + bu.use hash + for i, b in branches.pairs: + bu.subTree mnkBranch: + bu.use literal(mnkIntLit, env.getOrIncl(BiggestInt i), typ) + if b.strings.len == 0: + bu.add MirNode(kind: mnkLabel, label: elseLabel) + else: + bu.add MirNode(kind: mnkLabel, label: b.label) + + # emit the string comparisons: + for b in branches.items: + if b.strings.len > 0: + bu.join b.label + for (str, target) in b.strings.items: + let cond = bu.wrapTemp BoolType: + bu.buildMagicCall mEqStr, BoolType: + bu.emitByVal sel + bu.emitByVal bu.inline(tree, str) + + bu.buildIf (;bu.use cond): + bu.goto target + + bu.goto elseLabel # jump to the 'else' branch + else: + discard "keep as is" + proc applyPasses*(body: var MirBody, prc: PSym, env: var MirEnv, graph: ModuleGraph, target: TargetBackend) = ## Applies all applicable MIR passes to the body (`tree` and `source`) of @@ -687,6 +825,7 @@ proc applyPasses*(body: var MirBody, prc: PSym, env: var MirEnv, lowerNew(body.code, graph, env, c) lowerChecks(body, graph, env, c) injectStrPreparation(body.code, graph, env, c) + lowerCase(body.code, graph, env, c) # instrument the body with profiler calls after all lowerings, but before # optimization diff --git a/compiler/mir/mirtrees.nim b/compiler/mir/mirtrees.nim index 00f9b5b70e5..ad27128e6a8 100644 --- a/compiler/mir/mirtrees.nim +++ b/compiler/mir/mirtrees.nim @@ -550,10 +550,11 @@ iterator pairs*(tree: MirTree): (NodePosition, lent MirNode) = yield (i.NodePosition, tree[i]) inc i -iterator subNodes*(tree: MirTree, n: NodePosition): NodePosition = - ## Iterates over and yields all direct child nodes of `n` +iterator subNodes*(tree: MirTree, n: NodePosition; start = 0): NodePosition = + ## Returns in order of apperance all direct child nodes of `n`, starting with + ## `start`. let L = tree[n].len - var n = tree.child(n, 0) + var n = tree.child(n, start) for _ in 0..