Skip to content

Commit ca85050

Browse files
committed
refactor: fix some of the type issues with pattern
1 parent d7b56b4 commit ca85050

File tree

1 file changed

+85
-35
lines changed

1 file changed

+85
-35
lines changed

glassure/pattern.py

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from __future__ import annotations
33
import os
44
from typing_extensions import Annotated
5-
from typing import Union, Optional
6-
from pydantic import PlainSerializer, PlainValidator, BaseModel
5+
from typing import Union, Optional, TYPE_CHECKING
6+
from pydantic import PlainSerializer, PlainValidator
77
import numpy as np
88
import base64
99
import gzip
1010

11-
from dataclasses import dataclass
11+
from dataclasses import dataclass, field
1212

1313
import io
1414
from scipy.interpolate import interp1d
@@ -28,18 +28,35 @@ class Pattern:
2828
:param name: name of the pattern
2929
"""
3030

31-
x: Optional[PydanticNpArray] = None
32-
y: Optional[PydanticNpArray] = None
31+
x: PydanticNpArray = field(init=False)
32+
y: PydanticNpArray = field(init=False)
3333
name: str = ""
3434

35-
def __post_init__(self):
36-
if self.y is None and self.x is None:
37-
self.x = np.linspace(0, 10, 101)
38-
self.y = np.log(self.x**2) - (self.x * 0.2) ** 2
35+
def __init__(
36+
self,
37+
x: Optional[PydanticNpArray] = None,
38+
y: Optional[PydanticNpArray] = None,
39+
name: str = "",
40+
):
41+
x_values: PydanticNpArray
42+
y_values: PydanticNpArray
43+
44+
if x is None and y is None:
45+
x_values = np.linspace(0, 10, 101)
46+
y_values = np.log(x_values**2) - (x_values * 0.2) ** 2
47+
elif x is None or y is None:
48+
raise ValueError("Either both x and y must be provided or neither.")
49+
else:
50+
x_values = x
51+
y_values = y
3952

40-
if len(self.x) != len(self.y):
53+
if len(x_values) != len(y_values):
4154
raise ValueError("x and y values must have the same length")
4255

56+
self.x = x_values
57+
self.y = y_values
58+
self.name = name
59+
4360
def load(self, filename: str, skiprows: int = 0):
4461
"""
4562
Loads a pattern from a file. The file can be either a .xy or a .chi file. The .chi file will be loaded with
@@ -61,7 +78,7 @@ def load(self, filename: str, skiprows: int = 0):
6178
return -1
6279

6380
@staticmethod
64-
def from_file(filename: str, skip_rows: int = 0) -> Pattern | "-1":
81+
def from_file(filename: str, skip_rows: int = 0) -> Pattern | int:
6582
"""
6683
Loads a pattern from a file. The file can be either a .xy or a .chi file. The .chi file will be loaded with
6784
skiprows=4 by default.
@@ -89,15 +106,17 @@ def save(self, filename: str, header: str = ""):
89106
:param filename: path to the file
90107
:param header: header to be written to the file
91108
"""
92-
data = np.dstack((self.x, self.y))
109+
x, y = self.data
110+
data = np.dstack((x, y))
93111
np.savetxt(filename, data[0], header=header)
94112

95113
def smooth(self, amount: float) -> Pattern:
96114
"""
97115
Smoothing the pattern by applying a gaussian filter. Returns the smoothed pattern.
98116
:param amount: amount of smoothing to be applied
99117
"""
100-
return Pattern(self.x, gaussian_filter1d(self.y, amount))
118+
x, y = self.data
119+
return Pattern(x, gaussian_filter1d(y, amount))
101120

