Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _validate_data_input to simplify the codes #3818

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,14 +1765,13 @@
seg.header = None
seg.text = None

def virtualfile_in(
def virtualfile_in( # noqa: PLR0912
self,
check_kind=None,
data=None,
x=None,
y=None,
z=None,
extra_arrays=None,
required_z=False,
required_data=True,
):
Expand All @@ -1794,9 +1793,6 @@
data input.
x/y/z : 1-D arrays or None
x, y, and z columns as numpy arrays.
extra_arrays : list of 1-D arrays
Optional. A list of numpy arrays in addition to x, y, and z.
All of these arrays must be of the same size as the x/y/z arrays.
required_z : bool
State whether the 'z' column is required.
required_data : bool
Expand Down Expand Up @@ -1829,23 +1825,25 @@
... print(fout.read().strip())
<vector memory>: N = 3 <7/9> <4/6> <1/3>
"""
# Specify either data or x/y/z.
if data is not None and any(v is not None for v in (x, y, z)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

# Determine the kind of data.
kind = data_kind(data, required=required_data)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)

# Check if the kind of data is valid.
if check_kind:
valid_kinds = ("file", "arg") if required_data is False else ("file",)
if check_kind == "raster":
valid_kinds += ("grid", "image")
elif check_kind == "vector":
valid_kinds += ("empty", "matrix", "vectors", "geojson")
match check_kind:
case "raster":
valid_kinds += ("grid", "image")
case "vector":
valid_kinds += ("empty", "matrix", "vectors", "geojson")
case _:
msg = f"Invalid value for check_kind: '{check_kind}'."
raise GMTInvalidInput(msg)

Check warning on line 1846 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1844-L1846

Added lines #L1844 - L1846 were not covered by tests
if kind not in valid_kinds:
msg = f"Unrecognized data type for {check_kind}: {type(data)}."
raise GMTInvalidInput(msg)
Expand Down Expand Up @@ -1879,11 +1877,9 @@
_data = [x, y]
if z is not None:
_data.append(z)
if extra_arrays:
_data.extend(extra_arrays)
case "vectors":
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# pandas.DataFrame or xarray.Dataset types.
# Dictionary, pandas.DataFrame or xarray.Dataset types.
# pandas.Series will be handled below like a 1-D numpy.ndarray.
_data = [array for _, array in data.items()]
else:
Expand All @@ -1898,6 +1894,9 @@
_virtualfile_from = self.virtualfile_from_vectors
_data = data.T

# Check if _data to be passed to the virtualfile_from_ function is valid.
_validate_data_input(data=_data, kind=kind, required_z=required_z)

# Finally create the virtualfile from the data, to be passed into GMT
file_context = _virtualfile_from(_data)
return file_context
Expand Down
116 changes: 50 additions & 66 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,110 +39,98 @@
"ISO-8859-15",
"ISO-8859-16",
]
# Type hints for the list of possible data kinds.
Kind = Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
) -> None:
def _validate_data_input(data: Any, kind: Kind, required_z: bool = False) -> None:
"""
Check if the combination of data/x/y/z is valid.
Check if the data to be passed to the virtualfile_from_ functions is valid.

Examples
--------
>>> _validate_data_input(data="infile")
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> _validate_data_input(data=None, required_data=False)
>>> _validate_data_input()
The "empty" kind means the data is given via a series of vectors like x/y/z.

>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty")
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="empty")
>>> _validate_data_input(data=[None, [4, 5, 6]], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: No input data provided.
>>> _validate_data_input(x=[1, 2, 3])
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(data=[[1, 2, 3], None], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(y=[4, 5, 6])
>>> _validate_data_input(data=[None, None], kind="empty")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty", required_z=True)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.

The "matrix" kind means the data is given via a 2-D numpy.ndarray.

>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
>>> _validate_data_input(data=data, kind="matrix", required_z=True)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.

The "vectors" kind means the original data is either dictionary, list, tuple,
pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray.

>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... kind="vectors",
... required_z=True,
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... kind="vectors",
... required_z=True,
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.

Raises
------
GMTInvalidInput
If the data input is not valid.
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required_data: # data is not optional
msg = "No input data provided."
# Determine the required number of columns based on the required_z flag.
required_cols = 3 if required_z else 1
Comment on lines +111 to +112
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most modules usually requires 2 columns but there are exceptions like histogram and info will works for 1 column. So the default value is 1 if required_z is False.

Will refactor this part in PR #3369


match kind:
case "empty": # data = [x, y], [x, y, z], [x, y, z, ...]
if len(data) < 2 or any(v is None for v in data[:2]):
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
elif x is None or y is None: # either x or y is None
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
if required_z and z is None: # both x and y are not None, now check z
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
else: # data is not None
if x is not None or y is not None or z is not None:
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)
# check if data has the required z column
if required_z:
msg = "data must provide x, y, and z columns."
if kind == "matrix" and data.shape[1] < 3:
if required_z and (len(data) < 3 or data[:3] is None):
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < required_cols:
msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)
case "vectors":
# "vectors" means the original data is either dictionary, list, tuple,
# pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray.
# The original data is converted to a list of vectors or a 2-D numpy.ndarray
# in the virtualfile_in function.
if (actual_cols := len(data)) < required_cols:
msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given."

