Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f5b102a

Browse files
authoredJul 3, 2024
PYTHON-4525 Transition the existing test_database.py test to be asynchronous (#1716)
1 parent cfa215c commit f5b102a

File tree

5 files changed

+857
-44
lines changed

5 files changed

+857
-44
lines changed
 

‎test/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,12 @@ def require_no_auth(self, func):
552552
func=func,
553553
)
554554

555+
def require_no_fips(self, func):
556+
"""Run a test only if the host does not have FIPS enabled."""
557+
return self._require(
558+
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
559+
)
560+
555561
def require_replica_set(self, func):
556562
"""Run a test only if the client is connected to a replica set."""
557563
return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func)

‎test/asynchronous/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,12 @@ def require_no_auth(self, func):
554554
func=func,
555555
)
556556

557+
def require_no_fips(self, func):
558+
"""Run a test only if the host does not have FIPS enabled."""
559+
return self._require(
560+
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
561+
)
562+
557563
def require_replica_set(self, func):
558564
"""Run a test only if the client is connected to a replica set."""
559565
return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func)

‎test/asynchronous/test_database.py

Lines changed: 772 additions & 0 deletions
Large diffs are not rendered by default.

‎test/test_database.py

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from bson.regex import Regex
4040
from bson.son import SON
4141
from pymongo import helpers_shared
42-
from pymongo.asynchronous import auth
4342
from pymongo.errors import (
4443
CollectionInvalid,
4544
ExecutionTimeout,
@@ -50,11 +49,15 @@
5049
)
5150
from pymongo.read_concern import ReadConcern
5251
from pymongo.read_preferences import ReadPreference
52+
from pymongo.synchronous import auth
5353
from pymongo.synchronous.collection import Collection
5454
from pymongo.synchronous.database import Database
55+
from pymongo.synchronous.helpers import next
5556
from pymongo.synchronous.mongo_client import MongoClient
5657
from pymongo.write_concern import WriteConcern
5758

59+
_IS_SYNC = True
60+
5861

5962
class TestDatabaseNoConnect(unittest.TestCase):
6063
"""Test Database features on a client that does not connect."""
@@ -140,32 +143,38 @@ def test_get_coll(self):
140143
self.assertEqual(db.test.mike, db["test.mike"])
141144

142145
def test_repr(self):
146+
name = "Database"
143147
self.assertEqual(
144148
repr(Database(self.client, "pymongo_test")),
145-
"Database({!r}, {})".format(self.client, repr("pymongo_test")),
149+
"{}({!r}, {})".format(name, self.client, repr("pymongo_test")),
146150
)
147151

148152
def test_create_collection(self):
149153
db = Database(self.client, "pymongo_test")
150154

151155
db.test.insert_one({"hello": "world"})
152-
self.assertRaises(CollectionInvalid, db.create_collection, "test")
156+
with self.assertRaises(CollectionInvalid):
157+
db.create_collection("test")
153158

154159
db.drop_collection("test")
155160

156-
self.assertRaises(TypeError, db.create_collection, 5)
157-
self.assertRaises(TypeError, db.create_collection, None)
158-
self.assertRaises(InvalidName, db.create_collection, "coll..ection")
161+
with self.assertRaises(TypeError):
162+
db.create_collection(5) # type: ignore[arg-type]
163+
with self.assertRaises(TypeError):
164+
db.create_collection(None) # type: ignore[arg-type]
165+
with self.assertRaises(InvalidName):
166+
db.create_collection("coll..ection") # type: ignore[arg-type]
159167

160168
test = db.create_collection("test")
161169
self.assertTrue("test" in db.list_collection_names())
162170
test.insert_one({"hello": "world"})
163-
self.assertEqual(db.test.find_one()["hello"], "world") # type: ignore
171+
self.assertEqual((db.test.find_one())["hello"], "world")
164172

165173
db.drop_collection("test.foo")
166174
db.create_collection("test.foo")
167175
self.assertTrue("test.foo" in db.list_collection_names())
168-
self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
176+
with self.assertRaises(CollectionInvalid):
177+
db.create_collection("test.foo")
169178

170179
def test_list_collection_names(self):
171180
db = Database(self.client, "pymongo_test")
@@ -274,11 +283,11 @@ def test_list_collections(self):
274283
else:
275284
self.assertTrue(False)
276285

