Skip to content

Refactor _validate_data_input to simplify the codes #3818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
df88e02
Figure.plot/plot3d/text: Pass a dict of vectors rather than x/y/extra…
seisman Mar 3, 2025
ce57c59
Check if a dict of vectors contain None
seisman Mar 3, 2025
2cb9295
clib.virtualfile_in: Remove the 'extra_arrays' parameter
seisman Mar 3, 2025
362e5bd
Figure.plot3d: Add one more test to increase code coverage
seisman Mar 3, 2025
8f72e4c
Clarify that dictionary will be recognized as vectors
seisman Mar 3, 2025
5d3d308
Rename 'required_z' to 'required_ncols' in _validate_data_input
seisman Mar 5, 2025
406cb39
Rename 'required_z' to 'required_ncols' in Session.virtualfile_in
seisman Mar 5, 2025
af189e6
Rename 'required_z' to 'required_ncols' in Session.virtualfile_in tests
seisman Mar 5, 2025
2d55e5a
Rename 'required_z' to 'required_ncols' in module wrappers
seisman Mar 5, 2025
f017adb
Merge branch 'main' into refactor/virtualfile_in_extra_arrays
seisman Mar 6, 2025
c199dbf
Clarify that the data parameter can accept dicts
seisman Mar 6, 2025
07924ca
Merge branch 'main' into refactor/virtualfile_in_extra_arrays
seisman Mar 10, 2025
1d0d0dc
Fix a typo
seisman Mar 10, 2025
84e8d8a
Merge branch 'main' into refactor/virtualfile_in
seisman Mar 10, 2025
5eeb37b
Rename required_ncols to the shorter ncols
seisman Mar 10, 2025
a7ee253
Merge branch 'main' into refactor/virtualfile_in_extra_arrays
seisman Mar 24, 2025
5cd8275
Merge branch 'main' into refactor/virtualfile_in_extra_arrays
seisman Mar 26, 2025
e44e712
Improve docstrings for the new test
seisman Mar 26, 2025
5330d0a
Raise warnings when extra_arrays is used
seisman Mar 26, 2025
0a63ef0
Add TODO comments
seisman Mar 26, 2025
2851af8
Merge branch 'main' into refactor/virtualfile_in
seisman Mar 26, 2025
0e0c1d8
Deprecate required_z in a backward-compatible way
seisman Mar 26, 2025
f8293cd
Add one test with the deprecated 'required_z' parameter to increase c…
seisman Mar 26, 2025
60f1758
Refactor _validate_data_input
seisman Feb 21, 2025
3c61fea
Merge branch 'refactor/virtualfile_in' into refactor/validate_data_input
seisman Mar 26, 2025
312aa39
Merge branch 'main' into refactor/virtualfile_in
seisman Mar 28, 2025
b7ca239
Move TODO comment to the top
seisman Mar 28, 2025
a01afbb
Merge branch 'main' into refactor/virtualfile_in
seisman Apr 15, 2025
da874d1
Put ncols before required_data
seisman Apr 15, 2025
ea367c8
Merge branch 'refactor/virtualfile_in' into refactor/validate_data_input
seisman Apr 15, 2025
72c1730
Merge branch 'main' into refactor/validate_data_input
seisman Apr 28, 2025
2f22824
Merge branch 'main' into refactor/validate_data_input
seisman May 5, 2025
afe944e
Remove unrelated changes
seisman May 5, 2025
98e277b
Simplify
seisman May 5, 2025
0b16de5
Merge branch 'main' into refactor/validate_data_input
seisman May 7, 2025
71a62a9
info and histogram needs one column
seisman May 7, 2025
8b5d6e1
Update doctetss
seisman May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,16 +1847,13 @@ def virtualfile_in( # noqa: PLR0912
)
mincols = 3

# Specify either data or x/y/z.
if data is not None and any(v is not None for v in (x, y, z)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

# Determine the kind of data.
kind = data_kind(data, required=required)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required=required,
mincols=mincols,
kind=kind,
)

if check_kind:
valid_kinds = ("file", "arg") if required is False else ("file",)
Expand Down Expand Up @@ -1922,6 +1919,9 @@ def virtualfile_in( # noqa: PLR0912
_virtualfile_from = self.virtualfile_from_vectors
_data = data.T

# Check if _data to be passed to the virtualfile_from_ function is valid.
_validate_data_input(data=_data, kind=kind, mincols=mincols)

# Finally create the virtualfile from the data, to be passed into GMT
file_context = _virtualfile_from(_data)
return file_context
Expand Down
138 changes: 55 additions & 83 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import time
import webbrowser
from collections.abc import Iterable, Mapping, Sequence
from itertools import islice
from pathlib import Path
from typing import Any, Literal

Expand Down Expand Up @@ -40,119 +39,96 @@
"ISO-8859-15",
"ISO-8859-16",
]
# Type hints for the list of possible data kinds.
Kind = Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]


