Skip to content

Commit

Permalink
Improving eval code.
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-griffiths committed Jan 28, 2024
1 parent abe81e9 commit 142b412
Showing 1 changed file with 77 additions and 37 deletions.
114 changes: 77 additions & 37 deletions bitformat/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from bitstring import Bits, Dtype, Array
from typing import Sequence, Any, Iterator, Tuple, List, Dict
from types import CodeType
import copy
import ast

Expand All @@ -21,9 +22,29 @@ def __new__(cls, use_colour: bool) -> Colour:

colour = Colour(True)

def _compile_safe_eval(s: str) -> CodeType:
start = s.find('{')
end = s.find('}')
if start == -1 or end == -1:
raise ValueError(f'Invalid expression: {s}. It should start and end with braces.')
s = s[start + 1:end].strip()
# Only allowing operations for integer maths or boolean comparisons.
node_whitelist = {'BinOp', 'Name', 'Add', 'Expr', 'Mult', 'FloorDiv', 'Sub', 'Load', 'Module', 'Constant',
'UnaryOp', 'USub', 'Mod', 'Pow', 'BitAnd', 'BitXor', 'BitOr', 'And', 'Or', 'BoolOp', 'LShift',
'RShift',
'Eq', 'NotEq', 'Compare', 'LtE', 'GtE'}
nodes_used = set([x.__class__.__name__ for x in ast.walk(ast.parse(s))])
bad_nodes = nodes_used - node_whitelist
if bad_nodes:
raise ValueError(f"Disallowed operations used in expression '{s}'. Disallowed nodes were: {bad_nodes}.")
if '__' in s:
raise ValueError(f"Invalid expression: '{s}'. Double underscores are not permitted.")
code = compile(s, "<string>", "eval")
return code


