Skip to content

Commit 246ad6c

Browse files
authored
fix: make sure not to override user set values for from_sample (bentoml#3610)
1 parent 699c9c2 commit 246ad6c

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

src/bentoml/_internal/io_descriptors/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,10 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
483483
raise BentoMLException(
484484
f"Failed to create a 'numpy.ndarray' from given sample {sample}"
485485
) from None
486-
self._dtype = sample.dtype
487-
self._shape = sample.shape
486+
if self._dtype is None:
487+
self._dtype = sample.dtype
488+
if self._shape is None:
489+
self._shape = sample.shape
488490
return sample
489491

490492
async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:

src/bentoml/_internal/io_descriptors/pandas.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,10 @@ def predict(inputs: pd.DataFrame) -> pd.DataFrame: ...
427427
raise InvalidArgument(
428428
f"Failed to create a 'pd.DataFrame' from sample {sample}: {e}"
429429
) from None
430-
self._shape = sample.shape
431-
self._columns = [str(i) for i in list(sample.columns)]
430+
if self._shape is None:
431+
self._shape = sample.shape
432+
if self._columns is None:
433+
self._columns = [str(i) for i in list(sample.columns)]
432434
if self._dtype is None:
433435
self._dtype = sample.dtypes
434436
return sample
@@ -933,8 +935,10 @@ def predict(inputs: pd.Series) -> pd.Series: ...
933935
"""
934936
if not isinstance(sample, pd.Series):
935937
sample = pd.Series(sample)
936-
self._dtype = sample.dtype
937-
self._shape = sample.shape
938+
if self._dtype is None:
939+
self._dtype = sample.dtype
940+
if self._shape is None:
941+
self._shape = sample.shape
938942
return sample
939943

940944
def input_type(self) -> LazyType[ext.PdSeries]:

tests/unit/_internal/io/test_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture):
117117
assert "Failed to reshape" in caplog.text
118118

119119

120+
def test_from_sample_ensure_not_override():
121+
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)), dtype=np.float32)
122+
assert example._dtype == np.float32
123+
124+
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)), shape=(2, 2, 3))
125+
assert example._shape == (2, 2, 3)
126+
127+
120128
def generate_1d_array(dtype: pb.NDArray.DType.ValueType, length: int = 3):
121129
if dtype == pb.NDArray.DTYPE_BOOL:
122130
return [True] * length

0 commit comments

Comments
 (0)