Skip to content

Commit

Permalink
lower string/float case statements with MIR pass (#1360)
Browse files Browse the repository at this point in the history
## Summary

Use an MIR pass to lower `string` and `float` case statements into `if`
chains, replacing the C code generator logic.

## Details

* add the `lowerCase` pass to `mirpasses`
* it's currently only enabled for the C target
* the lowering is a straightforward MIR port of the code generation
  previously implemented by `cgen` for `string`/`float` case statements
* the `cgen` handling for `string` and `float` case statements is
  removed

---------

Co-authored-by: Saem Ghani <[email protected]>
  • Loading branch information
zerbina and saem authored Jun 27, 2024
1 parent a21c4db commit 0ae0cd5
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 61 deletions.
61 changes: 3 additions & 58 deletions compiler/backend/ccgstmts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
const
RangeExpandLimit = 256 # do not generate ranges
# over 'RangeExpandLimit' elements
stringCaseThreshold = 8
# above X strings a hash-switch for strings is generated

proc startBlockInternal(p: BProc) =
let result = p.blocks.len
Expand Down Expand Up @@ -213,52 +211,6 @@ template genIfForCaseUntil(p: BProc, t: CgNode,
else:
linefmt(p, cpsStmts, "goto $1;$n", [t[i][^1].label])

template genCaseGeneric(p: BProc, t: CgNode,
rangeFormat, eqFormat: FormatStr) =
var a: TLoc
initLocExpr(p, t[0], a)
genIfForCaseUntil(p, t, rangeFormat, eqFormat, t.len-1, a)

proc genCaseStringBranch(p: BProc, b: CgNode, e: TLoc, labl: BlockId,
branches: var openArray[Rope]) =
var x: TLoc
for i in 0..<b.len - 1:
assert(b[i].kind != cnkRange)
initLocExpr(p, b[i], x)
assert(b[i].kind == cnkStrLit)
var j = int(hashString(p.config, getString(p, b[i])) and high(branches))
appcg(p.module, branches[j], "if (#eqStrings($1, $2)) goto $3;$n",
[rdLoc(e), rdLoc(x), labl])

proc genStringCase(p: BProc, t: CgNode) =
# count how many constant strings there are in the case:
var strings = 0
for i in 1..<t.len:
if isOfBranch(t[i]): inc(strings, t[i].len - 1)
if strings > stringCaseThreshold:
var bitMask = math.nextPowerOfTwo(strings) - 1
var branches: seq[Rope]
newSeq(branches, bitMask + 1)
var a: TLoc
initLocExpr(p, t[0], a) # fist pass: generate ifs+goto:
for i in 1..<t.len:
if isOfBranch(t[i]):
genCaseStringBranch(p, t[i], a, t[i][^1].label, branches)
else:
# else statement: nothing to do yet
discard
linefmt(p, cpsStmts, "switch (#hashString($1) & $2) {$n",
[rdLoc(a), bitMask])
for j in 0..high(branches):
if branches[j] != "":
lineF(p, cpsStmts, "case $1: $n$2break;$n",
[intLiteral(j), branches[j]])
lineF(p, cpsStmts, "}$n", []) # else statement:
if not isOfBranch(t[^1]):
lineCg(p, cpsStmts, "goto $1;$n", [t[^1][0].label])

else:
genCaseGeneric(p, t, "", "if (#eqStrings($1, $2)) goto $3;$n")

proc branchHasTooBigRange(b: CgNode): bool =
for it in b:
Expand Down Expand Up @@ -332,17 +284,10 @@ proc genOrdinalCase(p: BProc, n: CgNode) =

proc genCase(p: BProc, t: CgNode) =
genLineDir(p, t)
case skipTypes(t[0].typ, abstractVarRange).kind
of tyString:
genStringCase(p, t)
of tyFloat..tyFloat64:
genCaseGeneric(p, t, "if ($1 >= $2 && $1 <= $3) goto $4;$n",
"if ($1 == $2) goto $3;$n")
if t[0].kind == cnkLocal and sfGoto in p.body[t[0].local].flags:
genGotoForCase(p, t)
else:
if t[0].kind == cnkLocal and sfGoto in p.body[t[0].local].flags:
genGotoForCase(p, t)
else:
genOrdinalCase(p, t)
genOrdinalCase(p, t)

proc bodyCanRaise(p: BProc; n: CgNode): bool =
case n.kind
Expand Down
4 changes: 4 additions & 0 deletions compiler/mir/mirconstr.nim
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ func join*(bu: var MirBuilder, label: LabelId) =
bu.subTree mnkJoin:
bu.add MirNode(kind: mnkLabel, label: label)

func goto*(bu: var MirBuilder, label: LabelId) =
bu.subTree mnkGoto:
bu.add MirNode(kind: mnkLabel, label: label)

template pathNamed*(bu: var MirBuilder, t: TypeId, f: int32, body: untyped) =
## Emits a ``mnkPathNamed`` expression.
bu.subTree MirNode(kind: mnkPathNamed, typ: t):
Expand Down
139 changes: 139 additions & 0 deletions compiler/mir/mirpasses.nim
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import
idioms
]

from std/math import nextPowerOfTwo
from compiler/backend/ccgutils import hashString

# for type-based alias analysis
from compiler/sem/aliases import isPartOf, TAnalysisResult

Expand All @@ -58,6 +61,10 @@ const
LocSkip = abstractRange + tyUserTypeClasses
## types to skip to arrive at the underlying concrete value type

template addCompilerProc(env: var MirEnv, graph: ModuleGraph,
name: string): ProcedureId =
env.procedures.add(graph.getCompilerProc(name))

template subTree(bu: var MirBuilder, k: MirNodeKind, t: TypeId,
body: untyped) =
bu.subTree MirNode(kind: k, typ: t):
Expand Down Expand Up @@ -652,6 +659,137 @@ proc lowerMove(tree: MirTree, changes: var Changeset) =
else:
discard "not relevant"

proc lowerCase(tree: MirTree, graph: ModuleGraph, env: var MirEnv,
changes: var Changeset) =
## Lowers case statements with string or float selectors. For large string-
## case statements, a hash-table optimization is used.
const stringCaseThreshold = 8
## above X strings a hash-switch for strings is generated

iterator targets(tree: MirTree, n: NodePosition): (LabelId, NodePosition) =
## Returns all comparison candidate together with their associated jump
## target.
for it in tree.subNodes(n, 1):
let target = tree[tree.last(it)].label
var x = tree.child(it, 0)
for _ in 0..<(tree[it].len - 1): # -1 for the label node
yield (target, x)
x = tree.sibling(x)

proc genericCase(bu: var MirBuilder, tree: MirTree, n: NodePosition,
eq: TMagic, sel: Value) {.nimcall.} =
for (target, it) in tree.targets(n):
if tree[it].kind == mnkRange:
# only float case-statements can use ranges, so we know that the
# operands are floats here
var cond = bu.wrapTemp BoolType:
bu.buildMagicCall mLeF64, BoolType:
bu.emitByVal bu.inline(tree, tree.child(it, 0))
bu.emitByVal sel

bu.buildIf (;bu.use cond):
cond = bu.wrapTemp BoolType:
bu.buildMagicCall mLeF64, BoolType:
bu.emitByVal sel
bu.emitByVal bu.inline(tree, tree.child(it, 1))

# jump to the branch body if the run-time value is within the given
# range
bu.buildIf (;bu.use cond):
bu.goto target
else:
# single comparison
let cond = bu.wrapTemp BoolType:
bu.buildMagicCall eq, BoolType:
bu.emitByVal sel
bu.emitByVal bu.inline(tree, it)

bu.buildIf (;bu.use cond):
bu.goto target

bu.goto tree[tree.last(tree.last(n))].label # jump to else branch

for n in search(tree, {mnkCase}):
case env.types[tree[n, 0].typ].skipTypes(abstractInst).kind
of tyFloat, tyFloat64, tyFloat32:
# simple: use the generic lowering
changes.replaceMulti(tree, n, bu):
let sel = bu.inline(tree, tree.child(n, 0))
genericCase(bu, tree, n, mEqF64, sel)
of tyString:
# count the number of strings:
var numStrings = 0
for it in tree.subNodes(n, start=1):
numStrings += (tree.len(it) - 1) # -1 for the target label

if numStrings < stringCaseThreshold:
# compare against every string
changes.replaceMulti(tree, n, bu):
let sel = bu.inline(tree, tree.child(n, 0))
genericCase(bu, tree, n, mEqStr, sel)
else:
# reduce the number of string comparisons through usage of a hash
# table
changes.replaceMulti(tree, n, bu):
let bitMask = nextPowerOfTwo(numStrings) - 1
var branches: seq[tuple[label: LabelId,
strings: seq[(NodePosition, LabelId)]]]
branches.newSeq(bitMask + 1)

# sort the string operands into buckets (`branches`) based on their
# hash:
for (target, it) in tree.targets(n):
let bI = hashString(graph.config, env[tree[it].strVal]) and bitMask
if branches[bI].strings.len == 0:
# the label is allocated on demand
branches[bI].label = bu.allocLabel()

branches[bI].strings.add (it, target)

let
elseLabel = tree[tree.last(tree.last(n))].label
typ = env.types.sizeType
sel = bu.inline(tree, tree.child(n, 0))
var hash: Value

# emit the hash computation:
hash = bu.wrapTemp typ:
bu.buildCall env.addCompilerProc(graph, "hashString"), typ:
bu.emitByVal sel
hash = bu.wrapTemp typ:
bu.buildMagicCall mBitandI, typ:
bu.emitByVal hash
bu.emitByVal:
literal(mnkIntLit, env.getOrIncl(BiggestInt bitMask), typ)

# emit the dispatcher over the hash value:
bu.subTree mnkCase:
bu.use hash
for i, b in branches.pairs:
bu.subTree mnkBranch:
bu.use literal(mnkIntLit, env.getOrIncl(BiggestInt i), typ)
if b.strings.len == 0:
bu.add MirNode(kind: mnkLabel, label: elseLabel)
else:
bu.add MirNode(kind: mnkLabel, label: b.label)

# emit the string comparisons:
for b in branches.items:
if b.strings.len > 0:
bu.join b.label
for (str, target) in b.strings.items:
let cond = bu.wrapTemp BoolType:
bu.buildMagicCall mEqStr, BoolType:
bu.emitByVal sel
bu.emitByVal bu.inline(tree, str)

bu.buildIf (;bu.use cond):
bu.goto target

bu.goto elseLabel # jump to the 'else' branch
else:
discard "keep as is"

proc applyPasses*(body: var MirBody, prc: PSym, env: var MirEnv,
graph: ModuleGraph, target: TargetBackend) =
## Applies all applicable MIR passes to the body (`tree` and `source`) of
Expand Down Expand Up @@ -687,6 +825,7 @@ proc applyPasses*(body: var MirBody, prc: PSym, env: var MirEnv,
lowerNew(body.code, graph, env, c)
lowerChecks(body, graph, env, c)
injectStrPreparation(body.code, graph, env, c)
lowerCase(body.code, graph, env, c)

# instrument the body with profiler calls after all lowerings, but before
# optimization
Expand Down
7 changes: 4 additions & 3 deletions compiler/mir/mirtrees.nim
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,11 @@ iterator pairs*(tree: MirTree): (NodePosition, lent MirNode) =
yield (i.NodePosition, tree[i])
inc i

iterator subNodes*(tree: MirTree, n: NodePosition): NodePosition =
## Iterates over and yields all direct child nodes of `n`
iterator subNodes*(tree: MirTree, n: NodePosition; start = 0): NodePosition =
## Returns in order of apperance all direct child nodes of `n`, starting with
## `start`.
let L = tree[n].len
var n = tree.child(n, 0)
var n = tree.child(n, start)
for _ in 0..<L:
yield n
n = tree.sibling(n)
Expand Down

0 comments on commit 0ae0cd5

Please sign in to comment.