Skip to content

Commit 5deb07a

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 587a9c5 + cf1a80f commit 5deb07a

File tree

8 files changed

+86
-19
lines changed

8 files changed

+86
-19
lines changed

django/contrib/gis/db/backends/base/features.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from django.contrib.gis.db import models
44

5+
from .operations import BaseSpatialOperations
6+
57

68
class BaseSpatialFeatures:
79
gis_enabled = True
@@ -107,5 +109,11 @@ def __getattr__(self, name):
107109
m = re.match(r"has_(\w*)_function$", name)
108110
if m:
109111
func_name = m[1]
112+
if func_name not in BaseSpatialOperations.unsupported_functions:
113+
raise ValueError(
114+
f"DatabaseFeatures.has_{func_name}_function isn't valid. "
115+
f'Is "{func_name}" missing from '
116+
"BaseSpatialOperations.unsupported_functions?"
117+
)
110118
return func_name not in self.connection.ops.unsupported_functions
111119
raise AttributeError

django/contrib/gis/db/backends/base/operations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,17 @@ def select_extent(self):
3939
"AsGML",
4040
"AsKML",
4141
"AsSVG",
42+
"AsWKB",
43+
"AsWKT",
4244
"Azimuth",
4345
"BoundingCircle",
4446
"Centroid",
4547
"ClosestPoint",
4648
"Difference",
4749
"Distance",
50+
"DistanceSpheroid",
4851
"Envelope",
52+
"ForcePolygonCW",
4953
"FromWKB",
5054
"FromWKT",
5155
"GeoHash",

