Skip to content

Commit a3db3b2

Browse files
committed
fix: tests
1 parent b802a7d commit a3db3b2

File tree

5 files changed

+199
-38
lines changed

5 files changed

+199
-38
lines changed

benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def benchmark_conversion(data: dict, parallel: bool, workers: int = 4, chunk_siz
4949
for _ in range(iterations):
5050
converter = Json2xml(data, parallel=parallel, workers=workers, chunk_size=chunk_size)
5151
start = time.perf_counter()
52-
result = converter.to_xml()
52+
converter.to_xml()
5353
end = time.perf_counter()
5454
times.append(end - start)
5555

json2xml/dicttoxml.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,17 @@ def dict2xml_str(
262262
parse dict2xml
263263
"""
264264
ids: list[str] = [] # initialize list of unique ids
265+
item = dict(item) # copy to avoid modifying the original dict
265266
", ".join(str(key) for key in item)
266267
subtree = "" # Initialize subtree with default empty string
267268

268269
if attr_type:
269270
attr["type"] = get_xml_type(item)
270271
val_attr: dict[str, str] = item.pop("@attrs", attr) # update attr with custom @attr if exists
272+
# Handle other @ keys as attributes
273+
for key in list(item.keys()):
274+
if key.startswith('@') and key not in ('@val', '@flat', '@attrs'):
275+
val_attr[key[1:]] = item.pop(key)
271276
rawitem = item["@val"] if "@val" in item else item
272277
if is_primitive_type(rawitem):
273278
if isinstance(rawitem, dict):
@@ -522,7 +527,15 @@ def convert_kv(
522527
if attr_type:
523528
attr["type"] = get_xml_type(val)
524529
attr_string = make_attrstring(attr)
525-
return f"<{key}{attr_string}>{wrap_cdata(val) if cdata else escape_xml(val)}</{key}>"
530+
val_str = str(val)
531+
if cdata:
532+
if '<![CDATA[' in val_str:
533+
content = val_str
534+
else:
535+
content = wrap_cdata(val)
536+
else:
537+
content = escape_xml(val)
538+
return f"<{key}{attr_string}>{content}</{key}>"
526539

527540

528541
def convert_bool(
@@ -566,7 +579,8 @@ def dicttoxml(
566579
list_headers: bool = False,
567580
parallel: bool = False,
568581
workers: int | None = None,
569-
chunk_size: int = 100
582+
chunk_size: int = 100,
583+
min_items_for_parallel: int = 10
570584
) -> bytes:
571585
"""
572586
Converts a python object into XML.
@@ -668,6 +682,10 @@ def dicttoxml(
668682
Default is 100
669683
Number of list items to process per chunk in parallel mode.
670684
685+
:param int min_items_for_parallel:
686+
Default is 10
687+
Minimum number of items in a dictionary to enable parallel processing.
688+
671689
Dictionaries-keys with special char '@' has special meaning:
672690
@attrs: This allows custom xml attributes:
673691
@@ -718,17 +736,61 @@ def dicttoxml(
718736
ns = xml_namespaces[prefix]
719737
namespace_str += f' xmlns:{prefix}="{ns}"'
720738

739+
def _dispatch_convert(
740+
obj, ids, parent,
741+
attr_type, item_func, cdata, item_wrap, list_headers,
742+
parallel, workers, chunk_size, min_items_for_parallel, xml_namespaces
743+
):
744+
should_use_parallel = parallel
745+
if parallel:
746+
if cdata:
747+
should_use_parallel = False
748+
if isinstance(obj, dict) and any(isinstance(k, str) and k.startswith('@') for k in obj.keys()):
749+
should_use_parallel = False
750+
if xml_namespaces:
751+
should_use_parallel = False
752+
if should_use_parallel:
753+
if isinstance(obj, dict):
754+
return convert_dict_parallel(
755+
obj, ids, parent,
756+
attr_type=attr_type, item_func=item_func, cdata=cdata,
757+
item_wrap=item_wrap, list_headers=list_headers,
758+
workers=workers, min_items_for_parallel=min_items_for_parallel
759+
)
760+
if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
761+
return convert_list_parallel(
762+
obj, ids, parent,
763+
attr_type=attr_type, item_func=item_func, cdata=cdata,
764+
item_wrap=item_wrap, list_headers=list_headers,
765+
workers=workers, chunk_size=chunk_size
766+
)
767+
# fallback to serial
768+
return convert(
769+
obj, ids,
770+
attr_type, item_func, cdata, item_wrap,
771+
parent=parent, list_headers=list_headers
772+
)
773+
774+
should_use_parallel = parallel
721775
if parallel:
776+
if cdata:
777+
should_use_parallel = False
778+
if isinstance(obj, dict) and any(isinstance(k, str) and k.startswith('@') for k in obj.keys()):
779+
should_use_parallel = False
780+
if xml_namespaces:
781+
should_use_parallel = False
782+
783+
if should_use_parallel:
722784
from json2xml.parallel import convert_dict_parallel, convert_list_parallel
723785

724786
if root:
725787
output.append('<?xml version="1.0" encoding="UTF-8" ?>')
726788
if isinstance(obj, dict):
727789
output_elem = convert_dict_parallel(
728790
obj, ids, custom_root, attr_type, item_func, cdata, item_wrap,
729-
list_headers=list_headers, workers=workers, min_items_for_parallel=10
791+
list_headers=list_headers, workers=workers, min_items_for_parallel=min_items_for_parallel
730792
)
731-
elif isinstance(obj, Sequence):
793+
elif isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
732794
output_elem = convert_list_parallel(
733795
obj, ids, custom_root, attr_type, item_func, cdata, item_wrap,
734796
list_headers=list_headers, workers=workers, chunk_size=chunk_size
@@ -742,11 +804,11 @@ def dicttoxml(
742804
if isinstance(obj, dict):
743805
output.append(
744806
convert_dict_parallel(
745-
obj, ids, "", attr_type, item_func, cdata, item_wrap,
746-
list_headers=list_headers, workers=workers, min_items_for_parallel=10
807+
obj, ids, "", attr_type, item_func, cdata, item_wrap,
808+
list_headers=list_headers, workers=workers, min_items_for_parallel=min_items_for_parallel
747809
)
748810
)
749-
elif isinstance(obj, Sequence):
811+
elif isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
750812
output.append(
751813
convert_list_parallel(
752814
obj, ids, "", attr_type, item_func, cdata, item_wrap,

json2xml/json2xml.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
parallel: bool = False,
2424
workers: int | None = None,
2525
chunk_size: int = 100,
26+
min_items_for_parallel: int = 10,
2627
):
2728
self.data = data
2829
self.pretty = pretty
@@ -33,6 +34,7 @@ def __init__(
3334
self.parallel = parallel
3435
self.workers = workers
3536
self.chunk_size = chunk_size
37+
self.min_items_for_parallel = min_items_for_parallel
3638

3739
def to_xml(self) -> Any | None:
3840
"""
@@ -48,6 +50,7 @@ def to_xml(self) -> Any | None:
4850
parallel=self.parallel,
4951
workers=self.workers,
5052
chunk_size=self.chunk_size,
53+
min_items_for_parallel=self.min_items_for_parallel,
5154
)
5255
if self.pretty:
5356
try:

json2xml/parallel.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,61 @@
66
import threading
77
from collections.abc import Callable, Sequence
88
from concurrent.futures import ThreadPoolExecutor, as_completed
9+
from functools import lru_cache
910
from typing import Any
1011

11-
from json2xml import dicttoxml
12-
1312

1413
def is_free_threaded() -> bool:
1514
"""
1615
Check if running on free-threaded Python build (Python 3.13t).
1716
17+
Note:
18+
This function relies on the private attribute `sys._is_gil_enabled`, which may change or be removed in future Python versions.
19+
If the attribute is not present, or its semantics change, this function will fall back to assuming GIL is enabled.
20+
1821
Returns:
1922
bool: True if running on free-threaded build, False otherwise.
2023
"""
21-
return hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled()
24+
# Fallback: If attribute is missing or not callable, assume GIL is enabled.
25+
gil_enabled = True
26+
if hasattr(sys, '_is_gil_enabled'):
27+
try:
28+
gil_enabled = sys._is_gil_enabled()
29+
except Exception:
30+
pass
31+
return not gil_enabled
2232

2333

24-
def get_optimal_workers(workers: int | None = None) -> int:
34+
def get_optimal_workers(
35+
workers: int | None = None,
36+
max_workers_limit: int | None = None
37+
) -> int:
2538
"""
2639
Get the optimal number of worker threads.
2740
2841
Args:
2942
workers: Explicitly specified worker count. If None, auto-detect.
43+
max_workers_limit: Optional cap for worker count on non-free-threaded Python.
3044
3145
Returns:
3246
int: Number of worker threads to use.
3347
"""
3448
if workers is not None:
3549
return max(1, workers)
3650

37-
cpu_count = os.cpu_count() or 4
51+
cpu_count = os.cpu_count() or 1
3852

3953
if is_free_threaded():
40-
return cpu_count
54+
optimal = cpu_count
4155
else:
42-
return min(4, cpu_count)
43-
56+
# Use configurable limit or default to 4
57+
limit = max_workers_limit if max_workers_limit is not None else 4
58+
optimal = min(limit, cpu_count)
4459

45-
_validation_cache: dict[str, bool] = {}
46-
_validation_cache_lock = threading.Lock()
60+
return max(1, optimal)
4761

4862

63+
@lru_cache(maxsize=None)
4964
def key_is_valid_xml_cached(key: str) -> bool:
5065
"""
5166
Thread-safe cached version of key_is_valid_xml.
@@ -56,16 +71,8 @@ def key_is_valid_xml_cached(key: str) -> bool:
5671
Returns:
5772
bool: True if the key is valid XML, False otherwise.
5873
"""
59-
with _validation_cache_lock:
60-
if key in _validation_cache:
61-
return _validation_cache[key]
62-
63-
result = dicttoxml.key_is_valid_xml(key)
64-
65-
with _validation_cache_lock:
66-
_validation_cache[key] = result
67-
68-
return result
74+
from json2xml import dicttoxml
75+
return dicttoxml.key_is_valid_xml(key)
6976

7077

7178
def make_valid_xml_name_cached(key: str, attr: dict[str, Any]) -> tuple[str, dict[str, Any]]:
@@ -79,6 +86,7 @@ def make_valid_xml_name_cached(key: str, attr: dict[str, Any]) -> tuple[str, dic
7986
Returns:
8087
tuple: Valid XML key and updated attributes.
8188
"""
89+
from json2xml import dicttoxml
8290
key = dicttoxml.escape_xml(key)
8391

8492
if key_is_valid_xml_cached(key):
@@ -129,7 +137,9 @@ def _convert_dict_item(
129137
import datetime
130138
import numbers
131139

132-
attr = {} if not ids else {"id": f"{dicttoxml.get_unique_id(parent)}"}
140+
from json2xml import dicttoxml
141+
142+
attr = {"id": f"{dicttoxml.get_unique_id(parent)}"} if ids else {}
133143
key, attr = make_valid_xml_name_cached(key, attr)
134144

135145
if isinstance(val, bool):
@@ -203,8 +213,11 @@ def convert_dict_parallel(
203213
min_items_for_parallel: Minimum items to enable parallelization.
204214
205215
Returns:
206-
str: XML string.
216+
str: XML string.
207217
"""
218+
if not isinstance(obj, dict):
219+
raise TypeError("obj must be a dict")
220+
from json2xml import dicttoxml
208221
if len(obj) < min_items_for_parallel:
209222
return dicttoxml.convert_dict(
210223
obj, ids, parent, attr_type, item_func, cdata, item_wrap, list_headers
@@ -225,7 +238,14 @@ def convert_dict_parallel(
225238

226239
for future in as_completed(future_to_idx):
227240
idx = future_to_idx[future]
228-
results[idx] = future.result()
241+
try:
242+
results[idx] = future.result()
243+
except Exception as e:
244+
# Cancel remaining futures
245+
for f in future_to_idx:
246+
if not f.done():
247+
f.cancel()
248+
raise e
229249

230250
return "".join(results[idx] for idx in range(len(items)))
231251

@@ -256,8 +276,9 @@ def _convert_list_chunk(
256276
start_offset: Starting index for this chunk.
257277
258278
Returns:
259-
str: XML string for this chunk.
279+
str: XML string for this chunk.
260280
"""
281+
from json2xml import dicttoxml
261282
return dicttoxml.convert_list(
262283
items, ids, parent, attr_type, item_func, cdata, item_wrap, list_headers
263284
)
@@ -291,8 +312,11 @@ def convert_list_parallel(
291312
chunk_size: Number of items per chunk.
292313
293314
Returns:
294-
str: XML string.
315+
str: XML string.
295316
"""
317+
if not isinstance(items, Sequence) or isinstance(items, (str, bytes)):
318+
raise TypeError("items must be a sequence (not str or bytes)")
319+
from json2xml import dicttoxml
296320
if len(items) < chunk_size:
297321
return dicttoxml.convert_list(
298322
items, ids, parent, attr_type, item_func, cdata, item_wrap, list_headers
@@ -313,6 +337,13 @@ def convert_list_parallel(
313337

314338
for future in as_completed(future_to_idx):
315339
idx = future_to_idx[future]
316-
results[idx] = future.result()
340+
try:
341+
results[idx] = future.result()
342+
except Exception as e:
343+
# Cancel remaining futures
344+
for f in future_to_idx:
345+
if not f.done():
346+
f.cancel()
347+
raise e
317348

318349
return "".join(results[idx] for idx in range(len(chunks)))

0 commit comments

Comments
 (0)