Skip to content

Commit 4006d14

Browse files
improvement, python: clean up endpoint functions by centralizing logic (#3761)
1 parent a44db33 commit 4006d14

File tree

351 files changed

+25782
-34701
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

351 files changed

+25782
-34701
lines changed

generators/python/core_utilities/sdk/http_client.py

Lines changed: 325 additions & 20 deletions
Large diffs are not rendered by default.

generators/python/sdk/CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [2.8.0] - 2024-06-03
9+
10+
- Improvement: Endpoint function request logic has been abstracted into the request function of the wrapped httpx client.
11+
812
## [2.7.0] - 2024-05-30
913

1014
- Improvement: The generator now outputs an `exampleId` alongside each generated snippet so that
11-
we can correlate snippets with the relevant examples. This is useful for retrieving examples from
12-
Fern's API and making sure that you can show multiple snippets in the generated docs.
15+
we can correlate snippets with the relevant examples. This is useful for retrieving examples from
16+
Fern's API and making sure that you can show multiple snippets in the generated docs.
1317

1418
## [2.6.1] - 2024-05-31
1519

generators/python/sdk/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.7.0
1+
2.8.0

generators/python/src/fern_python/external_dependencies/httpx.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@ class HttpX:
3434
@staticmethod
3535
def make_request(
3636
*,
37-
url: AST.Expression,
37+
path: Optional[AST.Expression],
38+
url: Optional[AST.Expression],
3839
method: str,
3940
query_parameters: Optional[AST.Expression],
4041
request_body: Optional[AST.Expression],
4142
headers: Optional[AST.Expression],
4243
files: Optional[AST.Expression],
4344
content: Optional[AST.Expression],
44-
auth: Optional[AST.Expression],
45-
timeout: AST.Expression,
4645
response_variable_name: str,
46+
request_options_variable_name: str,
4747
is_async: bool,
4848
is_streaming: bool,
4949
response_code_writer: AST.CodeWriter,
5050
reference_to_client: AST.Expression,
51-
max_retries: AST.Expression,
51+
is_default_body_parameter_used: bool,
5252
) -> AST.Expression:
5353
def add_request_params(*, writer: AST.NodeWriter) -> None:
5454
if query_parameters is not None:
@@ -76,21 +76,10 @@ def add_request_params(*, writer: AST.NodeWriter) -> None:
7676
writer.write_node(headers)
7777
writer.write_line(",")
7878

79-
if auth is not None:
80-
writer.write("auth=")
81-
writer.write_node(auth)
82-
writer.write_line(",")
83-
84-
writer.write("timeout=")
85-
writer.write_node(timeout)
86-
writer.write_line(",")
87-
88-
writer.write("retries=0")
89-
writer.write_line(",")
79+
writer.write(f"request_options={request_options_variable_name},")
9080

91-
writer.write("max_retries=")
92-
writer.write_node(max_retries)
93-
writer.write_line(", # type: ignore")
81+
if is_default_body_parameter_used:
82+
writer.write_line("omit=OMIT,")
9483

9584
def write_non_streaming_call(
9685
*,
@@ -107,10 +96,17 @@ def make_non_streaming_request(
10796
if is_async:
10897
writer.write("await ")
10998
writer.write_node(reference_to_client)
110-
writer.write(f'.request(method="{method}", url=')
111-
writer.write_node(url)
112-
writer.write(", ")
99+
writer.write_line(f".request(")
100+
113101
with writer.indent():
102+
if path is not None:
103+
writer.write_node(path)
104+
writer.write(",")
105+
if url is not None:
106+
writer.write("base_url=")
107+
writer.write_node(url)
108+
writer.write(",")
109+
writer.write_line(f'method="{method}",')
114110
add_request_params(writer=writer)
115111
writer.write_line(")")
116112

@@ -119,10 +115,17 @@ def write_streaming_call(*, writer: AST.NodeWriter) -> None:
119115
writer.write("async ")
120116
writer.write("with ")
121117
writer.write_node(reference_to_client)
122-
writer.write(f'.stream(method="{method}", url=')
123-
writer.write_node(url)
124-
writer.write(", ")
118+
writer.write(f".stream(")
119+
125120
with writer.indent():
121+
if path is not None:
122+
writer.write_node(path)
123+
writer.write(",")
124+
if url is not None:
125+
writer.write("base_url=")
126+
writer.write_node(url)
127+
writer.write(",")
128+
writer.write_line(f'method="{method}",')
126129
add_request_params(writer=writer)
127130
writer.write_line(f") as {response_variable_name}:")
128131

generators/python/src/fern_python/generators/sdk/client_generator/endpoint_function_generator.py

Lines changed: 26 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from fern_python.codegen import AST
88
from fern_python.codegen.ast.ast_node.node_writer import NodeWriter
9-
from fern_python.external_dependencies import HttpX, UrlLibParse
9+
from fern_python.external_dependencies import HttpX
1010
from fern_python.generators.sdk.client_generator.endpoint_response_code_writer import (
1111
EndpointDummySnippetConfig,
1212
EndpointResponseCodeWriter,
@@ -129,20 +129,9 @@ def __init__(
129129
_named_parameter_names.append(name)
130130
self._path_parameter_names[path_parameter.name] = name
131131

132-
def generate(self) -> GeneratedEndpointFunction:
133-
is_primitive: bool = (
134-
self._endpoint.request_body.visit(
135-
inlined_request_body=lambda _: False,
136-
reference=lambda referenced_request_body: self._is_httpx_primitive_data(
137-
type_reference=referenced_request_body.request_body_type, allow_optional=True
138-
),
139-
file_upload=lambda _: False,
140-
bytes=lambda _: False,
141-
)
142-
if self._endpoint.request_body is not None
143-
else False
144-
)
132+
self.is_default_body_parameter_used = self.request_body_parameters is not None
145133

134+
def generate(self) -> GeneratedEndpointFunction:
146135
endpoint_snippets = self._generate_endpoint_snippets(
147136
package=self._package,
148137
service=self._service,
@@ -174,7 +163,6 @@ def generate(self) -> GeneratedEndpointFunction:
174163
idempotency_headers=self._idempotency_headers,
175164
request_body_parameters=self.request_body_parameters,
176165
is_async=self._is_async,
177-
is_primitive=is_primitive,
178166
parameters=unnamed_parameters,
179167
named_parameters=self._named_parameters,
180168
),
@@ -311,7 +299,6 @@ def _create_endpoint_body_writer(
311299
idempotency_headers: List[ir_types.HttpHeader],
312300
request_body_parameters: Optional[AbstractRequestBodyParameters],
313301
is_async: bool,
314-
is_primitive: bool,
315302
named_parameters: List[AST.NamedFunctionParameter],
316303
parameters: List[AST.FunctionParameter],
317304
) -> AST.CodeWriter:
@@ -325,11 +312,6 @@ def write(writer: AST.NodeWriter) -> None:
325312
writer.write_node(AST.Expression(request_pre_fetch_statements))
326313

327314
json_request_body = request_body_parameters.get_json_body() if request_body_parameters is not None else None
328-
encoded_json_request_body = (
329-
self._context.core_utilities.jsonable_encoder(json_request_body)
330-
if json_request_body is not None
331-
else None
332-
)
333315

334316
method = endpoint.method.visit(
335317
get=lambda: "GET",
@@ -340,38 +322,8 @@ def write(writer: AST.NodeWriter) -> None:
340322
)
341323

342324
def write_request_body(writer: AST.NodeWriter) -> None:
343-
if is_primitive:
344-
if encoded_json_request_body:
345-
writer.write_node(encoded_json_request_body)
346-
else:
347-
# If there's an existing request body:
348-
# - Use it if the additional body params are none (e.g. request options has no impact)
349-
# - If additional body params is not none, json encode it and spread both dicts into a new one
350-
# - NOTE: With the is_primitive bail out, we do not acknowledge the additional body parameters,
351-
# to not have to merge together an integer and a hash, for example
352-
# If there is not an existing request body, send the encoded dict
353-
additional_parameters = (
354-
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_body_parameters')"
355-
)
356-
additional_parameters_defaulted = f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_body_parameters', {'{}'})"
357-
json_encoded_additional_params = self._context.core_utilities.jsonable_encoder(
358-
self._context.core_utilities.remove_none_from_dict(
359-
AST.Expression(additional_parameters_defaulted)
360-
)
361-
)
362-
if encoded_json_request_body:
363-
writer.write_node(encoded_json_request_body)
364-
writer.write(
365-
f" if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is None or {additional_parameters} is None else"
366-
)
367-
writer.write(" {**")
368-
writer.write_node(encoded_json_request_body)
369-
writer.write(", **(")
370-
writer.write_node(json_encoded_additional_params)
371-
writer.write(")}")
372-
else:
373-
writer.write_node(json_encoded_additional_params)
374-
writer.write(f" if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else None")
325+
if json_request_body is not None:
326+
writer.write_node(json_request_body)
375327

376328
is_streaming = (
377329
True
@@ -394,9 +346,6 @@ def write_request_body(writer: AST.NodeWriter) -> None:
394346
),
395347
)
396348

397-
timeout = AST.Expression(
398-
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('timeout_in_seconds') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None and {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('timeout_in_seconds') is not None else self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_TIMEOUT_METHOD_NAME}()"
399-
)
400349
files = (
401350
request_body_parameters.get_files()
402351
if request_body_parameters is not None and request_body_parameters.get_files() is not None
@@ -406,32 +355,27 @@ def write_request_body(writer: AST.NodeWriter) -> None:
406355
HttpX.make_request(
407356
is_streaming=is_streaming,
408357
is_async=is_async,
409-
url=(
410-
self._get_environment_as_str(endpoint=endpoint)
411-
if is_endpoint_path_empty(endpoint)
412-
else UrlLibParse.urljoin(
413-
self._get_environment_as_str(endpoint=endpoint),
414-
self._get_path_for_endpoint(endpoint),
415-
)
416-
),
358+
path=self._get_path_for_endpoint(endpoint=endpoint)
359+
if not is_endpoint_path_empty(endpoint)
360+
else None,
361+
url=self._get_environment_as_str(endpoint=endpoint),
417362
method=method,
418363
query_parameters=self._get_query_parameters_for_endpoint(endpoint=endpoint),
419-
request_body=AST.Expression(AST.CodeWriter(write_request_body)) if (method != "GET") else None,
364+
request_body=AST.Expression(AST.CodeWriter(write_request_body))
365+
if (method != "GET") and json_request_body is not None
366+
else None,
420367
content=request_body_parameters.get_content() if request_body_parameters is not None else None,
421368
files=self._context.core_utilities.httpx_tuple_converter(files) if files is not None else None,
422369
response_variable_name=EndpointResponseCodeWriter.RESPONSE_VARIABLE,
370+
request_options_variable_name=EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE,
423371
headers=self._get_headers_for_endpoint(
424372
service=service, endpoint=endpoint, idempotency_headers=idempotency_headers
425373
),
426-
auth=None,
427-
timeout=timeout,
428374
response_code_writer=response_code_writer.get_writer(),
429375
reference_to_client=AST.Expression(
430376
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.HTTPX_CLIENT_MEMBER_NAME}"
431377
),
432-
max_retries=AST.Expression(
433-
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('max_retries') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else 0"
434-
),
378+
is_default_body_parameter_used=self.is_default_body_parameter_used,
435379
)
436380
)
437381

@@ -832,7 +776,9 @@ def write_ternary(writer: AST.NodeWriter) -> None:
832776

833777
return reference
834778

835-
def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> AST.Expression:
779+
# Only get the environment expression if the environment is multipleBaseUrls, if it's
780+
# not we'll leverage the URL from the client wrapper
781+
def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> Optional[AST.Expression]:
836782
if self._context.ir.environments is not None:
837783
environments_as_union = self._context.ir.environments.environments.get_as_union()
838784
if environments_as_union.type == "multipleBaseUrls":
@@ -845,9 +791,7 @@ def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> AST.Exp
845791
return AST.Expression(
846792
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_ENVIRONMENT_METHOD_NAME}().{url_reference}"
847793
)
848-
return AST.Expression(
849-
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_BASE_URL_METHOD_NAME}()"
850-
)
794+
return None
851795

852796
def _get_headers_for_endpoint(
853797
self,
@@ -878,42 +822,18 @@ def _get_headers_for_endpoint(
878822
)
879823
)
880824

881-
def write_headers_dict_default(writer: AST.NodeWriter) -> None:
882-
writer.write("{")
883-
writer.write(
884-
f"**self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_HEADERS_METHOD_NAME}(),"
885-
)
886-
writer.write(
887-
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_headers', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
888-
)
889-
writer.write_line("}")
890-
891825
if len(headers) == 0:
892-
self._context.core_utilities.jsonable_encoder(
893-
self._context.core_utilities.remove_none_from_dict(
894-
AST.Expression(AST.CodeWriter(write_headers_dict_default)),
895-
)
896-
)
826+
return None
897827

898828
def write_headers_dict(writer: AST.NodeWriter) -> None:
899829
writer.write("{")
900-
writer.write(
901-
f"**self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_HEADERS_METHOD_NAME}(),"
902-
)
903-
for i, (header_key, header_value) in enumerate(headers):
830+
for _, (header_key, header_value) in enumerate(headers):
904831
writer.write(f'"{header_key}": ')
905832
writer.write_node(header_value)
906833
writer.write(", ")
907-
writer.write(
908-
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_headers', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
909-
)
910-
writer.write_line("},")
834+
writer.write_line("}")
911835

912-
return self._context.core_utilities.jsonable_encoder(
913-
self._context.core_utilities.remove_none_from_dict(
914-
AST.Expression(AST.CodeWriter(write_headers_dict)),
915-
)
916-
)
836+
return AST.Expression(AST.CodeWriter(write_headers_dict))
917837

918838
def _get_query_parameter_reference(self, query_parameter: ir_types.QueryParameter) -> AST.Expression:
919839
possible_query_literal = self._context.get_literal_value(query_parameter.value_type)
@@ -934,32 +854,17 @@ def _get_query_parameters_for_endpoint(
934854
]
935855

936856
if len(query_parameters) == 0:
937-
return self._context.core_utilities.get_encode_query(
938-
self._context.core_utilities.jsonable_encoder(
939-
AST.Expression(
940-
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_query_parameters') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else None"
941-
)
942-
)
943-
)
857+
return None
944858

945859
def write_query_parameters_dict(writer: AST.NodeWriter) -> None:
946860
writer.write("{")
947-
for i, (query_param_key, query_param_value) in enumerate(query_parameters):
861+
for _, (query_param_key, query_param_value) in enumerate(query_parameters):
948862
writer.write(f'"{query_param_key}": ')
949863
writer.write_node(query_param_value)
950864
writer.write(", ")
951-
writer.write(
952-
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_query_parameters', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
953-
)
954-
writer.write_line("},")
865+
writer.write_line("}")
955866

956-
return self._context.core_utilities.get_encode_query(
957-
self._context.core_utilities.jsonable_encoder(
958-
self._context.core_utilities.remove_none_from_dict(
959-
AST.Expression(AST.CodeWriter(write_query_parameters_dict)),
960-
)
961-
)
962-
)
867+
return AST.Expression(AST.CodeWriter(write_query_parameters_dict))
963868

964869
def _is_datetime(
965870
self,

0 commit comments

Comments
 (0)