Skip to content

Commit a1fa8c7

Browse files
committed
refactor: refines msgpack decoding in api generator and boolean fields in block model
Improves msgpack decoding in Algod, Indexer and KMD clients by handling byte keys and values. This prevents decoding errors when encountering non-UTF-8 byte sequences. Additionally, adds decoding for boolean fields in block models to correctly interpret raw values as booleans. This addresses issues with inconsistent data representation.
1 parent a531367 commit a1fa8c7

File tree

22 files changed

+1347
-220
lines changed

22 files changed

+1347
-220
lines changed

api/oas-generator/src/oas_generator/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,4 +694,5 @@ def build_client_descriptor(
694694
uses_signed_transaction=uses_signed_txn,
695695
uses_msgpack=operation_builder.uses_msgpack,
696696
include_block_models=operation_builder.uses_block_models,
697+
include_ledger_state_delta_models="LedgerStateDelta" in registry.entries,
697698
)

api/oas-generator/src/oas_generator/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ class ClientDescriptor:
146146
uses_signed_transaction: bool = False
147147
uses_msgpack: bool = False
148148
include_block_models: bool = False
149+
include_ledger_state_delta_models: bool = False

api/oas-generator/src/oas_generator/renderer/engine.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ class TemplateRenderer:
2626
"Block",
2727
"GetBlock",
2828
]
29+
LEDGER_STATE_DELTA_EXPORTS: ClassVar[list[str]] = [
30+
"LedgerTealValue",
31+
"LedgerStateSchema",
32+
"LedgerAppParams",
33+
"LedgerAppLocalState",
34+
"LedgerAppLocalStateDelta",
35+
"LedgerAppParamsDelta",
36+
"LedgerAppResourceRecord",
37+
"LedgerAssetHolding",
38+
"LedgerAssetHoldingDelta",
39+
"LedgerAssetParams",
40+
"LedgerAssetParamsDelta",
41+
"LedgerAssetResourceRecord",
42+
"LedgerVotingData",
43+
"LedgerAccountBaseData",
44+
"LedgerAccountData",
45+
"LedgerBalanceRecord",
46+
"LedgerAccountDeltas",
47+
"LedgerKvValueDelta",
48+
"LedgerIncludedTransactions",
49+
"LedgerModifiedCreatable",
50+
"LedgerAlgoCount",
51+
"LedgerAccountTotals",
52+
"LedgerStateDelta",
53+
]
2954

3055
def __init__(self, template_dir: Path | None = None) -> None:
3156
if template_dir:
@@ -56,6 +81,8 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
5681
files[models_dir / "__init__.py"] = self._render_template("models/__init__.py.j2", context)
5782
files[models_dir / "_serde_helpers.py"] = self._render_template("models/_serde_helpers.py.j2", context)
5883
for model in context["client"].models:
84+
if context["client"].include_ledger_state_delta_models and model.name == "LedgerStateDelta":
85+
continue
5986
model_context = {**context, "model": model}
6087
files[models_dir / f"{model.module_name}.py"] = self._render_template("models/model.py.j2", model_context)
6188
for enum in context["client"].enums:
@@ -67,7 +94,11 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
6794
"models/type_alias.py.j2", alias_context
6895
)
6996
if client.include_block_models:
70-
files[models_dir / "block.py"] = self._render_template("models/block.py.j2", context)
97+
files[models_dir / "_block.py"] = self._render_template("models/block.py.j2", context)
98+
if client.include_ledger_state_delta_models:
99+
files[models_dir / "_ledger_state_delta.py"] = self._render_template(
100+
"models/ledger_state_delta.py.j2", context
101+
)
71102
files[target / "py.typed"] = ""
72103
return files
73104

@@ -85,6 +116,10 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
85116
for name in self.BLOCK_MODEL_EXPORTS:
86117
if name not in model_exports:
87118
model_exports.append(name)
119+
if client.include_ledger_state_delta_models:
120+
for name in self.LEDGER_STATE_DELTA_EXPORTS:
121+
if name not in model_exports:
122+
model_exports.append(name)
88123
metadata_usage = self._collect_metadata_usage(client)
89124
model_modules = [{"module": model.module_name, "name": model.name} for model in client.models]
90125
enum_modules = [{"module": enum.module_name, "name": enum.name} for enum in client.enums]
@@ -105,6 +140,7 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
105140
"needs_datetime": any(model.requires_datetime for model in client.models),
106141
"client_needs_datetime": self._client_requires_datetime(client),
107142
"block_exports": self.BLOCK_MODEL_EXPORTS,
143+
"ledger_state_delta_exports": self.LEDGER_STATE_DELTA_EXPORTS,
108144
"needs_literal": needs_literal,
109145
}
110146

