From 2da461831a9e94b89a67f9bc621c1fd38d2ba84b Mon Sep 17 00:00:00 2001 From: M Aswin Kishore <60577077+mak626@users.noreply.github.com> Date: Mon, 10 Feb 2025 21:01:28 +0530 Subject: [PATCH] fix[Decimal]: converting from float -> decimal resulted in inaccurate values Example:- Decimal(0.01) -> Decimal("0.01000000000000000020816681711721685132943093776702880859375") Additional: Decimal field now accepts string, int & float. Previously it was just string & int --- graphene/types/decimal.py | 6 ++-- graphene/types/tests/test_decimal.py | 52 ++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/graphene/types/decimal.py b/graphene/types/decimal.py index 69952f96d..24bf71efc 100644 --- a/graphene/types/decimal.py +++ b/graphene/types/decimal.py @@ -1,7 +1,7 @@ from decimal import Decimal as _Decimal from graphql import Undefined -from graphql.language.ast import StringValueNode, IntValueNode +from graphql.language.ast import StringValueNode, IntValueNode, FloatValueNode from .scalars import Scalar @@ -22,13 +22,13 @@ def serialize(dec): @classmethod def parse_literal(cls, node, _variables=None): - if isinstance(node, (StringValueNode, IntValueNode)): + if isinstance(node, (StringValueNode, IntValueNode, FloatValueNode)): return cls.parse_value(node.value) return Undefined @staticmethod def parse_value(value): try: - return _Decimal(value) + return _Decimal(str(value)) except Exception: return Undefined diff --git a/graphene/types/tests/test_decimal.py b/graphene/types/tests/test_decimal.py index 1ba48bd1d..b5b5c0c3a 100644 --- a/graphene/types/tests/test_decimal.py +++ b/graphene/types/tests/test_decimal.py @@ -23,6 +23,26 @@ def test_decimal_string_query(): assert decimal.Decimal(result.data["decimal"]) == decimal_value +def test_decimal_float_query(): + float_value = 1969.1974 + decimal_value = decimal.Decimal(str(float_value)) + result = schema.execute("""{ decimal(input: %s) }""" % float_value) + assert not result.errors + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + +def test_decimal_int_query(): + int_value = 1234 + decimal_value = decimal.Decimal(str(int_value)) + result = schema.execute("""{ decimal(input: %s) }""" % int_value) + assert not result.errors + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + def test_decimal_string_query_variable(): decimal_value = decimal.Decimal("1969.1974") @@ -35,6 +55,32 @@ def test_decimal_string_query_variable(): assert decimal.Decimal(result.data["decimal"]) == decimal_value +def test_decimal_float_query_variable(): + float_value = 1969.1974 + decimal_value = decimal.Decimal(str(float_value)) + + result = schema.execute( + """query Test($decimal: Decimal){ decimal(input: $decimal) }""", + variables={"decimal": float_value}, + ) + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + +def test_decimal_int_query_variable(): + int_value = 1234 + decimal_value = decimal.Decimal(str(int_value)) + + result = schema.execute( + """query Test($decimal: Decimal){ decimal(input: $decimal) }""", + variables={"decimal": int_value}, + ) + assert not result.errors + assert result.data == {"decimal": str(decimal_value)} + assert decimal.Decimal(result.data["decimal"]) == decimal_value + + def test_bad_decimal_query(): not_a_decimal = "Nobody expects the Spanish Inquisition!" @@ -53,12 +99,6 @@ def test_bad_decimal_query(): assert result.data is None assert result.errors[0].message == "Expected value of type 'Decimal', found true." - result = schema.execute("{ decimal(input: 1.2) }") - assert result.errors - assert len(result.errors) == 1 - assert result.data is None - assert result.errors[0].message == "Expected value of type 'Decimal', found 1.2." - def test_decimal_string_query_integer(): decimal_value = 1