django/db/models/expressions.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,14 +1711,24 @@ def as_sql(
17111711
except EmptyResultSet:
17121712
continue
17131713
except FullResultSet:
1714-
default_sql, default_params = compiler.compile(case.result)
1714+
default = case.result
17151715
break
17161716
case_parts.append(case_sql)
17171717
sql_params.extend(case_params)
17181718
else:
1719-
default_sql, default_params = compiler.compile(self.default)
1720-
if not case_parts:
1721-
return default_sql, default_params
1719+
default = self.default
1720+
if case_parts:
1721+
default_sql, default_params = compiler.compile(default)
1722+
else:
1723+
if (
1724+
isinstance(default, Value)
1725+
and (output_field := default._output_field_or_none) is not None
1726+
):
1727+
from django.db.models.functions import Cast
1728+
1729+
default = Cast(default, output_field)
1730+
return compiler.compile(default)
1731+
17221732
case_joiner = case_joiner or self.case_joiner
17231733
template_params["cases"] = case_joiner.join(case_parts)
17241734
template_params["default"] = default_sql

django/test/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,12 +728,13 @@ def __enter__(self):
728728
self.connection.ensure_connection()
729729
self.initial_queries = len(self.connection.queries_log)
730730
self.final_queries = None
731-
request_started.disconnect(reset_queries)
731+
self.reset_queries_disconnected = request_started.disconnect(reset_queries)
732732
return self
733733

734734
def __exit__(self, exc_type, exc_value, traceback):
735735
self.connection.force_debug_cursor = self.force_debug_cursor
736-
request_started.connect(reset_queries)
736+
if self.reset_queries_disconnected:
737+
request_started.connect(reset_queries)
737738
if exc_type is not None:
738739
return
739740
self.final_queries = len(self.connection.queries_log)

tests/gis_tests/geoapp/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def test_memsize(self):
559559
# Exact value depends on database and version.
560560
self.assertTrue(20 <= ptown.size <= 105)
561561

562-
@skipUnlessDBFeature("has_NumGeom_function")
562+
@skipUnlessDBFeature("has_NumGeometries_function")
563563
def test_num_geom(self):
564564
# Both 'countries' only have two geometries.
565565
for c in Country.objects.annotate(num_geom=functions.NumGeometries("mpoly")):
@@ -576,7 +576,7 @@ def test_num_geom(self):
576576
else:
577577
self.assertEqual(1, city.num_geom)
578578

579-
@skipUnlessDBFeature("has_NumPoint_function")
579+
@skipUnlessDBFeature("has_NumPoints_function")
580580
def test_num_points(self):
581581
coords = [(-95.363151, 29.763374), (-95.448601, 29.713803)]
582582
Track.objects.create(name="Foo", line=LineString(coords))

tests/gis_tests/tests.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import unittest
22

33
from django.core.exceptions import ImproperlyConfigured
4-
from django.db import ProgrammingError
4+
from django.db import ProgrammingError, connection
55
from django.db.backends.base.base import NO_DB_ALIAS
6+
from django.test import TestCase
67

78
try:
89
from django.contrib.gis.db.backends.postgis.operations import PostGISOperations
@@ -12,6 +13,16 @@
1213
HAS_POSTGRES = False
1314

1415

16+
class BaseSpatialFeaturesTests(TestCase):
17+
def test_invalid_has_func_function(self):
18+
msg = (
19+
'DatabaseFeatures.has_Invalid_function isn\'t valid. Is "Invalid" '
20+
"missing from BaseSpatialOperations.unsupported_functions?"
21+
)
22+
with self.assertRaisesMessage(ValueError, msg):
23+
connection.features.has_Invalid_function
24+
25+
1526
if HAS_POSTGRES:
1627

1728
class FakeConnection:

tests/ordering/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33

44
from django.core.exceptions import FieldError
55
from django.db.models import (
6+
Case,
67
CharField,
78
Count,
89
DateTimeField,
910
F,
11+
IntegerField,
1012
Max,
1113
OrderBy,
1214
OuterRef,
1315
Subquery,
1416
Value,
17+
When,
1518
)
1619
from django.db.models.functions import Length, Upper
1720
from django.test import TestCase
@@ -526,6 +529,17 @@ def test_order_by_constant_value(self):
526529
qs = Article.objects.order_by(Value("1", output_field=CharField()), "-headline")
527530
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
528531

532+
def test_order_by_case_when_constant_value(self):
533+
qs = Article.objects.order_by(
534+
Case(
535+
When(pk__in=[], then=Value(1)),
536+
default=Value(0),
537+
output_field=IntegerField(),
538+
).desc(),
539+
"pk",
540+
)
541+
self.assertSequenceEqual(qs, [self.a1, self.a2, self.a3, self.a4])
542+
529543
def test_related_ordering_duplicate_table_reference(self):
530544
"""
531545
An ordering referencing a model with an ordering referencing a model

tests/test_utils/tests.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ class CaptureQueriesContextManagerTests(TestCase):
386386
@classmethod
387387
def setUpTestData(cls):
388388
cls.person_pk = str(Person.objects.create(name="test").pk)
389+
cls.url = f"/test_utils/get_person/{cls.person_pk}/"
389390

390391
def test_simple(self):
391392
with CaptureQueriesContext(connection) as captured_queries:
@@ -418,25 +419,38 @@ def test_failure(self):
418419

419420
def test_with_client(self):
420421
with CaptureQueriesContext(connection) as captured_queries:
421-
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
422+
self.client.get(self.url)
422423
self.assertEqual(len(captured_queries), 1)
423424
self.assertIn(self.person_pk, captured_queries[0]["sql"])
424425

425426
with CaptureQueriesContext(connection) as captured_queries:
426-
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
427+
self.client.get(self.url)
427428
self.assertEqual(len(captured_queries), 1)
428429
self.assertIn(self.person_pk, captured_queries[0]["sql"])
429430

430431
with CaptureQueriesContext(connection) as captured_queries:
431-
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
432-
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
432+
self.client.get(self.url)
433+
self.client.get(self.url)
433434
self.assertEqual(len(captured_queries), 2)
434435
self.assertIn(self.person_pk, captured_queries[0]["sql"])
435436
self.assertIn(self.person_pk, captured_queries[1]["sql"])
436437

438+
def test_with_client_nested(self):
439+
with CaptureQueriesContext(connection) as captured_queries:
440+
Person.objects.count()
441+
with CaptureQueriesContext(connection):
442+
pass
443+
self.client.get(self.url)
444+
self.assertEqual(2, len(captured_queries))
445+
437446

438447
@override_settings(ROOT_URLCONF="test_utils.urls")
439448
class AssertNumQueriesContextManagerTests(TestCase):
449+
@classmethod
450+
def setUpTestData(cls):
451+
cls.person_pk = str(Person.objects.create(name="test").pk)
452+
cls.url = f"/test_utils/get_person/{cls.person_pk}/"
453+
440454
def test_simple(self):
441455
with self.assertNumQueries(0):
442456
pass
@@ -459,17 +473,22 @@ def test_failure(self):
459473
raise TypeError
460474

461475
def test_with_client(self):
462-
person = Person.objects.create(name="test")
463-
464476
with self.assertNumQueries(1):
465-
self.client.get("/test_utils/get_person/%s/" % person.pk)
477+
self.client.get(self.url)
466478

467479
with self.assertNumQueries(1):
468-
self.client.get("/test_utils/get_person/%s/" % person.pk)
480+
self.client.get(self.url)
469481

470482
with self.assertNumQueries(2):
471-
self.client.get("/test_utils/get_person/%s/" % person.pk)
472-
self.client.get("/test_utils/get_person/%s/" % person.pk)
483+
self.client.get(self.url)
484+
self.client.get(self.url)
485+
486+
def test_with_client_nested(self):
487+
with self.assertNumQueries(2):
488+
Person.objects.count()
489+
with self.assertNumQueries(0):
490+
pass
491+
self.client.get(self.url)
473492

474493

475494
@override_settings(ROOT_URLCONF="test_utils.urls")

0 commit comments

Comments
 (0)