102121
def rebin(self, bin_size: float) -> Pattern:
103122
"""
@@ -117,7 +136,7 @@ def rebin(self, bin_size: float) -> Pattern:
117136
return Pattern(new_x, new_y)
118137

119138
@property
120-
def data(self) -> tuple[np.ndarray, np.ndarray]:
139+
def data(self) -> tuple[PydanticNpArray, PydanticNpArray]:
121140
"""
122141
Returns the data of the pattern as a tuple of x and y values.
123142
@@ -132,8 +151,11 @@ def data(self, data: tuple[np.ndarray, np.ndarray]):
132151
133152
:param data: tuple of x and y values
134153
"""
135-
self.x = data[0]
136-
self.y = data[1]
154+
x_values, y_values = data
155+
if len(x_values) != len(y_values):
156+
raise ValueError("x and y values must have the same length")
157+
self.x = x_values
158+
self.y = y_values
137159

138160
def limit(self, x_min: float, x_max: float) -> Pattern:
139161
"""
@@ -144,10 +166,9 @@ def limit(self, x_min: float, x_max: float) -> Pattern:
144166
:return: limited Pattern
145167
"""
146168
x, y = self.data
147-
return Pattern(
148-
x[np.where((x_min < x) & (x < x_max))],
149-
y[np.where((x_min < x) & (x < x_max))],
150-
)
169+
x_limited = x[np.where((x_min < x) & (x < x_max))]
170+
y_limited = y[np.where((x_min < x) & (x < x_max))]
171+
return Pattern(x_limited, y_limited)
151172

152173
def extend_to(self, x_value: float, y_value: float) -> Pattern:
153174
"""
@@ -159,23 +180,24 @@ def extend_to(self, x_value: float, y_value: float) -> Pattern:
159180
:param y_value: number to fill the pattern with
160181
:return: extended Pattern
161182
"""
162-
x_step = np.mean(np.diff(self.x))
163-
x_min = np.min(self.x)
164-
x_max = np.max(self.x)
183+
x, y = self.data
184+
x_step = np.mean(np.diff(x))
185+
x_min = np.min(x)
186+
x_max = np.max(x)
165187
if x_value < x_min:
166188
x_fill = np.arange(x_min - x_step, x_value - x_step * 0.5, -x_step)[::-1]
167189
y_fill = np.zeros(x_fill.shape)
168190
y_fill.fill(y_value)
169191

170-
new_x = np.concatenate((x_fill, self.x))
171-
new_y = np.concatenate((y_fill, self.y))
192+
new_x = np.concatenate((x_fill, x))
193+
new_y = np.concatenate((y_fill, y))
172194
elif x_value > x_max:
173195
x_fill = np.arange(x_max + x_step, x_value + x_step * 0.5, x_step)
174196
y_fill = np.zeros(x_fill.shape)
175197
y_fill.fill(y_value)
176198

177-
new_x = np.concatenate((self.x, x_fill))
178-
new_y = np.concatenate((self.y, y_fill))
199+
new_x = np.concatenate((x, x_fill))
200+
new_y = np.concatenate((y, y_fill))
179201
else:
180202
return self
181203

@@ -187,10 +209,11 @@ def to_dict(self) -> dict:
187209
188210
:return: dictionary representation of the pattern
189211
"""
212+
x, y = self.data
190213
return {
191214
"name": self.name,
192-
"x": self.x.tolist(),
193-
"y": self.y.tolist(),
215+
"x": x.tolist(),
216+
"y": y.tolist(),
194217
}
195218

196219
@staticmethod
@@ -300,7 +323,7 @@ def __mul__(self, other: float) -> Pattern:
300323
"""
301324
return self.__rmul__(other)
302325

303-
def __eq__(self, other: Pattern) -> bool:
326+
def __eq__(self, other: object) -> bool:
304327
"""
305328
Checks if two patterns are equal. Two patterns are equal if their data
306329
is equal.
@@ -327,6 +350,14 @@ def __str__(self):
327350

328351

329352
def validate(value):
353+
"""
354+
Validates a numpy array. If the value is a list, it is converted to a numpy array.
355+
If the value is a string, it is decoded from a base64 encoded string.
356+
If the value is already a numpy array, it is returned as is.
357+
358+
:param value: The value to validate
359+
:return: The validated numpy array
360+
"""
330361
if isinstance(value, list):
331362
return np.array(value)
332363
if isinstance(value, str):
@@ -343,8 +374,18 @@ def validate(value):
343374

344375

345376
def serialize(value):
346-
if isinstance(value, np.ndarray):
347-
# Save numpy array to compressed bytyes
377+
"""
378+
Serializes a numpy array to a base64 encoded string.
379+
If the value is a numpy array, it is saved to a compressed bytes buffer and then encoded to a base64 string.
380+
If the value is a list, it is converted to a numpy array and then serialized.
381+
If the value is a string, it is decoded from a base64 encoded string and then deserialized.
382+
If the value is already a numpy array, it is returned as is.
383+
384+
:param value: The value to serialize
385+
:return: The serialized numpy array
386+
"""
387+
if isinstance(value, (list, np.ndarray)):
388+
# Save numpy array to compressed bytes buffer
348389
with io.BytesIO() as buffer:
349390
np.save(buffer, value, allow_pickle=False)
350391
binary_data = buffer.getvalue()
@@ -354,6 +395,15 @@ def serialize(value):
354395
raise TypeError(f"Invalid type for numpy array: {type(value)}")
355396

356397

357-
PydanticNpArray = Annotated[
358-
np.ndarray, PlainValidator(validate), PlainSerializer(serialize)
359-
]
398+
"""
399+
This is a workaround to allow numpy arrays to be used as fields in a pydantic model.
400+
It is necessary because pydantic does not support numpy arrays as fields.
401+
We use a custom validator and serializer to convert the numpy array to a base64 encoded string
402+
and back again.
403+
"""
404+
if TYPE_CHECKING:
405+
PydanticNpArray = np.ndarray
406+
else:
407+
PydanticNpArray = Annotated[
408+
np.ndarray, PlainValidator(validate), PlainSerializer(serialize)
409+
]

0 commit comments

Comments
 (0)