class Field:
def __init__(self, dtype: Dtype | Bits | str, name: str | None = None, value: Any = None, items: int = 1):
def __init__(self, dtype: Dtype | Bits | str, name: str | None = None, value: Any = None, items: str | int = 1):
if name == '':
name = None
self._bits = None
Expand Down Expand Up @@ -70,6 +91,10 @@ def __init__(self, dtype: Dtype | Bits | str, name: str | None = None, value: An
if name is not None and not name.isidentifier():
raise ValueError(f"The Field name '{name}' is not a valid Python identifier.")
self.name = name
try:
items = int(items)
except ValueError:
items = _compile_safe_eval(items)
self.items = items
if self.dtype.length == 0:
raise ValueError(f"A field's dtype cannot have a length of zero (dtype = {self.dtype}).")
Expand All @@ -78,34 +103,62 @@ def __init__(self, dtype: Dtype | Bits | str, name: str | None = None, value: An

@staticmethod
def _parse_dtype_str(dtype_str: str) -> Tuple[str, str | None, str | None, int]:
# The string has the form 'dtype [* items] [<name>] [= value]'
# But there may be chars inside {} sections that should be ignored.
# So we scan to find first real *, <, > and =
asterix_pos = -1
lessthan_pos = -1
greaterthan_pos = -1
equals_pos = -1
inside_braces = False
for pos, char in enumerate(dtype_str):
if char == '{':
if inside_braces:
raise ValueError(f"Two consecutive opening braces found in '{dtype_str}'.")
inside_braces = True
if char == '}':
if not inside_braces:
raise ValueError(f"Closing brace found with no matching opening brace in '{dtype_str}'.")
inside_braces = False
if inside_braces:
continue
if char == '*':
if asterix_pos != -1:
raise ValueError(f"More than one '*' found in '{dtype_str}'.")
asterix_pos = pos
if char == '<':
if lessthan_pos != -1:
raise ValueError(f"More than one '<' found in '{dtype_str}'.")
lessthan_pos = pos
if char == '>':
if greaterthan_pos != -1:
raise ValueError(f"More than one '>' found in '{dtype_str}'.")
greaterthan_pos = pos
if char == '=':
if equals_pos != -1:
raise ValueError(f"More than one '=' found in '{dtype_str}'.")
equals_pos = pos

name = value = None
items = 1
# Check to see if it includes a value:
q = dtype_str.find('=')
if q != -1:
value = dtype_str[q + 1:]
dtype_str = dtype_str[:q]
if equals_pos != -1:
value = dtype_str[equals_pos + 1:]
dtype_str = dtype_str[:equals_pos]
# Check if it has a name:
name_start = dtype_str.find('<')
if name_start != -1:
name_end = dtype_str.find('>')
if name_end == -1:
if lessthan_pos != -1:
if greaterthan_pos == -1:
raise ValueError(
f"An opening '<' was supplied in the formatted dtype '{dtype_str} but without a closing '>'.")
name = dtype_str[name_start + 1:name_end]
name = dtype_str[lessthan_pos + 1:greaterthan_pos]
name = name.strip()
chars_after_name = dtype_str[name_end + 1:]
chars_after_name = dtype_str[greaterthan_pos + 1:]
if chars_after_name != '' and not chars_after_name.isspace():
raise ValueError(f"There should be no trailing characters after the <name>.")
dtype_str = dtype_str[:name_start]
multiply_pos = dtype_str.find('*')
if multiply_pos != -1:
items = dtype_str[multiply_pos + 1:]
try:
items = int(items)
except ValueError:
pass
dtype_str = dtype_str[:multiply_pos]
dtype_str = dtype_str[:lessthan_pos]
if asterix_pos != -1:
items = dtype_str[asterix_pos + 1:]
dtype_str = dtype_str[:asterix_pos]
return dtype_str, name, value, items

def _getvalue(self) -> Any:
Expand Down Expand Up @@ -172,7 +225,7 @@ def __eq__(self, other: Any) -> bool:
class Format:

def __init__(self, name: str | None = None,
fields: Sequence[Field | Format | str] | None = None) -> None:
fields: Sequence[Field | Format | str | Dtype | Bits] | None = None) -> None:
self.name = name
if self.name == '':
self.name = None
Expand Down Expand Up @@ -284,21 +337,8 @@ def append(self, value: Any) -> None:
self.__iadd__(value)

@staticmethod
def _safe_eval(s: str, vars_: Dict) -> Any:
start = s.find('{')
end = s.find('}')
if start == -1 or end == -1:
raise ValueError(f'Invalid expression: {s}. It should start and end with braces.')
s = s[start + 1:end]

node_whitelist = {'BinOp', 'Name', 'Add', 'Expr', 'Mult', 'FloorDiv', 'Sub', 'Load', 'Module'}
nodes_used = set([x.__class__.__name__ for x in ast.walk(ast.parse(s))])
bad_nodes = nodes_used - node_whitelist
if bad_nodes:
raise ValueError(f"Disallowed operations used in expression '{s}'. Disallowed nodes were: {bad_nodes}.")
if '__' in s:
raise ValueError(f'Invalid expression: {s}. Double underscores are not permitted.')
return eval(s, {"__builtins__": {}}, vars_)
def _safe_eval(code: CodeType, vars_: Dict) -> Any:
return eval(code, {"__builtins__": {}}, vars_)

def _build(self, value_iter: Iterator[Field]) -> Tuple[Sequence[Field | Format], Dict]:
out_fields = []
Expand All @@ -313,7 +353,7 @@ def _build(self, value_iter: Iterator[Field]) -> Tuple[Sequence[Field | Format],
out_fields.append(f)
continue
if field.bits is None:
if isinstance(field.items, str):
if isinstance(field.items, CodeType):
field.items = Format._safe_eval(field.items, vars_)
field = Field(field.dtype, field.name, next(value_iter), field.items)
out_fields.append(field)
Expand Down

0 comments on commit 142b412

Please sign in to comment.