Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Decimal scalar inaccurate conversion of float to decimal #1594

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions graphene/types/decimal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decimal import Decimal as _Decimal

from graphql import Undefined
from graphql.language.ast import StringValueNode, IntValueNode
from graphql.language.ast import StringValueNode, IntValueNode, FloatValueNode

from .scalars import Scalar

Expand All @@ -22,13 +22,13 @@ def serialize(dec):

@classmethod
def parse_literal(cls, node, _variables=None):
if isinstance(node, (StringValueNode, IntValueNode)):
if isinstance(node, (StringValueNode, IntValueNode, FloatValueNode)):
return cls.parse_value(node.value)
return Undefined

@staticmethod
def parse_value(value):
try:
return _Decimal(value)
return _Decimal(str(value))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same implementation as done at Strawberry

except Exception:
return Undefined
52 changes: 46 additions & 6 deletions graphene/types/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ def test_decimal_string_query():
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_decimal_float_query():
float_value = 1969.1974
decimal_value = decimal.Decimal(str(float_value))
result = schema.execute("""{ decimal(input: %s) }""" % float_value)
assert not result.errors
assert not result.errors
assert result.data == {"decimal": str(decimal_value)}
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_decimal_int_query():
int_value = 1234
decimal_value = decimal.Decimal(str(int_value))
result = schema.execute("""{ decimal(input: %s) }""" % int_value)
assert not result.errors
assert not result.errors
assert result.data == {"decimal": str(decimal_value)}
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_decimal_string_query_variable():
decimal_value = decimal.Decimal("1969.1974")

Expand All @@ -35,6 +55,32 @@ def test_decimal_string_query_variable():
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_decimal_float_query_variable():
float_value = 1969.1974
decimal_value = decimal.Decimal(str(float_value))

result = schema.execute(
"""query Test($decimal: Decimal){ decimal(input: $decimal) }""",
variables={"decimal": float_value},
)
assert not result.errors
assert result.data == {"decimal": str(decimal_value)}
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_decimal_int_query_variable():
int_value = 1234
decimal_value = decimal.Decimal(str(int_value))

result = schema.execute(
"""query Test($decimal: Decimal){ decimal(input: $decimal) }""",
variables={"decimal": int_value},
)
assert not result.errors
assert result.data == {"decimal": str(decimal_value)}
assert decimal.Decimal(result.data["decimal"]) == decimal_value


def test_bad_decimal_query():
not_a_decimal = "Nobody expects the Spanish Inquisition!"

Expand All @@ -53,12 +99,6 @@ def test_bad_decimal_query():
assert result.data is None
assert result.errors[0].message == "Expected value of type 'Decimal', found true."

result = schema.execute("{ decimal(input: 1.2) }")
assert result.errors
assert len(result.errors) == 1
assert result.data is None
assert result.errors[0].message == "Expected value of type 'Decimal', found 1.2."


def test_decimal_string_query_integer():
decimal_value = 1
Expand Down