Skip to content

Commit

Permalink
improvement, python: clean up endpoint functions by centralizing logic (
Browse files Browse the repository at this point in the history
  • Loading branch information
armandobelardo authored Jun 3, 2024
1 parent a44db33 commit 4006d14
Show file tree
Hide file tree
Showing 351 changed files with 25,782 additions and 34,701 deletions.
345 changes: 325 additions & 20 deletions generators/python/core_utilities/sdk/http_client.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions generators/python/sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.8.0] - 2024-06-03

- Improvement: Endpoint function request logic has been abstracted into the request function of the wrapped httpx client.

## [2.7.0] - 2024-05-30

- Improvement: The generator now outputs an `exampleId` alongside each generated snippet so that
we can correlate snippets with the relevant examples. This is useful for retrieving examples from
Fern's API and making sure that you can show multiple snippets in the generated docs.
we can correlate snippets with the relevant examples. This is useful for retrieving examples from
Fern's API and making sure that you can show multiple snippets in the generated docs.

## [2.6.1] - 2024-05-31

Expand Down
2 changes: 1 addition & 1 deletion generators/python/sdk/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.7.0
2.8.0
51 changes: 27 additions & 24 deletions generators/python/src/fern_python/external_dependencies/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ class HttpX:
@staticmethod
def make_request(
*,
url: AST.Expression,
path: Optional[AST.Expression],
url: Optional[AST.Expression],
method: str,
query_parameters: Optional[AST.Expression],
request_body: Optional[AST.Expression],
headers: Optional[AST.Expression],
files: Optional[AST.Expression],
content: Optional[AST.Expression],
auth: Optional[AST.Expression],
timeout: AST.Expression,
response_variable_name: str,
request_options_variable_name: str,
is_async: bool,
is_streaming: bool,
response_code_writer: AST.CodeWriter,
reference_to_client: AST.Expression,
max_retries: AST.Expression,
is_default_body_parameter_used: bool,
) -> AST.Expression:
def add_request_params(*, writer: AST.NodeWriter) -> None:
if query_parameters is not None:
Expand Down Expand Up @@ -76,21 +76,10 @@ def add_request_params(*, writer: AST.NodeWriter) -> None:
writer.write_node(headers)
writer.write_line(",")

if auth is not None:
writer.write("auth=")
writer.write_node(auth)
writer.write_line(",")

writer.write("timeout=")
writer.write_node(timeout)
writer.write_line(",")

writer.write("retries=0")
writer.write_line(",")
writer.write(f"request_options={request_options_variable_name},")

writer.write("max_retries=")
writer.write_node(max_retries)
writer.write_line(", # type: ignore")
if is_default_body_parameter_used:
writer.write_line("omit=OMIT,")

