Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-3636 MongoClient should perform SRV resolution lazily #2191

Merged
merged 64 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ead780a
WIP (not cleaned up)
sleepyStick Mar 7, 2025
79c09ea
this might be broken? unsure....
sleepyStick Mar 10, 2025
3afd732
Merge branch 'mongodb:master' into PYTHON-3636
sleepyStick Mar 10, 2025
7d771cb
keep parse_uri as is and have it call two different functions instead
sleepyStick Mar 10, 2025
0f64689
cleanup
sleepyStick Mar 10, 2025
ed50141
some refactoring to reduce code duplication
sleepyStick Mar 11, 2025
ed25867
fix typing
sleepyStick Mar 11, 2025
8d48f44
remove copied doc string
sleepyStick Mar 11, 2025
1a3efed
move init_background to only be called upon client connection
sleepyStick Mar 11, 2025
d94743b
only define topology after uri resolution
sleepyStick Mar 11, 2025
ad20606
okay turns out it was *too* lazy HAHA
sleepyStick Mar 11, 2025
dfa0639
cleanup
sleepyStick Mar 11, 2025
57edcbc
more cleanup
sleepyStick Mar 12, 2025
58a58a0
more cleanup
sleepyStick Mar 12, 2025
35a41e9
fix fork tests
sleepyStick Mar 12, 2025
d343311
fix typing
sleepyStick Mar 12, 2025
d03c78f
determine is_srv differently
sleepyStick Mar 12, 2025
e1d091f
fix test
sleepyStick Mar 12, 2025
8efd549
fix encrypter
sleepyStick Mar 12, 2025
4c06dec
undoing unintended changes
sleepyStick Mar 12, 2025
511fcc4
bringing back a previously deleted test
sleepyStick Mar 12, 2025
40509a1
undoing unintended changes
sleepyStick Mar 12, 2025
97e0778
some refactoring
sleepyStick Mar 12, 2025
1de56d4
Merge branch 'master' into PYTHON-3636
sleepyStick Mar 12, 2025
2653a56
fix typing
sleepyStick Mar 12, 2025
4c23ee0
Update pymongo/asynchronous/mongo_client.py
sleepyStick Mar 13, 2025
bc61199
Update pymongo/asynchronous/mongo_client.py
sleepyStick Mar 13, 2025
8c2b368
respond to comments and move srv_resolver to async
sleepyStick Mar 13, 2025
e38c2ad
refactor part 2
sleepyStick Mar 13, 2025
3c1bb28
Merge branch 'master' into PYTHON-3636
sleepyStick Mar 13, 2025
94fec44
fix circular import
sleepyStick Mar 13, 2025
99a07fe
Merge branch 'PYTHON-3636' of github.com:sleepyStick/mongo-python-dri…
sleepyStick Mar 13, 2025
32fabb9
fix tests
sleepyStick Mar 13, 2025
af568da
fix test and repr
sleepyStick Mar 13, 2025
82bcd38
fix test
sleepyStick Mar 13, 2025
60bf17d
fix import for test
sleepyStick Mar 13, 2025
63ba7be
change helpers import
sleepyStick Mar 13, 2025
7585e04
fix uri_parser
sleepyStick Mar 13, 2025
2c69412
fix srv_resolver
sleepyStick Mar 13, 2025
f834b89
add missing awaits
sleepyStick Mar 13, 2025
d450457
add missing await
sleepyStick Mar 13, 2025
2a8b1b2
Update test/auth_aws/test_auth_aws.py
sleepyStick Mar 14, 2025
c82cf50
Update test/asynchronous/helpers.py
sleepyStick Mar 14, 2025
9256808
Update test/auth_oidc/test_auth_oidc.py
sleepyStick Mar 14, 2025
c6d2ceb
address comments
sleepyStick Mar 14, 2025
d8d2c26
Merge branch 'PYTHON-3636' of github.com:sleepyStick/mongo-python-dri…
sleepyStick Mar 14, 2025
8927a27
undo import change in helpers
sleepyStick Mar 14, 2025
76a68b2
change client eq and hash
sleepyStick Mar 17, 2025
b60eb60
address comments part 1
sleepyStick Mar 17, 2025
0b6d303
address comment ish - remove first
sleepyStick Mar 18, 2025
0ca6afd
re-order call to super's init
sleepyStick Mar 18, 2025
259d36b
Merge branch 'main' into PYTHON-3636
sleepyStick Mar 18, 2025
d616135
update link to use https based on prev commit on main
sleepyStick Mar 18, 2025
379dfb6
fix typing
sleepyStick Mar 18, 2025
5466484
oops fix typing pt2
sleepyStick Mar 18, 2025
2900718
address comments
sleepyStick Mar 21, 2025
63676b6
fix patch string
sleepyStick Mar 21, 2025
cd9bd92
address comments pt1
sleepyStick Mar 21, 2025
f33d091
Merge branch 'main' into PYTHON-3636
sleepyStick Mar 21, 2025
a7c090d
add test for repr and change changelog
sleepyStick Mar 21, 2025
93bc3c9
fix test
sleepyStick Mar 21, 2025
afd82f4
Update doc/changelog.rst
ShaneHarvey Mar 24, 2025
99a5c8a
Address review
NoahStapp Mar 24, 2025
21d3f58
Merge branch 'master' into PYTHON-3636
ShaneHarvey Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ def _txn_read_preference(self) -> Optional[_ServerMode]:
def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None:
if isinstance(self._server_session, _EmptyServerSession):
old = self._server_session
assert self._client._topology is not None
self._server_session = self._client._topology.get_server_session(
logical_session_timeout_minutes
)
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
Expand Down
313 changes: 228 additions & 85 deletions pymongo/asynchronous/mongo_client.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.errors import NetworkTimeout, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _async_create_lock
Expand All @@ -33,7 +34,6 @@
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver

