Skip to content

Commit ea0b743

Browse files
committed
fix: some more improvements
1 parent 6fb1bdc commit ea0b743

File tree

1 file changed

+71
-68
lines changed

1 file changed

+71
-68
lines changed

json2xml/dicttoxml.py

Lines changed: 71 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
from typing import Any, Union
99

1010
from defusedxml.minidom import parseString
11+
from xml.parsers.expat import ExpatError
1112

1213
# Create a safe random number generator
1314

1415
# Set up logging
1516
LOG = logging.getLogger("dicttoxml")
1617

18+
# Module-level set for true uniqueness tracking
19+
_used_ids: set[str] = set()
20+
1721

1822
def make_id(element: str, start: int = 100000, end: int = 999999) -> str:
1923
"""
@@ -41,16 +45,11 @@ def get_unique_id(element: str) -> str:
4145
Returns:
4246
str: The unique ID.
4347
"""
44-
ids: list[str] = [] # initialize list of unique ids
4548
this_id = make_id(element)
46-
dup = True
47-
while dup:
48-
if this_id not in ids:
49-
dup = False
50-
ids.append(this_id)
51-
else:
52-
this_id = make_id(element)
53-
return ids[-1]
49+
while this_id in _used_ids:
50+
this_id = make_id(element)
51+
_used_ids.add(this_id)
52+
return this_id
5453

5554