def write_non_streaming_call(
*,
Expand All @@ -107,10 +96,17 @@ def make_non_streaming_request(
if is_async:
writer.write("await ")
writer.write_node(reference_to_client)
writer.write(f'.request(method="{method}", url=')
writer.write_node(url)
writer.write(", ")
writer.write_line(f".request(")

with writer.indent():
if path is not None:
writer.write_node(path)
writer.write(",")
if url is not None:
writer.write("base_url=")
writer.write_node(url)
writer.write(",")
writer.write_line(f'method="{method}",')
add_request_params(writer=writer)
writer.write_line(")")

Expand All @@ -119,10 +115,17 @@ def write_streaming_call(*, writer: AST.NodeWriter) -> None:
writer.write("async ")
writer.write("with ")
writer.write_node(reference_to_client)
writer.write(f'.stream(method="{method}", url=')
writer.write_node(url)
writer.write(", ")
writer.write(f".stream(")

with writer.indent():
if path is not None:
writer.write_node(path)
writer.write(",")
if url is not None:
writer.write("base_url=")
writer.write_node(url)
writer.write(",")
writer.write_line(f'method="{method}",')
add_request_params(writer=writer)
writer.write_line(f") as {response_variable_name}:")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fern_python.codegen import AST
from fern_python.codegen.ast.ast_node.node_writer import NodeWriter
from fern_python.external_dependencies import HttpX, UrlLibParse
from fern_python.external_dependencies import HttpX
from fern_python.generators.sdk.client_generator.endpoint_response_code_writer import (
EndpointDummySnippetConfig,
EndpointResponseCodeWriter,
Expand Down Expand Up @@ -129,20 +129,9 @@ def __init__(
_named_parameter_names.append(name)
self._path_parameter_names[path_parameter.name] = name

def generate(self) -> GeneratedEndpointFunction:
is_primitive: bool = (
self._endpoint.request_body.visit(
inlined_request_body=lambda _: False,
reference=lambda referenced_request_body: self._is_httpx_primitive_data(
type_reference=referenced_request_body.request_body_type, allow_optional=True
),
file_upload=lambda _: False,
bytes=lambda _: False,
)
if self._endpoint.request_body is not None
else False
)
self.is_default_body_parameter_used = self.request_body_parameters is not None

def generate(self) -> GeneratedEndpointFunction:
endpoint_snippets = self._generate_endpoint_snippets(
package=self._package,
service=self._service,
Expand Down Expand Up @@ -174,7 +163,6 @@ def generate(self) -> GeneratedEndpointFunction:
idempotency_headers=self._idempotency_headers,
request_body_parameters=self.request_body_parameters,
is_async=self._is_async,
is_primitive=is_primitive,
parameters=unnamed_parameters,
named_parameters=self._named_parameters,
),
Expand Down Expand Up @@ -311,7 +299,6 @@ def _create_endpoint_body_writer(
idempotency_headers: List[ir_types.HttpHeader],
request_body_parameters: Optional[AbstractRequestBodyParameters],
is_async: bool,
is_primitive: bool,
named_parameters: List[AST.NamedFunctionParameter],
parameters: List[AST.FunctionParameter],
) -> AST.CodeWriter:
Expand All @@ -325,11 +312,6 @@ def write(writer: AST.NodeWriter) -> None:
writer.write_node(AST.Expression(request_pre_fetch_statements))

json_request_body = request_body_parameters.get_json_body() if request_body_parameters is not None else None
encoded_json_request_body = (
self._context.core_utilities.jsonable_encoder(json_request_body)
if json_request_body is not None
else None
)

method = endpoint.method.visit(
get=lambda: "GET",
Expand All @@ -340,38 +322,8 @@ def write(writer: AST.NodeWriter) -> None:
)

def write_request_body(writer: AST.NodeWriter) -> None:
if is_primitive:
if encoded_json_request_body:
writer.write_node(encoded_json_request_body)
else:
# If there's an existing request body:
# - Use it if the additional body params are none (e.g. request options has no impact)
# - If additional body params is not none, json encode it and spread both dicts into a new one
# - NOTE: With the is_primitive bail out, we do not acknowledge the additional body parameters,
# to not have to merge together an integer and a hash, for example
# If there is not an existing request body, send the encoded dict
additional_parameters = (
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_body_parameters')"
)
additional_parameters_defaulted = f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_body_parameters', {'{}'})"
json_encoded_additional_params = self._context.core_utilities.jsonable_encoder(
self._context.core_utilities.remove_none_from_dict(
AST.Expression(additional_parameters_defaulted)
)
)
if encoded_json_request_body:
writer.write_node(encoded_json_request_body)
writer.write(
f" if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is None or {additional_parameters} is None else"
)
writer.write(" {**")
writer.write_node(encoded_json_request_body)
writer.write(", **(")
writer.write_node(json_encoded_additional_params)
writer.write(")}")
else:
writer.write_node(json_encoded_additional_params)
writer.write(f" if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else None")
if json_request_body is not None:
writer.write_node(json_request_body)

is_streaming = (
True
Expand All @@ -394,9 +346,6 @@ def write_request_body(writer: AST.NodeWriter) -> None:
),
)

timeout = AST.Expression(
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('timeout_in_seconds') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None and {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('timeout_in_seconds') is not None else self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_TIMEOUT_METHOD_NAME}()"
)
files = (
request_body_parameters.get_files()
if request_body_parameters is not None and request_body_parameters.get_files() is not None
Expand All @@ -406,32 +355,27 @@ def write_request_body(writer: AST.NodeWriter) -> None:
HttpX.make_request(
is_streaming=is_streaming,
is_async=is_async,
url=(
self._get_environment_as_str(endpoint=endpoint)
if is_endpoint_path_empty(endpoint)
else UrlLibParse.urljoin(
self._get_environment_as_str(endpoint=endpoint),
self._get_path_for_endpoint(endpoint),
)
),
path=self._get_path_for_endpoint(endpoint=endpoint)
if not is_endpoint_path_empty(endpoint)
else None,
url=self._get_environment_as_str(endpoint=endpoint),
method=method,
query_parameters=self._get_query_parameters_for_endpoint(endpoint=endpoint),
request_body=AST.Expression(AST.CodeWriter(write_request_body)) if (method != "GET") else None,
request_body=AST.Expression(AST.CodeWriter(write_request_body))
if (method != "GET") and json_request_body is not None
else None,
content=request_body_parameters.get_content() if request_body_parameters is not None else None,
files=self._context.core_utilities.httpx_tuple_converter(files) if files is not None else None,
response_variable_name=EndpointResponseCodeWriter.RESPONSE_VARIABLE,
request_options_variable_name=EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE,
headers=self._get_headers_for_endpoint(
service=service, endpoint=endpoint, idempotency_headers=idempotency_headers
),
auth=None,
timeout=timeout,
response_code_writer=response_code_writer.get_writer(),
reference_to_client=AST.Expression(
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.HTTPX_CLIENT_MEMBER_NAME}"
),
max_retries=AST.Expression(
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('max_retries') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else 0"
),
is_default_body_parameter_used=self.is_default_body_parameter_used,
)
)

Expand Down Expand Up @@ -832,7 +776,9 @@ def write_ternary(writer: AST.NodeWriter) -> None:

return reference

def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> AST.Expression:
# Only get the environment expression if the environment is multipleBaseUrls, if it's
# not we'll leverage the URL from the client wrapper
def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> Optional[AST.Expression]:
if self._context.ir.environments is not None:
environments_as_union = self._context.ir.environments.environments.get_as_union()
if environments_as_union.type == "multipleBaseUrls":
Expand All @@ -845,9 +791,7 @@ def _get_environment_as_str(self, *, endpoint: ir_types.HttpEndpoint) -> AST.Exp
return AST.Expression(
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_ENVIRONMENT_METHOD_NAME}().{url_reference}"
)
return AST.Expression(
f"self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_BASE_URL_METHOD_NAME}()"
)
return None

def _get_headers_for_endpoint(
self,
Expand Down Expand Up @@ -878,42 +822,18 @@ def _get_headers_for_endpoint(
)
)

def write_headers_dict_default(writer: AST.NodeWriter) -> None:
writer.write("{")
writer.write(
f"**self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_HEADERS_METHOD_NAME}(),"
)
writer.write(
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_headers', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
)
writer.write_line("}")

if len(headers) == 0:
self._context.core_utilities.jsonable_encoder(
self._context.core_utilities.remove_none_from_dict(
AST.Expression(AST.CodeWriter(write_headers_dict_default)),
)
)
return None

def write_headers_dict(writer: AST.NodeWriter) -> None:
writer.write("{")
writer.write(
f"**self.{self._client_wrapper_member_name}.{ClientWrapperGenerator.GET_HEADERS_METHOD_NAME}(),"
)
for i, (header_key, header_value) in enumerate(headers):
for _, (header_key, header_value) in enumerate(headers):
writer.write(f'"{header_key}": ')
writer.write_node(header_value)
writer.write(", ")
writer.write(
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_headers', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
)
writer.write_line("},")
writer.write_line("}")

return self._context.core_utilities.jsonable_encoder(
self._context.core_utilities.remove_none_from_dict(
AST.Expression(AST.CodeWriter(write_headers_dict)),
)
)
return AST.Expression(AST.CodeWriter(write_headers_dict))

def _get_query_parameter_reference(self, query_parameter: ir_types.QueryParameter) -> AST.Expression:
possible_query_literal = self._context.get_literal_value(query_parameter.value_type)
Expand All @@ -934,32 +854,17 @@ def _get_query_parameters_for_endpoint(
]

if len(query_parameters) == 0:
return self._context.core_utilities.get_encode_query(
self._context.core_utilities.jsonable_encoder(
AST.Expression(
f"{EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_query_parameters') if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else None"
)
)
)
return None

def write_query_parameters_dict(writer: AST.NodeWriter) -> None:
writer.write("{")
for i, (query_param_key, query_param_value) in enumerate(query_parameters):
for _, (query_param_key, query_param_value) in enumerate(query_parameters):
writer.write(f'"{query_param_key}": ')
writer.write_node(query_param_value)
writer.write(", ")
writer.write(
f"**({EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE}.get('additional_query_parameters', {'{}'}) if {EndpointFunctionGenerator.REQUEST_OPTIONS_VARIABLE} is not None else {'{}'}),"
)
writer.write_line("},")
writer.write_line("}")

return self._context.core_utilities.get_encode_query(
self._context.core_utilities.jsonable_encoder(
self._context.core_utilities.remove_none_from_dict(
AST.Expression(AST.CodeWriter(write_query_parameters_dict)),
)
)
)
return AST.Expression(AST.CodeWriter(write_query_parameters_dict))

def _is_datetime(
self,
Expand Down
Loading

0 comments on commit 4006d14

Please sign in to comment.