api/oas-generator/src/oas_generator/renderer/templates/client.py.j2

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class {{ client.class_name }}:
248248
return response.content
249249
content_type = response.headers.get("content-type", "application/json")
250250
if "msgpack" in content_type:
251-
data = msgpack.unpackb(response.content, raw=False, strict_map_key=False)
251+
data = msgpack.unpackb(response.content, raw=True, strict_map_key=False)
252252
data = self._normalize_msgpack(data)
253253
elif content_type.startswith("application/json"):
254254
data = response.json()
@@ -264,7 +264,18 @@ class {{ client.class_name }}:
264264

265265
def _normalize_msgpack(self, value: object) -> object:
266266
if isinstance(value, dict):
267-
return {key: self._normalize_msgpack(item) for key, item in value.items()}
267+
normalized: dict[object, object] = {}
268+
for key, item in value.items():
269+
normalized[self._coerce_msgpack_key(key)] = self._normalize_msgpack(item)
270+
return normalized
268271
if isinstance(value, list):
269272
return [self._normalize_msgpack(item) for item in value]
270273
return value
274+
275+
def _coerce_msgpack_key(self, key: object) -> object:
276+
if isinstance(key, bytes):
277+
try:
278+
return key.decode("utf-8")
279+
except UnicodeDecodeError:
280+
return key
281+
return key

api/oas-generator/src/oas_generator/renderer/templates/models/__init__.py.j2

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11

22