if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
Expand Down Expand Up @@ -395,7 +395,7 @@ async def _run(self) -> None:
# Don't poll right after creation, wait 60 seconds first
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
return
seedlist = self._get_seedlist()
seedlist = await self._get_seedlist()
if seedlist:
self._seedlist = seedlist
try:
Expand All @@ -404,7 +404,7 @@ async def _run(self) -> None:
# Topology was garbage-collected.
await self.close()

def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
"""Poll SRV records for a seedlist.

Returns a list of ServerDescriptions.
Expand All @@ -415,7 +415,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
Expand Down
1 change: 1 addition & 0 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
raise

if handler:
assert handler.client._topology is not None
await handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
Expand Down
158 changes: 158 additions & 0 deletions pymongo/asynchronous/srv_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.

"""Support for resolving hosts and options from mongodb+srv:// URIs."""
from __future__ import annotations

import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union

from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError

if TYPE_CHECKING:
from dns import resolver

_IS_SYNC = False


def _have_dnspython() -> bool:
try:
import dns # noqa: F401

return True
except ImportError:
return False


# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text: Union[str, bytes]) -> str:
if isinstance(text, bytes):
return text.decode()
return text


# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
if _IS_SYNC:
from dns import resolver

if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
else:
from dns import asyncresolver

if hasattr(asyncresolver, "resolve"):
# dnspython >= 2
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
raise ConfigurationError("Upgrade to dnspython version >= 2.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message should explicitly inform users that they are attempting to use the async API with an old dnspython version. Telling them only to upgrade without any other information is inconsistent with the underlying reason.

Copy link
Member

@ShaneHarvey ShaneHarvey Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a commit suggestion here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ConfigurationError("Upgrade to dnspython version >= 2.0")
raise ConfigurationError("Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections.")



_INVALID_HOST_MSG = (
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
"Did you mean to use 'mongodb://'?"
)


class _SrvResolver:
def __init__(
self,
fqdn: str,
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass

try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))

async def get_options(self) -> Optional[str]:
from dns import resolver

try:
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc)) from None
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]

async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
try:
results = await _resolve(
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc)) from None
return results

async def _get_srv_response_and_hosts(
self, encapsulate_errors: bool
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = await self._resolve_uri(encapsulate_errors)

# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
for res in results
]

# Validate hosts
for node in nodes:
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes

async def get_hosts(self) -> list[tuple[str, Any]]:
_, nodes = await self._get_srv_response_and_hosts(True)
return nodes

async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
results, nodes = await self._get_srv_response_and_hosts(False)
rrset = results.rrset
ttl = rrset.ttl if rrset else 0
return nodes, ttl
172 changes: 172 additions & 0 deletions pymongo/asynchronous/uri_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

from typing import Any, Optional
from urllib.parse import unquote_plus

from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.uri_parser_shared import (
_ALLOWED_TXT_OPTS,
DEFAULT_PORT,
SCHEME,
SCHEME_LEN,
SRV_SCHEME_LEN,
_check_options,
_validate_uri,
split_hosts,
split_options,
)

_IS_SYNC = False


async def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.

Returns a dict of the form::

{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}

If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.

:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name

.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.

.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.

.. versionchanged:: 3.9
Added the ``normalize`` parameter.

.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.

.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.

.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
result.update(
await _parse_srv(
uri,
default_port,
validate,
warn,
normalize,
connect_timeout,
srv_service_name,
srv_max_hosts,
)
)
return result


async def _parse_srv(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
else:
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]

options = _CaseInsensitiveDictionary()

host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, _ = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part

if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
_, _, hosts = host_part.rpartition("@")
else:
hosts = host_part

hosts = unquote_plus(hosts)
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
nodes = split_hosts(hosts, default_port=None)
fqdn, port = nodes[0]

# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = await dns_resolver.get_hosts()
dns_options = await dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
else:
nodes = split_hosts(hosts, default_port=default_port)

_check_options(nodes, options)

return {
"nodelist": nodes,
"options": options,
}
Loading
Loading