Skip to content

Commit

Permalink
Dectests (#91)
Browse files Browse the repository at this point in the history
This PR refactors the `scripts/dectest.jl` script so that it is easier to support more operations in the future. All dectests are also generated fresh with the unified format.
  • Loading branch information
barucden authored Nov 14, 2024
1 parent fac9bc8 commit f48ea84
Show file tree
Hide file tree
Showing 13 changed files with 4,839 additions and 9,501 deletions.
271 changes: 148 additions & 123 deletions scripts/dectest.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,109 @@
function _precision(line)
function (@main)(args=ARGS)
name, dectest_path, output_path = args

open(output_path, "w") do io
println(io, """
using Decimals
using Test
using Decimals: @with_context
@testset \"$name\" begin""")

translate(io, dectest_path)

println(io, "end")
end
end

function translate(io, dectest_path)
directives = Dict{Symbol, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives[:precision] = parse_precision(line)
elseif startswith(line, "rounding:")
directives[:rounding] = parse_rounding(line)
elseif startswith(line, "maxexponent:")
directives[:Emax] = parse_maxexponent(line)
elseif startswith(line, "minexponent:")
directives[:Emin] = parse_minexponent(line)
else
if directives[:rounding] == RoundingMode{:Unsupported}
continue
end

test = parse_test(line)
any(isspecial, test.operands) && continue
isspecial(test.result) && continue

dectest = decimal_test(test, directives)
println(io, dectest)
end
end
end

function isspecial(value)
value = lowercase(value)
return occursin(r"(inf|nan|#|\?)", value)
end

function parse_precision(line)
m = match(r"^precision:\s*(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _rounding(line)
function parse_rounding(line)
m = match(r"^rounding:\s*(\w+)$", line)
isnothing(m) && throw(ArgumentError(line))
r = m[1]
if r == "ceiling"
return "RoundUp"
return RoundUp
elseif r == "down"
return "RoundToZero"
return RoundToZero
elseif r == "floor"
return "RoundDown"
return RoundDown
elseif r == "half_even"
return "RoundNearest"
return RoundNearest
elseif r == "half_up"
return "RoundNearestTiesAway"
return RoundNearestTiesAway
elseif r == "up"
return "RoundFromZero"
return RoundFromZero
elseif r == "half_down"
return "RoundHalfDownUnsupported"
return RoundingMode{:Unsupported}
elseif r == "05up"
return "Round05UpUnsupported"
return RoundingMode{:Unsupported}
else
throw(ArgumentError(r))
end
end

function _maxexponent(line)
function parse_maxexponent(line)
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _minexponent(line)
function parse_minexponent(line)
m = match(r"^minexponent:\s*(-\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _test(line)
function parse_test(line)
occursin("->", line) || throw(ArgumentError(line))
lhs, rhs = split(line, "->")
id, operation, operands... = split(lhs)
Expand All @@ -50,134 +112,97 @@ function _test(line)
return (;id, operation, operands, result, conditions)
end

function decimal(x)
function clean(@nospecialize ex)
if isa(ex, Expr)
if Meta.isexpr(ex, :macrocall)
return Expr(:macrocall, ex.args[1], nothing, map(clean, ex.args[3:end])...)
else
return Expr(ex.head, map(clean, ex.args)...)
end
elseif isa(ex, LineNumberNode)
return nothing
else
return ex
end
end

function decimal_test(test, directives)
ctxt = decimal_context(directives)
op = decimal_operation(test.operation, test.operands)
res = operation_result(test.operation, test.result)

if :overflow in test.conditions
ex = :(@with_context($ctxt, @test_throws OverflowError $op))
elseif :division_undefined in test.conditions
ex = :(@with_context($ctxt, @test_throws UndefinedDivisionError $op))
elseif :division_by_zero in test.conditions
ex = :(@with_context($ctxt, @test_throws DivisionByZeroError $op))
else
ex = :(@with_context($ctxt, @test $op == $(res)))
end
return clean(ex)
end

function dec(x)
x = strip(x, ['\'', '\"'])
return "dec\"$x\""
return :(@dec_str $("$x"))
end

function decimal_context(directives)
names = Tuple(sort!(collect(keys(directives))))
values = Tuple([directives[name] for name in names])
params = NamedTuple{names}(values)
return params
end

function operation_result(operation, result)
if operation == "compare"
return parse(Int, result)
else
return dec(result)
end
end

function print_operation(io, operation, operands)
function decimal_operation(operation, operands)
if operation == "abs"
print_abs(io, operands...)
return decimal_abs(operands...)
elseif operation == "add"
print_add(io, operands...)
return decimal_add(operands...)
elseif operation == "apply"
print_apply(io, operands...)
return decimal_apply(operands...)
elseif operation == "compare"
print_compare(io, operands...)
return decimal_compare(operands...)
elseif operation == "divide"
print_divide(io, operands...)
return decimal_divide(operands...)
elseif operation == "max"
print_max(io, operands...)
return decimal_max(operands...)
elseif operation == "min"
print_min(io, operands...)
return decimal_min(operands...)
elseif operation == "minus"
print_minus(io, operands...)
return decimal_minus(operands...)
elseif operation == "multiply"
print_multiply(io, operands...)
return decimal_multiply(operands...)
elseif operation == "plus"
print_plus(io, operands...)
return decimal_plus(operands...)
elseif operation == "reduce"
print_reduce(io, operands...)
return decimal_reduce(operands...)
elseif operation == "subtract"
print_subtract(io, operands...)
return decimal_subtract(operands...)
else
throw(ArgumentError(operation))
end
end
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
print_apply(io, x) = print(io, decimal(x))
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
print_max(io, x, y) = print(io, "max(", decimal(x), ", ", decimal(y), ")")
print_min(io, x, y) = print(io, "min(", decimal(x), ", ", decimal(y), ")")
print_minus(io, x) = print(io, "-(", decimal(x), ")")
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
print_plus(io, x) = print(io, "+(", decimal(x), ")")
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))

function print_test(io, test, directives)
println(io, " # $(test.id)")

names = sort!(collect(keys(directives)))
params = join(("$k=$(directives[k])" for k in names), ", ")
print(io, " @with_context ($params) ")

if :overflow test.conditions
print(io, "@test_throws OverflowError ")
print_operation(io, test.operation, test.operands)
println(io)
elseif :division_undefined test.conditions
print(io, "@test_throws UndefinedDivisionError ")
print_operation(io, test.operation, test.operands)
println(io)
elseif :division_by_zero test.conditions
print(io, "@test_throws DivisionByZeroError ")
print_operation(io, test.operation, test.operands)
println(io)
else
print(io, "@test ")
print_operation(io, test.operation, test.operands)
print(io, " == ")
println(io, decimal(test.result))
end
end

function isspecial(value)
value = lowercase(value)
return occursin(r"(inf|nan|#)", value)
end

function translate(io, dectest_path)
directives = Dict{String, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives["precision"] = _precision(line)
elseif startswith(line, "rounding:")
directives["rounding"] = _rounding(line)
elseif startswith(line, "maxexponent:")
directives["Emax"] = _maxexponent(line)
elseif startswith(line, "minexponent:")
directives["Emin"] = _minexponent(line)
else
test = _test(line)
any(isspecial, test.operands) && continue
occursin("Unsupported", directives["rounding"]) && continue
print_test(io, test, directives)
end
end
end

function (@main)(args=ARGS)
name, dectest_path, output_path = args

open(output_path, "w") do io
println(io, """
using Decimals
using ScopedValues
using Test
using Decimals: @with_context
@testset \"$name\" begin""")

translate(io, dectest_path)

println(io, "end")
end
end
decimal_abs(x) = :(abs($(dec(x))))
decimal_add(x, y) = :($(dec(x)) + $(dec(y)))
decimal_apply(x) = dec(x)
decimal_compare(x, y) = :(cmp($(dec(x)), $(dec(y))))
decimal_divide(x, y) = :($(dec(x)) / $(dec(y)))
decimal_max(x, y) = :(max($(dec(x)), $(dec(y))))
decimal_min(x, y) = :(min($(dec(x)), $(dec(y))))
decimal_minus(x) = :(-($(dec(x))))
decimal_multiply(x, y) = :($(dec(x)) * $(dec(y)))
decimal_plus(x) = :(+($(dec(x))))
decimal_reduce(x) = :(normalize($(dec(x))))
decimal_subtract(x, y) = :($(dec(x)) - $(dec(y)))

Loading

0 comments on commit f48ea84

Please sign in to comment.