Check warning on line 132 in pygmt/helpers/utils.py

View check run for this annotation

Codecov / codecov/patch

pygmt/helpers/utils.py#L132

Added line #L132 was not covered by tests
raise GMTInvalidInput(msg)
if kind == "vectors":
if hasattr(data, "shape") and (
(len(data.shape) == 1 and data.shape[0] < 3)
or (len(data.shape) > 1 and data.shape[1] < 3)
): # np.ndarray or pd.DataFrame
raise GMTInvalidInput(msg)
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput(msg)


def _is_printable_ascii(argstr: str) -> bool:
Expand Down Expand Up @@ -261,11 +249,7 @@
return "ISOLatin1+"


def data_kind(
data: Any, required: bool = True
) -> Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
def data_kind(data: Any, required: bool = True) -> Kind:
r"""
Check the kind of data that is provided to a module.

Expand Down
25 changes: 13 additions & 12 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot(
def plot( # noqa: PLR0912
self,
data=None,
x=None,
Expand Down Expand Up @@ -232,34 +232,37 @@
kwargs = self._preprocess(**kwargs)

kind = data_kind(data)
extra_arrays = []
if kind == "empty": # Add more columns for vectors input
if kind == "empty": # Data is given via a series of vectors.
data = {"x": x, "y": y}
# Parameters for vector styles
if (
isinstance(kwargs.get("S"), str)
and len(kwargs["S"]) >= 1
and kwargs["S"][0] in "vV"
and is_nonstr_iter(direction)
):
extra_arrays.extend(direction)
data.update({"x2": direction[0], "y2": direction[1]})
# Fill
if is_nonstr_iter(kwargs.get("G")):
extra_arrays.append(kwargs.get("G"))
data["fill"] = kwargs["G"]
del kwargs["G"]
# Size
if is_nonstr_iter(size):
extra_arrays.append(size)
data["size"] = size
# Intensity and transparency
for flag in ["I", "t"]:
for flag, name in ["I", "intensity"], ["t", "transparency"]:
if is_nonstr_iter(kwargs.get(flag)):
extra_arrays.append(kwargs.get(flag))
data[name] = kwargs[flag]
kwargs[flag] = ""
# Symbol must be at the last column
if is_nonstr_iter(symbol):
if "S" not in kwargs:
kwargs["S"] = True
extra_arrays.append(symbol)
data["symbol"] = symbol
else:
if any(v is not None for v in (x, y)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

Check warning on line 265 in pygmt/src/plot.py

View check run for this annotation

Codecov / codecov/patch

pygmt/src/plot.py#L264-L265

Added lines #L264 - L265 were not covered by tests
for name, value in [
("direction", direction),
("fill", kwargs.get("G")),
Expand All @@ -277,7 +280,5 @@
kwargs["S"] = "s0.2c"

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
) as vintbl:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))
31 changes: 14 additions & 17 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot3d(
def plot3d( # noqa: PLR0912
self,
data=None,
x=None,
Expand Down Expand Up @@ -210,35 +210,38 @@ def plot3d(
kwargs = self._preprocess(**kwargs)

kind = data_kind(data)
extra_arrays = []

if kind == "empty": # Add more columns for vectors input
if kind == "empty": # Data is given via a series of vectors.
data = {"x": x, "y": y, "z": z}
# Parameters for vector styles
if (
isinstance(kwargs.get("S"), str)
and len(kwargs["S"]) >= 1
and kwargs["S"][0] in "vV"
and is_nonstr_iter(direction)
):
extra_arrays.extend(direction)
data.update({"x2": direction[0], "y2": direction[1]})
# Fill
if is_nonstr_iter(kwargs.get("G")):
extra_arrays.append(kwargs.get("G"))
data["fill"] = kwargs["G"]
del kwargs["G"]
# Size
if is_nonstr_iter(size):
extra_arrays.append(size)
data["size"] = size
# Intensity and transparency
for flag in ["I", "t"]:
for flag, name in [("I", "intensity"), ("t", "transparency")]:
if is_nonstr_iter(kwargs.get(flag)):
extra_arrays.append(kwargs.get(flag))
data[name] = kwargs[flag]
kwargs[flag] = ""
# Symbol must be at the last column
if is_nonstr_iter(symbol):
if "S" not in kwargs:
kwargs["S"] = True
extra_arrays.append(symbol)
data["symbol"] = symbol
else:
if any(v is not None for v in (x, y, z)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

for name, value in [
("direction", direction),
("fill", kwargs.get("G")),
Expand All @@ -257,12 +260,6 @@ def plot3d(

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector",
data=data,
x=x,
y=y,
z=z,
extra_arrays=extra_arrays,
required_z=True,
check_kind="vector", data=data, required_z=True
) as vintbl:
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))
Loading
Loading