5655
ELEMENT = Union[
@@ -77,23 +76,22 @@ def get_xml_type(val: ELEMENT) -> str:
7776
Returns:
7877
str: The XML type.
7978
"""
80-
if val is not None:
81-
if type(val).__name__ in ("str", "unicode"):
82-
return "str"
83-
if type(val).__name__ in ("int", "long"):
84-
return "int"
85-
if type(val).__name__ == "float":
86-
return "float"
87-
if type(val).__name__ == "bool":
88-
return "bool"
89-
if isinstance(val, numbers.Number):
90-
return "number"
91-
if isinstance(val, dict):
92-
return "dict"
93-
if isinstance(val, Sequence):
94-
return "list"
95-
else:
79+
if val is None:
9680
return "null"
81+
if isinstance(val, bool): # Check bool before int (bool is subclass of int)
82+
return "bool"
83+
if isinstance(val, int):
84+
return "int"
85+
if isinstance(val, float):
86+
return "float"
87+
if isinstance(val, str):
88+
return "str"
89+
if isinstance(val, numbers.Number):
90+
return "number"
91+
if isinstance(val, dict):
92+
return "dict"
93+
if isinstance(val, Sequence):
94+
return "list"
9795
return type(val).__name__
9896

9997

@@ -102,19 +100,19 @@ def escape_xml(s: str | int | float | numbers.Number) -> str:
102100
Escape a string for use in XML.
103101
104102
Args:
105-
s (str | numbers.Number): The string to escape.
103+
s (str | int | float | numbers.Number): The string to escape.
106104
107105
Returns:
108106
str: The escaped string.
109107
"""
108+
s_str = str(s) # Convert to string once
110109
if isinstance(s, str):
111-
s = str(s) # avoid UnicodeDecodeError
112-
s = s.replace("&", "&")
113-
s = s.replace('"', """)
114-
s = s.replace("'", "'")
115-
s = s.replace("<", "&lt;")
116-
s = s.replace(">", "&gt;")
117-
return str(s)
110+
s_str = s_str.replace("&", "&amp;")
111+
s_str = s_str.replace('"', "&quot;")
112+
s_str = s_str.replace("'", "&apos;")
113+
s_str = s_str.replace("<", "&lt;")
114+
s_str = s_str.replace(">", "&gt;")
115+
return s_str
118116

119117

120118
def make_attrstring(attr: dict[str, Any]) -> str:
@@ -145,37 +143,39 @@ def key_is_valid_xml(key: str) -> bool:
145143
try:
146144
parseString(test_xml)
147145
return True
148-
except Exception: # minidom does not implement exceptions well
146+
except (ExpatError, ValueError) as e:
147+
LOG.debug(f"Invalid XML name '{key}': {e}")
149148
return False
150149

151150

152-
def make_valid_xml_name(key: str, attr: dict[str, Any]) -> tuple[str, dict[str, Any]]:
151+
def make_valid_xml_name(key: str | int, attr: dict[str, Any]) -> tuple[str, dict[str, Any]]:
153152
"""Tests an XML name and fixes it if invalid"""
154-
key = escape_xml(key)
153+
key_str = str(key) # Ensure we're working with strings
154+
key_str = escape_xml(key_str)
155155
# nothing happens at escape_xml if attr is not a string, we don't
156156
# need to pass it to the method at all.
157157
# attr = escape_xml(attr)
158158

159159
# pass through if key is already valid
160-
if key_is_valid_xml(key):
161-
return key, attr
160+
if key_is_valid_xml(key_str):
161+
return key_str, attr
162162

163163
# prepend a lowercase n if the key is numeric
164-
if isinstance(key, int) or key.isdigit():
165-
return f"n{key}", attr
164+
if key_str.isdigit():
165+
return f"n{key_str}", attr
166166

167167
# replace spaces with underscores if that fixes the problem
168-
if key_is_valid_xml(key.replace(" ", "_")):
169-
return key.replace(" ", "_"), attr
168+
if key_is_valid_xml(key_str.replace(" ", "_")):
169+
return key_str.replace(" ", "_"), attr
170170

171171
# allow namespace prefixes + ignore @flat in key
172-
if key_is_valid_xml(key.replace(":", "").replace("@flat", "")):
173-
return key, attr
172+
if key_is_valid_xml(key_str.replace(":", "").replace("@flat", "")):
173+
return key_str, attr
174174

175175
# key is still invalid - move it into a name attribute
176-
attr["name"] = key
177-
key = "key"
178-
return key, attr
176+
attr["name"] = key_str
177+
key_str = "key"
178+
return key_str, attr
179179

180180

181181
def wrap_cdata(s: str | int | float | numbers.Number) -> str:
@@ -188,6 +188,25 @@ def default_item_func(parent: str) -> str:
188188
return "item"
189189

190190

191+
def _build_namespace_string(xml_namespaces: dict[str, Any]) -> str:
192+
"""Build XML namespace string from namespace dictionary."""
193+
parts = []
194+
195+
for prefix, value in xml_namespaces.items():
196+
if prefix == 'xsi' and isinstance(value, dict):
197+
for schema_att, ns in value.items():
198+
if schema_att == 'schemaInstance':
199+
parts.append(f'xmlns:{prefix}="{ns}"')
200+
elif schema_att == 'schemaLocation':
201+
parts.append(f'xsi:{schema_att}="{ns}"')
202+
elif prefix == 'xmlns':
203+
parts.append(f'xmlns="{value}"')
204+
else:
205+
parts.append(f'xmlns:{prefix}="{value}"')
206+
207+
return ' ' + ' '.join(parts) if parts else ''
208+
209+
191210
def convert(
192211
obj: ELEMENT,
193212
ids: Any,
@@ -262,7 +281,6 @@ def dict2xml_str(
262281
parse dict2xml
263282
"""
264283
ids: list[str] = [] # initialize list of unique ids
265-
", ".join(str(key) for key in item)
266284
subtree = "" # Initialize subtree with default empty string
267285

268286
if attr_type:
@@ -562,7 +580,7 @@ def dicttoxml(
562580
item_wrap: bool = True,
563581
item_func: Callable[[str], str] = default_item_func,
564582
cdata: bool = False,
565-
xml_namespaces: dict[str, Any] = {},
583+
xml_namespaces: dict[str, Any] | None = None,
566584
list_headers: bool = False
567585
) -> bytes:
568586
"""
@@ -681,26 +699,11 @@ def dicttoxml(
681699
<list a="b" c="d"><item>4</item><item>5</item><item>6</item></list>
682700
683701
"""
702+
if xml_namespaces is None:
703+
xml_namespaces = {}
704+
684705
output = []
685-
namespace_str = ""
686-
for prefix in xml_namespaces:
687-
if prefix == 'xsi':
688-
for schema_att in xml_namespaces[prefix]:
689-
if schema_att == 'schemaInstance':
690-
ns = xml_namespaces[prefix]['schemaInstance']
691-
namespace_str += f' xmlns:{prefix}="{ns}"'
692-
elif schema_att == 'schemaLocation':
693-
ns = xml_namespaces[prefix][schema_att]
694-
namespace_str += f' xsi:{schema_att}="{ns}"'
695-
696-
elif prefix == 'xmlns':
697-
# xmns needs no prefix
698-
ns = xml_namespaces[prefix]
699-
namespace_str += f' xmlns="{ns}"'
700-
701-
else:
702-
ns = xml_namespaces[prefix]
703-
namespace_str += f' xmlns:{prefix}="{ns}"'
706+
namespace_str = _build_namespace_string(xml_namespaces)
704707
if root:
705708
output.append('<?xml version="1.0" encoding="UTF-8" ?>')
706709
output_elem = convert(

0 commit comments

Comments
 (0)