22from __future__ import annotations
33import os
44from typing_extensions import Annotated
5- from typing import Union , Optional
6- from pydantic import PlainSerializer , PlainValidator , BaseModel
5+ from typing import Union , Optional , TYPE_CHECKING
6+ from pydantic import PlainSerializer , PlainValidator
77import numpy as np
88import base64
99import gzip
1010
11- from dataclasses import dataclass
11+ from dataclasses import dataclass , field
1212
1313import io
1414from scipy .interpolate import interp1d
@@ -28,18 +28,35 @@ class Pattern:
2828 :param name: name of the pattern
2929 """
3030
31- x : Optional [ PydanticNpArray ] = None
32- y : Optional [ PydanticNpArray ] = None
31+ x : PydanticNpArray = field ( init = False )
32+ y : PydanticNpArray = field ( init = False )
3333 name : str = ""
3434
35- def __post_init__ (self ):
36- if self .y is None and self .x is None :
37- self .x = np .linspace (0 , 10 , 101 )
38- self .y = np .log (self .x ** 2 ) - (self .x * 0.2 ) ** 2
35+ def __init__ (
36+ self ,
37+ x : Optional [PydanticNpArray ] = None ,
38+ y : Optional [PydanticNpArray ] = None ,
39+ name : str = "" ,
40+ ):
41+ x_values : PydanticNpArray
42+ y_values : PydanticNpArray
43+
44+ if x is None and y is None :
45+ x_values = np .linspace (0 , 10 , 101 )
46+ y_values = np .log (x_values ** 2 ) - (x_values * 0.2 ) ** 2
47+ elif x is None or y is None :
48+ raise ValueError ("Either both x and y must be provided or neither." )
49+ else :
50+ x_values = x
51+ y_values = y
3952
40- if len (self . x ) != len (self . y ):
53+ if len (x_values ) != len (y_values ):
4154 raise ValueError ("x and y values must have the same length" )
4255
56+ self .x = x_values
57+ self .y = y_values
58+ self .name = name
59+
4360 def load (self , filename : str , skiprows : int = 0 ):
4461 """
4562 Loads a pattern from a file. The file can be either a .xy or a .chi file. The .chi file will be loaded with
@@ -61,7 +78,7 @@ def load(self, filename: str, skiprows: int = 0):
6178 return - 1
6279
6380 @staticmethod
64- def from_file (filename : str , skip_rows : int = 0 ) -> Pattern | "-1" :
81+ def from_file (filename : str , skip_rows : int = 0 ) -> Pattern | int :
6582 """
6683 Loads a pattern from a file. The file can be either a .xy or a .chi file. The .chi file will be loaded with
6784 skiprows=4 by default.
@@ -89,15 +106,17 @@ def save(self, filename: str, header: str = ""):
89106 :param filename: path to the file
90107 :param header: header to be written to the file
91108 """
92- data = np .dstack ((self .x , self .y ))
109+ x , y = self .data
110+ data = np .dstack ((x , y ))
93111 np .savetxt (filename , data [0 ], header = header )
94112
95113 def smooth (self , amount : float ) -> Pattern :
96114 """
97115 Smoothing the pattern by applying a gaussian filter. Returns the smoothed pattern.
98116 :param amount: amount of smoothing to be applied
99117 """
100- return Pattern (self .x , gaussian_filter1d (self .y , amount ))
118+ x , y = self .data
119+ return Pattern (x , gaussian_filter1d (y , amount ))
101120
102121 def rebin (self , bin_size : float ) -> Pattern :
103122 """
@@ -117,7 +136,7 @@ def rebin(self, bin_size: float) -> Pattern:
117136 return Pattern (new_x , new_y )
118137
119138 @property
120- def data (self ) -> tuple [np . ndarray , np . ndarray ]:
139+ def data (self ) -> tuple [PydanticNpArray , PydanticNpArray ]:
121140 """
122141 Returns the data of the pattern as a tuple of x and y values.
123142
@@ -132,8 +151,11 @@ def data(self, data: tuple[np.ndarray, np.ndarray]):
132151
133152 :param data: tuple of x and y values
134153 """
135- self .x = data [0 ]
136- self .y = data [1 ]
154+ x_values , y_values = data
155+ if len (x_values ) != len (y_values ):
156+ raise ValueError ("x and y values must have the same length" )
157+ self .x = x_values
158+ self .y = y_values
137159
138160 def limit (self , x_min : float , x_max : float ) -> Pattern :
139161 """
@@ -144,10 +166,9 @@ def limit(self, x_min: float, x_max: float) -> Pattern:
144166 :return: limited Pattern
145167 """
146168 x , y = self .data
147- return Pattern (
148- x [np .where ((x_min < x ) & (x < x_max ))],
149- y [np .where ((x_min < x ) & (x < x_max ))],
150- )
169+ x_limited = x [np .where ((x_min < x ) & (x < x_max ))]
170+ y_limited = y [np .where ((x_min < x ) & (x < x_max ))]
171+ return Pattern (x_limited , y_limited )
151172
152173 def extend_to (self , x_value : float , y_value : float ) -> Pattern :
153174 """
@@ -159,23 +180,24 @@ def extend_to(self, x_value: float, y_value: float) -> Pattern:
159180 :param y_value: number to fill the pattern with
160181 :return: extended Pattern
161182 """
162- x_step = np .mean (np .diff (self .x ))
163- x_min = np .min (self .x )
164- x_max = np .max (self .x )
183+ x , y = self .data
184+ x_step = np .mean (np .diff (x ))
185+ x_min = np .min (x )
186+ x_max = np .max (x )
165187 if x_value < x_min :
166188 x_fill = np .arange (x_min - x_step , x_value - x_step * 0.5 , - x_step )[::- 1 ]
167189 y_fill = np .zeros (x_fill .shape )
168190 y_fill .fill (y_value )
169191
170- new_x = np .concatenate ((x_fill , self . x ))
171- new_y = np .concatenate ((y_fill , self . y ))
192+ new_x = np .concatenate ((x_fill , x ))
193+ new_y = np .concatenate ((y_fill , y ))
172194 elif x_value > x_max :
173195 x_fill = np .arange (x_max + x_step , x_value + x_step * 0.5 , x_step )
174196 y_fill = np .zeros (x_fill .shape )
175197 y_fill .fill (y_value )
176198
177- new_x = np .concatenate ((self . x , x_fill ))
178- new_y = np .concatenate ((self . y , y_fill ))
199+ new_x = np .concatenate ((x , x_fill ))
200+ new_y = np .concatenate ((y , y_fill ))
179201 else :
180202 return self
181203
@@ -187,10 +209,11 @@ def to_dict(self) -> dict:
187209
188210 :return: dictionary representation of the pattern
189211 """
212+ x , y = self .data
190213 return {
191214 "name" : self .name ,
192- "x" : self . x .tolist (),
193- "y" : self . y .tolist (),
215+ "x" : x .tolist (),
216+ "y" : y .tolist (),
194217 }
195218
196219 @staticmethod
@@ -300,7 +323,7 @@ def __mul__(self, other: float) -> Pattern:
300323 """
301324 return self .__rmul__ (other )
302325
303- def __eq__ (self , other : Pattern ) -> bool :
326+ def __eq__ (self , other : object ) -> bool :
304327 """
305328 Checks if two patterns are equal. Two patterns are equal if their data
306329 is equal.
@@ -327,6 +350,14 @@ def __str__(self):
327350
328351
329352def validate (value ):
353+ """
354+ Validates a numpy array. If the value is a list, it is converted to a numpy array.
355+ If the value is a string, it is decoded from a base64 encoded string.
356+ If the value is already a numpy array, it is returned as is.
357+
358+ :param value: The value to validate
359+ :return: The validated numpy array
360+ """
330361 if isinstance (value , list ):
331362 return np .array (value )
332363 if isinstance (value , str ):
@@ -343,8 +374,18 @@ def validate(value):
343374
344375
345376def serialize (value ):
346- if isinstance (value , np .ndarray ):
347- # Save numpy array to compressed bytyes
377+ """
378+ Serializes a numpy array to a base64 encoded string.
379+ If the value is a numpy array, it is saved to a compressed bytes buffer and then encoded to a base64 string.
380+ If the value is a list, it is converted to a numpy array and then serialized.
381+ If the value is a string, it is decoded from a base64 encoded string and then deserialized.
382+ If the value is already a numpy array, it is returned as is.
383+
384+ :param value: The value to serialize
385+ :return: The serialized numpy array
386+ """
387+ if isinstance (value , (list , np .ndarray )):
388+ # Save numpy array to compressed bytes buffer
348389 with io .BytesIO () as buffer :
349390 np .save (buffer , value , allow_pickle = False )
350391 binary_data = buffer .getvalue ()
@@ -354,6 +395,15 @@ def serialize(value):
354395 raise TypeError (f"Invalid type for numpy array: { type (value )} " )
355396
356397
357- PydanticNpArray = Annotated [
358- np .ndarray , PlainValidator (validate ), PlainSerializer (serialize )
359- ]
398+ """
399+ This is a workaround to allow numpy arrays to be used as fields in a pydantic model.
400+ It is necessary because pydantic does not support numpy arrays as fields.
401+ We use a custom validator and serializer to convert the numpy array to a base64 encoded string
402+ and back again.
403+ """
404+ if TYPE_CHECKING :
405+ PydanticNpArray = np .ndarray
406+ else :
407+ PydanticNpArray = Annotated [
408+ np .ndarray , PlainValidator (validate ), PlainSerializer (serialize )
409+ ]
0 commit comments