277-
colls = db.list_collections(filter={"name": {"$regex": "^test$"}})
278-
self.assertEqual(1, len(list(colls)))
286+
colls = (db.list_collections(filter={"name": {"$regex": "^test$"}})).to_list()
287+
self.assertEqual(1, len(colls))
279288

280-
colls = db.list_collections(filter={"name": {"$regex": "^test.mike$"}})
281-
self.assertEqual(1, len(list(colls)))
289+
colls = (db.list_collections(filter={"name": {"$regex": "^test.mike$"}})).to_list()
290+
self.assertEqual(1, len(colls))
282291

283292
db.drop_collection("test")
284293

@@ -326,8 +335,10 @@ def test_list_collection_names_single_socket(self):
326335
def test_drop_collection(self):
327336
db = Database(self.client, "pymongo_test")
328337

329-
self.assertRaises(TypeError, db.drop_collection, 5)
330-
self.assertRaises(TypeError, db.drop_collection, None)
338+
with self.assertRaises(TypeError):
339+
db.drop_collection(5) # type: ignore[arg-type]
340+
with self.assertRaises(TypeError):
341+
db.drop_collection(None) # type: ignore[arg-type]
331342

332343
db.test.insert_one({"dummy": "object"})
333344
self.assertTrue("test" in db.list_collection_names())
@@ -360,13 +371,17 @@ def test_drop_collection(self):
360371
def test_validate_collection(self):
361372
db = self.client.pymongo_test
362373

363-
self.assertRaises(TypeError, db.validate_collection, 5)
364-
self.assertRaises(TypeError, db.validate_collection, None)
374+
with self.assertRaises(TypeError):
375+
db.validate_collection(5) # type: ignore[arg-type]
376+
with self.assertRaises(TypeError):
377+
db.validate_collection(None) # type: ignore[arg-type]
365378

366379
db.test.insert_one({"dummy": "object"})
367380

368-
self.assertRaises(OperationFailure, db.validate_collection, "test.doesnotexist")
369-
self.assertRaises(OperationFailure, db.validate_collection, db.test.doesnotexist)
381+
with self.assertRaises(OperationFailure):
382+
db.validate_collection("test.doesnotexist")
383+
with self.assertRaises(OperationFailure):
384+
db.validate_collection(db.test.doesnotexist)
370385

371386
self.assertTrue(db.validate_collection("test"))
372387
self.assertTrue(db.validate_collection(db.test))
@@ -426,17 +441,21 @@ def test_cursor_command(self):
426441

427442
self.assertIsInstance(cursor, CommandCursor)
428443

429-
result_docs = list(cursor)
444+
result_docs = cursor.to_list()
430445
self.assertEqual(docs, result_docs)
431446

432447
def test_cursor_command_invalid(self):
433-
self.assertRaises(InvalidOperation, self.db.cursor_command, "usersInfo", "test")
448+
with self.assertRaises(InvalidOperation):
449+
self.db.cursor_command("usersInfo", "test")
434450

435451
@client_context.require_no_fips
436452
def test_password_digest(self):
437-
self.assertRaises(TypeError, auth._password_digest, 5)
438-
self.assertRaises(TypeError, auth._password_digest, True)
439-
self.assertRaises(TypeError, auth._password_digest, None)
453+
with self.assertRaises(TypeError):
454+
auth._password_digest(5) # type: ignore[arg-type, call-arg]
455+
with self.assertRaises(TypeError):
456+
auth._password_digest(True) # type: ignore[arg-type, call-arg]
457+
with self.assertRaises(TypeError):
458+
auth._password_digest(None) # type: ignore[arg-type, call-arg]
440459