def _validate_data_input( # noqa: PLR0912
data=None, x=None, y=None, z=None, required=True, mincols=2, kind=None
) -> None:
def _validate_data_input(data: Any, kind: Kind, mincols: int = 2) -> None:
"""
Check if the combination of data/x/y/z is valid.
Check if the data to be passed to the virtualfile_from_ functions has the required
number of columns.

Only checks the "empty"/"vectors"/"matrix" kinds.

Examples
--------
>>> _validate_data_input(data="infile")
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> _validate_data_input(data=None, required=False)
>>> _validate_data_input()
The "empty" kind means the data is given via x/y/z.

>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty")
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="empty")
>>> _validate_data_input(data=[None, [4, 5, 6]], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: No input data provided.
>>> _validate_data_input(x=[1, 2, 3])
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(data=[[1, 2, 3], None], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(y=[4, 5, 6])
>>> _validate_data_input(data=[None, None], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], mincols=3)
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty", mincols=3)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, mincols=3, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... mincols=3,
... kind="vectors",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... mincols=3,
... kind="vectors",
... )

The "vectors" kind means the data is a series of 1-D vectors.

>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="vectors")
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="vectors")
>>> _validate_data_input(data=[None, [4, 5, 6]], kind="vectors")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
pygmt.exceptions.GMTInvalidInput: At least one column is None.
>>> _validate_data_input(data=[[1, 2, 3], None], kind="vectors")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", y=[4, 5, 6])
pygmt.exceptions.GMTInvalidInput: At least one column is None.
>>> _validate_data_input(data=[None, None], kind="vectors")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
pygmt.exceptions.GMTInvalidInput: At least one column is None.
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="vectors", mincols=3)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.

The "matrix" kind means the data is given via a 2-D numpy.ndarray.

>>> import numpy as np
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, kind="matrix", mincols=2)
>>> _validate_data_input(data=data, kind="matrix", mincols=3)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.

Raises
------
GMTInvalidInput
If the data input is not valid.
"""
required_z = mincols >= 3
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required: # data is not optional
msg = "No input data provided."
match kind:
case "empty": # data = [x, y] or [x, y, z]
if len(data) < 2 or any(v is None for v in data[:2]):
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
elif x is None or y is None: # either x or y is None
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
if required_z and z is None: # both x and y are not None, now check z
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
else: # data is not None
if x is not None or y is not None or z is not None:
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)
# check if data has the required z column
if required_z:
msg = "data must provide x, y, and z columns."
if kind == "matrix" and data.shape[1] < 3:
if mincols >= 3 and (len(data) < 3 or data[:3] is None):
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
if kind == "vectors":
if hasattr(data, "shape") and (
(len(data.shape) == 1 and data.shape[0] < 3)
or (len(data.shape) > 1 and data.shape[1] < 3)
): # np.ndarray or pd.DataFrame
raise GMTInvalidInput(msg)
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput(msg)
if kind == "vectors" and isinstance(data, dict):
# Iterator over the up-to-3 first elements.
arrays = list(islice(data.values(), 3))
if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y
msg = "Must provide x and y."
case "vectors": # A list of 1-D vectors or 2-D numpy array
if (actual_cols := len(data)) < mincols:
msg = f"Need at least {mincols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)
if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z
msg = "Must provide x, y, and z."
if any(array is None for array in data[:mincols]):
msg = "At least one column is None."
raise GMTInvalidInput(msg)
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < mincols:
msg = f"Need at least {mincols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)


Expand Down Expand Up @@ -272,11 +248,7 @@ def _check_encoding(argstr: str) -> Encoding:
return "ISOLatin1+"


def data_kind(
data: Any, required: bool = True
) -> Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
def data_kind(data: Any, required: bool = True) -> Kind:
r"""
Check the kind of data that is provided to a module.

Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def histogram(self, data: PathLike | TableLike, **kwargs):
"""
kwargs = self._preprocess(**kwargs)
with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
with lib.virtualfile_in(check_kind="vector", data=data, mincols=1) as vintbl:
lib.call_module(
module="histogram", args=build_arg_list(kwargs, infile=vintbl)
)
4 changes: 3 additions & 1 deletion pygmt/src/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def info(data: PathLike | TableLike, **kwargs) -> np.ndarray | str:
"""
with Session() as lib:
with GMTTempFile() as tmpfile:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
with lib.virtualfile_in(
check_kind="vector", data=data, mincols=1
) as vintbl:
lib.call_module(
module="info",
args=build_arg_list(kwargs, infile=vintbl, outfile=tmpfile.name),
Expand Down