Skip to content

Commit f48ea84

Browse files
authored
Dectests (#91)
This PR refactors the `scripts/dectest.jl` script so that it is easier to support more operations in the future. All dectests are also generated fresh with the unified format.
1 parent fac9bc8 commit f48ea84

File tree

13 files changed

+4839
-9501
lines changed

13 files changed

+4839
-9501
lines changed

scripts/dectest.jl

Lines changed: 148 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,109 @@
1-
function _precision(line)
1+
function (@main)(args=ARGS)
2+
name, dectest_path, output_path = args
3+
4+
open(output_path, "w") do io
5+
println(io, """
6+
using Decimals
7+
using Test
8+
using Decimals: @with_context
9+
10+
@testset \"$name\" begin""")
11+
12+
translate(io, dectest_path)
13+
14+
println(io, "end")
15+
end
16+
end
17+
18+
function translate(io, dectest_path)
19+
directives = Dict{Symbol, Any}()
20+
21+
for line in eachline(dectest_path)
22+
line = strip(line)
23+
24+
isempty(line) && continue
25+
startswith(line, "--") && continue
26+
27+
line = lowercase(line)
28+
29+
if startswith(line, "version:")
30+
# ...
31+
elseif startswith(line, "extended:")
32+
# ...
33+
elseif startswith(line, "clamp:")
34+
# ...
35+
elseif startswith(line, "precision:")
36+
directives[:precision] = parse_precision(line)
37+
elseif startswith(line, "rounding:")
38+
directives[:rounding] = parse_rounding(line)
39+
elseif startswith(line, "maxexponent:")
40+
directives[:Emax] = parse_maxexponent(line)
41+
elseif startswith(line, "minexponent:")
42+
directives[:Emin] = parse_minexponent(line)
43+
else
44+
if directives[:rounding] == RoundingMode{:Unsupported}
45+
continue
46+
end
47+
48+
test = parse_test(line)
49+
any(isspecial, test.operands) && continue
50+
isspecial(test.result) && continue
51+
52+
dectest = decimal_test(test, directives)
53+
println(io, dectest)
54+
end
55+
end
56+
end
57+
58+
function isspecial(value)
59+
value = lowercase(value)
60+
return occursin(r"(inf|nan|#|\?)", value)
61+
end
62+
63+
function parse_precision(line)
264
m = match(r"^precision:\s*(\d+)$", line)
365
isnothing(m) && throw(ArgumentError(line))
466
return parse(Int, m[1])
567
end
668

7-
function _rounding(line)
69+
function parse_rounding(line)
870
m = match(r"^rounding:\s*(\w+)$", line)
971
isnothing(m) && throw(ArgumentError(line))
1072
r = m[1]
1173
if r == "ceiling"
12-
return "RoundUp"
74+
return RoundUp
1375
elseif r == "down"
14-
return "RoundToZero"
76+
return RoundToZero
1577
elseif r == "floor"
16-
return "RoundDown"
78+
return RoundDown
1779
elseif r == "half_even"
18-
return "RoundNearest"
80+
return RoundNearest
1981
elseif r == "half_up"
20-
return "RoundNearestTiesAway"
82+
return RoundNearestTiesAway
2183
elseif r == "up"
22-
return "RoundFromZero"
84+
return RoundFromZero
2385
elseif r == "half_down"
24-
return "RoundHalfDownUnsupported"
86+
return RoundingMode{:Unsupported}
2587
elseif r == "05up"
26-
return "Round05UpUnsupported"
88+
return RoundingMode{:Unsupported}
2789
else
2890
throw(ArgumentError(r))
2991
end
3092
end
3193

32-
function _maxexponent(line)
94+
function parse_maxexponent(line)
3395
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
3496
isnothing(m) && throw(ArgumentError(line))
3597
return parse(Int, m[1])
3698
end
3799

38-
function _minexponent(line)
100+
function parse_minexponent(line)
39101
m = match(r"^minexponent:\s*(-\d+)$", line)
40102
isnothing(m) && throw(ArgumentError(line))
41103
return parse(Int, m[1])
42104
end
43105

44-
function _test(line)
106+
function parse_test(line)
45107
occursin("->", line) || throw(ArgumentError(line))
46108
lhs, rhs = split(line, "->")
47109
id, operation, operands... = split(lhs)
@@ -50,134 +112,97 @@ function _test(line)
50112
return (;id, operation, operands, result, conditions)
51113
end
52114

53-
function decimal(x)
115+
function clean(@nospecialize ex)
116+
if isa(ex, Expr)
117+
if Meta.isexpr(ex, :macrocall)
118+
return Expr(:macrocall, ex.args[1], nothing, map(clean, ex.args[3:end])...)
119+
else
120+
return Expr(ex.head, map(clean, ex.args)...)
121+
end
122+
elseif isa(ex, LineNumberNode)
123+
return nothing
124+
else
125+
return ex
126+
end
127+
end
128+
129+
function decimal_test(test, directives)
130+
ctxt = decimal_context(directives)
131+
op = decimal_operation(test.operation, test.operands)
132+
res = operation_result(test.operation, test.result)
133+
134+
if :overflow in test.conditions
135+
ex = :(@with_context($ctxt, @test_throws OverflowError $op))
136+
elseif :division_undefined in test.conditions
137+
ex = :(@with_context($ctxt, @test_throws UndefinedDivisionError $op))
138+
elseif :division_by_zero in test.conditions
139+
ex = :(@with_context($ctxt, @test_throws DivisionByZeroError $op))
140+
else
141+
ex = :(@with_context($ctxt, @test $op == $(res)))
142+
end
143+
return clean(ex)
144+
end
145+
146+
function dec(x)
54147
x = strip(x, ['\'', '\"'])
55-
return "dec\"$x\""
148+
return :(@dec_str $("$x"))
149+
end
150+
151+
function decimal_context(directives)
152+
names = Tuple(sort!(collect(keys(directives))))
153+
values = Tuple([directives[name] for name in names])
154+
params = NamedTuple{names}(values)
155+
return params
156+
end
157+
158+
function operation_result(operation, result)
159+
if operation == "compare"
160+
return parse(Int, result)
161+
else
162+
return dec(result)
163+
end
56164
end
57165

58-
function print_operation(io, operation, operands)
166+
function decimal_operation(operation, operands)
59167
if operation == "abs"
60-
print_abs(io, operands...)
168+
return decimal_abs(operands...)
61169
elseif operation == "add"
62-
print_add(io, operands...)
170+
return decimal_add(operands...)
63171
elseif operation == "apply"
64-
print_apply(io, operands...)
172+
return decimal_apply(operands...)
65173
elseif operation == "compare"
66-
print_compare(io, operands...)
174+
return decimal_compare(operands...)
67175
elseif operation == "divide"
68-
print_divide(io, operands...)
176+
return decimal_divide(operands...)
69177
elseif operation == "max"
70-
print_max(io, operands...)
178+
return decimal_max(operands...)
71179
elseif operation == "min"
72-
print_min(io, operands...)
180+
return decimal_min(operands...)
73181
elseif operation == "minus"
74-
print_minus(io, operands...)
182+
return decimal_minus(operands...)
75183
elseif operation == "multiply"
76-
print_multiply(io, operands...)
184+
return decimal_multiply(operands...)
77185
elseif operation == "plus"
78-
print_plus(io, operands...)
186+
return decimal_plus(operands...)
79187
elseif operation == "reduce"
80-
print_reduce(io, operands...)
188+
return decimal_reduce(operands...)
81189
elseif operation == "subtract"
82-
print_subtract(io, operands...)
190+
return decimal_subtract(operands...)
83191
else
84192
throw(ArgumentError(operation))
85193
end
86194
end
87-
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
88-
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
89-
print_apply(io, x) = print(io, decimal(x))
90-
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
91-
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
92-
print_max(io, x, y) = print(io, "max(", decimal(x), ", ", decimal(y), ")")
93-
print_min(io, x, y) = print(io, "min(", decimal(x), ", ", decimal(y), ")")
94-
print_minus(io, x) = print(io, "-(", decimal(x), ")")
95-
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
96-
print_plus(io, x) = print(io, "+(", decimal(x), ")")
97-
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
98-
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))
99-
100-
function print_test(io, test, directives)
101-
println(io, " # $(test.id)")
102-
103-
names = sort!(collect(keys(directives)))
104-
params = join(("$k=$(directives[k])" for k in names), ", ")
105-
print(io, " @with_context ($params) ")
106-
107-
if :overflow test.conditions
108-
print(io, "@test_throws OverflowError ")
109-
print_operation(io, test.operation, test.operands)
110-
println(io)
111-
elseif :division_undefined test.conditions
112-
print(io, "@test_throws UndefinedDivisionError ")
113-
print_operation(io, test.operation, test.operands)
114-
println(io)
115-
elseif :division_by_zero test.conditions
116-
print(io, "@test_throws DivisionByZeroError ")
117-
print_operation(io, test.operation, test.operands)
118-
println(io)
119-
else
120-
print(io, "@test ")
121-
print_operation(io, test.operation, test.operands)
122-
print(io, " == ")
123-
println(io, decimal(test.result))
124-
end
125-
end
126-
127-
function isspecial(value)
128-
value = lowercase(value)
129-
return occursin(r"(inf|nan|#)", value)
130-
end
131195

