Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions src/json_stream/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import collections
import copy
from abc import ABC
from collections import OrderedDict
from itertools import chain
from collections import OrderedDict, deque
from itertools import chain, count as itertools_count, zip_longest
from typing import Optional, Iterator, Any

from json_stream.tokenizer import TokenType
Expand Down Expand Up @@ -106,6 +106,10 @@ def __len__(self) -> int:
def __repr__(self): # pragma: no cover
return f"<{type(self).__name__}: {repr(self._data)}, {'STREAMING' if self.streaming else 'DONE'}>"

def __contains__(self, item):
self.read_all()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be implemented so that, if item is near the start, you don't have to read the whole thing.

return item in self._data


class TransientStreamingJSONBase(StreamingJSONBase, ABC):
def __init__(self, token_stream):
Expand Down Expand Up @@ -158,6 +162,20 @@ def _load_item(self):
def _get__iter__(self):
return self._iter_items()

def index(self, item, start=0, stop=None):
for i, v in enumerate(self._iter_items()):
if i < start:
continue
elif stop is not None and i > stop:
break
if v is item or v == item:
return i
raise ValueError

@staticmethod
def _index_args(*args):
return args[:args.index(None) if None in args else -1]


class PersistentStreamingJSONList(PersistentStreamingJSONBase, StreamingJSONList):
def _init_persistent_data(self):
Expand All @@ -183,6 +201,21 @@ def __getitem__(self, k) -> Any:
pass
return self._find_item(k)

def index(self, item, /, start=None, stop=None):
args = self._index_args(start, stop)
try:
return self._data.index(item, *args)
except ValueError:
return len(self._data) + super().index(item, *args)

def count(self, item):
self.read_all()
return self._data.count(item)

def __reversed__(self):
self.read_all()
return reversed(self._data)


class TransientStreamingJSONList(TransientStreamingJSONBase, StreamingJSONList):
def __init__(self, token_stream):
Expand All @@ -202,6 +235,25 @@ def _find_item(self, i):
return v
raise IndexError(f"Index {i} out of range")

def index(self, item, /, start=None, stop=None):
if (start is not None and start < 0) or (stop is not None and stop < 0):
raise IndexError("Negative indices not supported for transient lists")
return self._index + super().index(item, *self._index_args(start, stop)) + 1

def count(self, item):
self._check_started()
# equivalent to but faster than sum(1 for i in self if i is item or i == item)
counter = itertools_count()
deque(zip((i for i in self._iter_items() if i is item or i == item), counter), maxlen=0) # (consume at C speed)
return next(counter)

def __reversed__(self):
self._check_started()
# this approach releases memory as iterator advances
stack = deque(self._iter_items())
while stack:
yield stack.pop()


class StreamingJSONObject(StreamingJSONBase, ABC):
INCOMPLETE_ERROR = "Unterminated object at end of file"
Expand Down Expand Up @@ -275,6 +327,18 @@ def __getitem__(self, k) -> Any:
pass
return self._find_item(k)

def __eq__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
self.read_all()
return self._data == other

def __ne__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
self.read_all()
return self._data != other


class TransientStreamingJSONObject(TransientStreamingJSONBase, StreamingJSONObject):
def _find_item(self, k):
Expand All @@ -299,3 +363,9 @@ def keys(self):
def values(self):
self._check_started()
return (v for k, v in self._iter_items())

def __eq__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
not_equal = object() # sentinel for length differences
return all(a == b for a, b in zip_longest(self.items(), other.items(), fillvalue=not_equal))