441460
self.assertTrue(isinstance(auth._password_digest("mike", "password"), str))
442461
self.assertEqual(
@@ -470,16 +489,20 @@ def test_deref(self):
470489
db = self.client.pymongo_test
471490
db.test.drop()
472491

473-
self.assertRaises(TypeError, db.dereference, 5)
474-
self.assertRaises(TypeError, db.dereference, "hello")
475-
self.assertRaises(TypeError, db.dereference, None)
492+
with self.assertRaises(TypeError):
493+
db.dereference(5) # type: ignore[arg-type]
494+
with self.assertRaises(TypeError):
495+
db.dereference("hello") # type: ignore[arg-type]
496+
with self.assertRaises(TypeError):
497+
db.dereference(None) # type: ignore[arg-type]
476498

477499
self.assertEqual(None, db.dereference(DBRef("test", ObjectId())))
478500
obj: dict[str, Any] = {"x": True}
479-
key = db.test.insert_one(obj).inserted_id
501+
key = (db.test.insert_one(obj)).inserted_id
480502
self.assertEqual(obj, db.dereference(DBRef("test", key)))
481503
self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test")))
482-
self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo"))
504+
with self.assertRaises(ValueError):
505+
db.dereference(DBRef("test", key, "foo"))
483506

484507
self.assertEqual(None, db.dereference(DBRef("test", 4)))
485508
obj = {"_id": 4}
@@ -504,7 +527,7 @@ def test_insert_find_one(self):
504527
db.test.drop()
505528

506529
a_doc = SON({"hello": "world"})
507-
a_key = db.test.insert_one(a_doc).inserted_id
530+
a_key = (db.test.insert_one(a_doc)).inserted_id
508531
self.assertTrue(isinstance(a_doc["_id"], ObjectId))
509532
self.assertEqual(a_doc["_id"], a_key)
510533
self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]}))
@@ -531,12 +554,12 @@ def test_long(self):
531554
db = self.client.pymongo_test
532555
db.test.drop()
533556
db.test.insert_one({"x": 9223372036854775807})
534-
retrieved = db.test.find_one()["x"] # type: ignore
557+
retrieved = (db.test.find_one())["x"]
535558
self.assertEqual(Int64(9223372036854775807), retrieved)
536559
self.assertIsInstance(retrieved, Int64)
537560
db.test.delete_many({})
538561
db.test.insert_one({"x": Int64(1)})
539-
retrieved = db.test.find_one()["x"] # type: ignore
562+
retrieved = (db.test.find_one())["x"]
540563
self.assertEqual(Int64(1), retrieved)
541564
self.assertIsInstance(retrieved, Int64)
542565

@@ -578,7 +601,8 @@ def test_command_response_without_ok(self):
578601
# Sometimes (SERVER-10891) the server's response to a badly-formatted
579602
# command document will have no 'ok' field. We should raise
580603
# OperationFailure instead of KeyError.
581-
self.assertRaises(OperationFailure, helpers_shared._check_command_response, {}, None)
604+
with self.assertRaises(OperationFailure):
605+
helpers_shared._check_command_response({}, None)
582606

583607
try:
584608
helpers_shared._check_command_response({"$err": "foo"}, None)
@@ -624,22 +648,23 @@ def test_command_max_time_ms(self):
624648
try:
625649
db = self.client.pymongo_test
626650
db.command("count", "test")
627-
self.assertRaises(ExecutionTimeout, db.command, "count", "test", maxTimeMS=1)
651+
with self.assertRaises(ExecutionTimeout):
652+
db.command("count", "test", maxTimeMS=1)
628653
pipeline = [{"$project": {"name": 1, "count": 1}}]
629654
# Database command helper.
630655
db.command("aggregate", "test", pipeline=pipeline, cursor={})
631-
self.assertRaises(
632-
ExecutionTimeout,
633-
db.command,
634-
"aggregate",
635-
"test",
636-
pipeline=pipeline,
637-
cursor={},
638-
maxTimeMS=1,
639-
)
656+
with self.assertRaises(ExecutionTimeout):
657+
db.command(
658+
"aggregate",
659+
"test",
660+
pipeline=pipeline,
661+
cursor={},
662+
maxTimeMS=1,
663+
)
640664
# Collection helper.
641665
db.test.aggregate(pipeline=pipeline)
642-
self.assertRaises(ExecutionTimeout, db.test.aggregate, pipeline, maxTimeMS=1)
666+
with self.assertRaises(ExecutionTimeout):
667+
db.test.aggregate(pipeline, maxTimeMS=1)
643668
finally:
644669
self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off")
645670

@@ -723,7 +748,10 @@ def test_database_aggregation_fake_cursor(self):
723748
with self.assertRaises(StopIteration):
724749
next(cursor)
725750

726-
result = wait_until(output_coll.find_one, "read unacknowledged write")
751+
def lambda_fn():
752+
return output_coll.find_one()
753+
754+
result = wait_until(lambda_fn, "read unacknowledged write")
727755
self.assertEqual(result["dummy"], self.result["dummy"])
728756

729757
def test_bool(self):

‎tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
"__init__.py",
133133
"conftest.py",
134134
"test_collection.py",
135+
"test_database.py",
135136
]
136137

137138
sync_test_files = [

0 commit comments

Comments
 (0)
Please sign in to comment.