33
{% if client.uses_signed_transaction %}from algokit_transact.models.signed_transaction import SignedTransaction
4-
{% endif %}{% for item in model_modules %}from .{{ item.module }} import {{ item.name }}
5-
{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
4+
{% endif %}{% for item in model_modules %}{% if not (client.include_ledger_state_delta_models and item.name == "LedgerStateDelta") %}from .{{ item.module }} import {{ item.name }}
5+
{% endif %}{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
66
{% endfor %}{% for item in alias_modules %}from .{{ item.module }} import {{ item.name }}
7-
{% endfor %}{% if client.include_block_models %}from .block import (
7+
{% endfor %}{% if client.include_block_models %}from ._block import (
88
{{ block_exports | join(',\n ') }}
99
)
10+
{% endif %}{% if client.include_ledger_state_delta_models %}from ._ledger_state_delta import (
11+
{{ ledger_state_delta_exports | join(',\n ') }}
12+
)
1013
{% endif %}
1114

1215
__all__ = [
1316
{% for name in model_exports %}"{{ name }}",
1417
{% endfor %}
1518
]
16-

api/oas-generator/src/oas_generator/renderer/templates/models/_serde_helpers.py.j2

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ from typing import Callable, TypeAlias, TypeVar
88

99
from algokit_common.serde import from_wire, to_wire
1010

11-
T = TypeVar("T")
12-
E = TypeVar("E", bound=Enum)
13-
KT = TypeVar("KT")
11+
DecodedT = TypeVar("DecodedT")
12+
EnumValueT = TypeVar("EnumValueT", bound=Enum)
13+
MapKeyT = TypeVar("MapKeyT")
1414
BytesLike: TypeAlias = bytes | bytearray | memoryview
1515

1616

@@ -33,11 +33,22 @@ def decode_bytes_base64(raw: object) -> bytes:
3333
try:
3434
return base64.b64decode(raw.encode("ascii"), validate=True)
3535
except (BinasciiError, UnicodeEncodeError) as exc:
36+
raise ValueError("Invalid base64 payload") from exc
37+
raise TypeError(f"Unsupported value for bytes field: {type(raw)!r}")
38+
39+
40+
def decode_bytes_map_key(raw: object) -> bytes:
41+
if isinstance(raw, bytes | bytearray | memoryview):
42+
return bytes(raw)
43+
if isinstance(raw, str):
44+
try:
45+
return decode_bytes_base64(raw)
46+
except ValueError:
3647
try:
3748
return raw.encode("utf-8")
3849
except UnicodeEncodeError as fallback_exc:
39-
raise ValueError("Invalid base64 payload") from fallback_exc
40-
raise TypeError(f"Unsupported value for bytes field: {type(raw)!r}")
50+
raise ValueError("Invalid bytes map key") from fallback_exc
51+
raise TypeError(f"Unsupported map key for bytes field: {type(raw)!r}")
4152

4253

4354
def encode_bytes_sequence(values: Iterable[BytesLike | None] | None) -> list[str | None] | None:
@@ -77,11 +88,11 @@ def encode_model_sequence(values: Iterable[object] | None) -> list[dict[str, obj
7788
return encoded or None
7889

7990

80-
def decode_model_sequence(cls_factory: Callable[[], type[T]], raw: object) -> list[T] | None:
91+
def decode_model_sequence(cls_factory: Callable[[], type[DecodedT]], raw: object) -> list[DecodedT] | None:
8192
if not isinstance(raw, list):
8293
return None
8394
cls = cls_factory()
84-
decoded: list[T] = []
95+
decoded: list[DecodedT] = []
8596
for item in raw:
8697
if isinstance(item, Mapping):
8798
decoded.append(from_wire(cls, item))
@@ -99,11 +110,11 @@ def encode_enum_sequence(values: Iterable[object] | None) -> list[object] | None
99110
return encoded or None
100111

101112

102-
def decode_enum_sequence(enum_factory: Callable[[], type[E]], raw: object) -> list[E] | None:
113+
def decode_enum_sequence(enum_factory: Callable[[], type[EnumValueT]], raw: object) -> list[EnumValueT] | None:
103114
if not isinstance(raw, list):
104115
return None
105116
enum_cls = enum_factory()
106-
decoded: list[E] = []
117+
decoded: list[EnumValueT] = []
107118
for item in raw:
108119
try:
109120
decoded.append(enum_cls(item))
@@ -113,7 +124,7 @@ def decode_enum_sequence(enum_factory: Callable[[], type[E]], raw: object) -> li
113124

114125

115126
def encode_model_mapping(
116-
factory: Callable[[], type[T]],
127+
factory: Callable[[], type[DecodedT]],
117128
mapping: Mapping[object, object] | None,
118129
*,
119130
key_encoder: Callable[[object], str] | None = None,
@@ -140,24 +151,30 @@ def encode_model_mapping(
140151

141152

142153
def decode_model_mapping(
143-
factory: Callable[[], type[T]],
154+
factory: Callable[[], type[DecodedT]],
144155
raw: object,
145156
*,
146-
key_decoder: Callable[[object], KT] | None = None,
147-
) -> dict[KT, T] | None:
157+
key_decoder: Callable[[object], MapKeyT] | None = None,
158+
) -> dict[MapKeyT, DecodedT] | None:
148159
if not isinstance(raw, Mapping):
149160
return None
150161
cls = factory()
151-
decoded: dict[KT, T] = {}
162+
decoded: dict[MapKeyT, DecodedT] = {}
152163
for key, value in raw.items():
153164
if isinstance(value, Mapping):
154165
decoded_key = key_decoder(key) if key_decoder is not None else key
155166
decoded[decoded_key] = from_wire(cls, value)
156167
return decoded or None
157168

158169

170+
def decode_optional_bool(raw: object) -> bool | None:
171+
if raw is None:
172+
return None
173+
return bool(raw)
174+
175+
159176
def mapping_encoder(
160-
factory: Callable[[], type[T]],
177+
factory: Callable[[], type[DecodedT]],
161178
*,
162179
key_encoder: Callable[[object], str] | None = None,
163180
) -> Callable[[Mapping[object, object] | None], dict[str, object] | None]:
@@ -168,11 +185,11 @@ def mapping_encoder(
168185

169186

170187
def mapping_decoder(
171-
factory: Callable[[], type[T]],
188+
factory: Callable[[], type[DecodedT]],
172189
*,
173-
key_decoder: Callable[[object], KT] | None = None,
174-
) -> Callable[[object], dict[KT, T] | None]:
175-
def _decode(raw: object) -> dict[KT, T] | None:
190+
key_decoder: Callable[[object], MapKeyT] | None = None,
191+
) -> Callable[[object], dict[MapKeyT, DecodedT] | None]:
192+
def _decode(raw: object) -> dict[MapKeyT, DecodedT] | None:
176193
return decode_model_mapping(factory, raw, key_decoder=key_decoder)
177194

178195
return _decode

api/oas-generator/src/oas_generator/renderer/templates/models/block.py.j2

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ from collections.abc import Mapping
44
from dataclasses import dataclass, field
55
from typing import Any, cast
66

7-
from algokit_common.serde import flatten, nested, wire
7+
from algokit_common.serde import flatten, nested, wire, addr
88
from algokit_transact.models.signed_transaction import SignedTransaction
99

1010
from ._serde_helpers import (
11-
decode_bytes_base64,
11+
decode_bytes_map_key,
1212
decode_model_mapping,
1313
decode_model_sequence,
14+
decode_optional_bool,
1415
encode_bytes_base64,
1516
encode_model_mapping,
1617
encode_model_sequence,
@@ -64,18 +65,28 @@ def _decode_state_proof_tracking_key(key: object) -> int:
6465

6566

6667
def _decode_block_state_delta(raw: object) -> BlockStateDelta | None:
67-
decoded = decode_model_mapping(lambda: BlockEvalDelta, raw, key_decoder=decode_bytes_base64)
68+
decoded = decode_model_mapping(lambda: BlockEvalDelta, raw, key_decoder=decode_bytes_map_key)
6869
return decoded or None
6970

7071

72+
def _encode_local_delta_index_key(key: object) -> str:
73+
if isinstance(key, bool):
74+
return str(int(key))
75+
if isinstance(key, int):
76+
return str(key)
77+
if isinstance(key, str):
78+
return str(int(key))
79+
raise TypeError("Local delta keys must be numeric")
80+
81+
7182
def _encode_local_deltas(mapping: Mapping[int, BlockStateDelta] | None) -> dict[str, object] | None:
7283
if mapping is None:
7384
return None
7485
out: dict[str, object] = {}
7586
for key, value in mapping.items():
7687
encoded = _encode_block_state_delta(value)
7788
if encoded:
78-
out[str(int(key))] = encoded
89+
out[_encode_local_delta_index_key(key)] = encoded
7990
return out or None
8091

8192

@@ -156,8 +167,8 @@ class SignedTxnInBlock:
156167
)
157168
config_asset: int | None = field(default=None, metadata=wire("caid"))
158169
application_id: int | None = field(default=None, metadata=wire("apid"))
159-
has_genesis_id: bool | None = field(default=None, metadata=wire("hgi"))
160-
has_genesis_hash: bool | None = field(default=None, metadata=wire("hgh"))
170+
has_genesis_id: bool | None = field(default=None, metadata=wire("hgi", decode=decode_optional_bool))
171+
has_genesis_hash: bool | None = field(default=None, metadata=wire("hgh", decode=decode_optional_bool))
161172

162173

163174
@dataclass(slots=True)
@@ -206,12 +217,12 @@ class Block:
206217
timestamp: int | None = field(default=None, metadata=wire("ts"))
207218
genesis_id: str | None = field(default=None, metadata=wire("gen"))
208219
genesis_hash: bytes | None = field(default=None, metadata=wire("gh"))
209-
proposer: bytes | None = field(default=None, metadata=wire("prp"))
220+
proposer: bytes | None = field(default=None, metadata=addr("prp"))
210221
fees_collected: int | None = field(default=None, metadata=wire("fc"))
211222
bonus: int | None = field(default=None, metadata=wire("bi"))
212223
proposer_payout: int | None = field(default=None, metadata=wire("pp"))
213-
fee_sink: bytes | None = field(default=None, metadata=wire("fees"))
214-
rewards_pool: bytes | None = field(default=None, metadata=wire("rwd"))
224+
fee_sink: bytes | None = field(default=None, metadata=addr("fees"))
225+
rewards_pool: bytes | None = field(default=None, metadata=addr("rwd"))
215226
rewards_level: int | None = field(default=None, metadata=wire("earn"))
216227
rewards_rate: int | None = field(default=None, metadata=wire("rate"))
217228
rewards_residue: int | None = field(default=None, metadata=wire("frac"))

0 commit comments

Comments
 (0)