132-
function translate(io, dectest_path)
133-
directives = Dict{String, Any}()
134-
135-
for line in eachline(dectest_path)
136-
line = strip(line)
137-
138-
isempty(line) && continue
139-
startswith(line, "--") && continue
140-
141-
line = lowercase(line)
142-
143-
if startswith(line, "version:")
144-
# ...
145-
elseif startswith(line, "extended:")
146-
# ...
147-
elseif startswith(line, "clamp:")
148-
# ...
149-
elseif startswith(line, "precision:")
150-
directives["precision"] = _precision(line)
151-
elseif startswith(line, "rounding:")
152-
directives["rounding"] = _rounding(line)
153-
elseif startswith(line, "maxexponent:")
154-
directives["Emax"] = _maxexponent(line)
155-
elseif startswith(line, "minexponent:")
156-
directives["Emin"] = _minexponent(line)
157-
else
158-
test = _test(line)
159-
any(isspecial, test.operands) && continue
160-
occursin("Unsupported", directives["rounding"]) && continue
161-
print_test(io, test, directives)
162-
end
163-
end
164-
end
165-
166-
function (@main)(args=ARGS)
167-
name, dectest_path, output_path = args
168-
169-
open(output_path, "w") do io
170-
println(io, """
171-
using Decimals
172-
using ScopedValues
173-
using Test
174-
using Decimals: @with_context
175-
176-
@testset \"$name\" begin""")
177-
178-
translate(io, dectest_path)
179-
180-
println(io, "end")
181-
end
182-
end
196+
decimal_abs(x) = :(abs($(dec(x))))
197+
decimal_add(x, y) = :($(dec(x)) + $(dec(y)))
198+
decimal_apply(x) = dec(x)
199+
decimal_compare(x, y) = :(cmp($(dec(x)), $(dec(y))))
200+
decimal_divide(x, y) = :($(dec(x)) / $(dec(y)))
201+
decimal_max(x, y) = :(max($(dec(x)), $(dec(y))))
202+
decimal_min(x, y) = :(min($(dec(x)), $(dec(y))))
203+
decimal_minus(x) = :(-($(dec(x))))
204+
decimal_multiply(x, y) = :($(dec(x)) * $(dec(y)))
205+
decimal_plus(x) = :(+($(dec(x))))
206+
decimal_reduce(x) = :(normalize($(dec(x))))
207+
decimal_subtract(x, y) = :($(dec(x)) - $(dec(y)))
183208

0 commit comments

Comments
 (0)