Skip to content

Commit 9990ebe

Browse files
authored
Merge pull request #119 from bksaiki/fixes
Fixes and `tree_sum`
2 parents bbdfcaf + 8c2dccd commit 9990ebe

File tree

4 files changed

+54
-1
lines changed

4 files changed

+54
-1
lines changed

fpy2/interpret/default.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,21 @@ def _visit_block(self, block: StmtBlock, ctx: Context):
840840
for stmt in block.stmts:
841841
self._visit_statement(stmt, ctx)
842842

843+
def _cvt_return(self, x: Value):
844+
match x:
845+
case bool() | Float() | Context():
846+
return x
847+
case Fraction():
848+
return Float.from_rational(x) if is_dyadic(x) else x
849+
case tuple():
850+
return tuple(self._cvt_return(v) for v in x)
851+
case list():
852+
for i in range(len(x)):
853+
x[i] = self._cvt_return(x[i])
854+
return x
855+
case _:
856+
raise RuntimeError('unreachable')
857+
843858
def _visit_function(self, func: FuncDef, ctx: Context):
844859
# process free variables
845860
for var in func.free_vars:
@@ -851,7 +866,7 @@ def _visit_function(self, func: FuncDef, ctx: Context):
851866
self._visit_block(func.body, ctx)
852867
raise RuntimeError('no return statement encountered')
853868
except FunctionReturnError as e:
854-
return e.value
869+
return self._cvt_return(e.value)
855870

856871
def _visit_expr(self, e: Expr, ctx: Context) -> Value:
857872
return super()._visit_expr(e, ctx)

fpy2/libraries/core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,42 @@ def max_e(xs: list[fp.Real]) -> tuple[fp.Real, bool]:
234234

235235
return (largest_e, any_non_zero)
236236

237+
############################################################
238+
# Arithmetic
239+
240+
@fp.fpy
241+
def tree_sum(xs: list[fp.Real]):
242+
"""
243+
Sums the elements of xs in a tree.
244+
Each sum is rounded under the current rounding context.
245+
246+
Args:
247+
xs: A list of real numbers. The length of xs must be positive and a power of 2.
248+
249+
Returns:
250+
The sum of the elements of xs.
251+
"""
252+
253+
with fp.INTEGER:
254+
n: fp.Real = len(xs)
255+
assert n > 0, "Length of xs must be positive"
256+
257+
depth = fp.log2(n)
258+
assert fp.pow(2, depth) == n, "Length of xs must be a power of 2"
259+
260+
t = [x for x in xs]
261+
for _ in range(depth):
262+
with fp.INTEGER:
263+
n /= 2
264+
265+
for i in range(n): # type: ignore[arg-type]
266+
with fp.INTEGER:
267+
j = 2 * i
268+
k = 2 * i + 1
269+
t[i] = t[j] + t[k]
270+
271+
return t[0]
272+
237273
############################################################
238274
# Context operations
239275

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ requires-python = ">=3.11"
1919
dependencies = [
2020
"titanfp==0.1.1",
2121
"gmpy2==2.2",
22+
"matplotlib"
2223
]
2324

2425
[project.optional-dependencies]

tests/infra/backend/cpp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _test_unit(output_dir: Path, no_cc: bool = False):
189189
'_modf_spec',
190190
'isinteger',
191191
'_ldexp_spec',
192+
'tree_sum',
192193
# eft
193194
'ideal_2sum',
194195
'ideal_2mul',

0 commit comments

Comments
 (0)