Skip to content

Commit bafb8ab

Browse files
weiji14seisman
andauthored
Let required_z check work on xarray.Dataset inputs (#1523)
Modify if-statement check in `data_kind` helper function to use `len(data.data_vars)` to check the shape of xarray.Dataset inputs. Added two regression tests in test_clib and test_blockm to ensure this works on both the low-level clib and high-level module APIs. * Mention required_z parameter in docstring Co-authored-by: Dongdong Tian <[email protected]>
1 parent c0a8dfa commit bafb8ab

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

pygmt/clib/session.py

+2
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,8 @@ def virtualfile_from_data(
13751375
extra_arrays : list of 1d arrays
13761376
Optional. A list of numpy arrays in addition to x, y and z. All
13771377
of these arrays must be of the same size as the x/y/z arrays.
1378+
required_z : bool
1379+
State whether the 'z' column is required.
13781380
13791381
Returns
13801382
-------

pygmt/helpers/utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def data_kind(data, x=None, y=None, z=None, required_z=False):
3939
x/y : 1d arrays or None
4040
x and y columns as numpy arrays.
4141
z : 1d array or None
42-
z column as numpy array. To be used optionally when x and y
43-
are given.
42+
z column as numpy array. To be used optionally when x and y are given.
43+
required_z : bool
44+
State whether the 'z' column is required.
4445
4546
Returns
4647
-------
@@ -80,7 +81,10 @@ def data_kind(data, x=None, y=None, z=None, required_z=False):
8081
elif hasattr(data, "__geo_interface__"):
8182
kind = "geojson"
8283
elif data is not None:
83-
if required_z and data.shape[1] < 3:
84+
if required_z and (
85+
getattr(data, "shape", (3, 3))[1] < 3 # np.array, pd.DataFrame
86+
or len(getattr(data, "data_vars", (0, 1, 2))) < 3 # xr.Dataset
87+
):
8488
raise GMTInvalidInput("data must provide x, y, and z columns.")
8589
kind = "matrix"
8690
else:

pygmt/tests/test_blockm.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
"""
44
import os
55

6+
import numpy as np
67
import numpy.testing as npt
78
import pandas as pd
89
import pytest
10+
import xarray as xr
911
from pygmt import blockmean, blockmode
1012
from pygmt.datasets import load_sample_bathymetry
1113
from pygmt.exceptions import GMTInvalidInput
@@ -31,12 +33,13 @@ def test_blockmean_input_dataframe(dataframe):
3133
npt.assert_allclose(output.iloc[0], [245.888877, 29.978707, -384.0])
3234

3335

34-
def test_blockmean_input_table_matrix(dataframe):
36+
@pytest.mark.parametrize("array_func", [np.array, xr.Dataset])
37+
def test_blockmean_input_table_matrix(array_func, dataframe):
3538
"""
3639
Run blockmean using table input that is not a pandas.DataFrame but still a
3740
matrix.
3841
"""
39-
table = dataframe.values
42+
table = array_func(dataframe)
4043
output = blockmean(table=table, spacing="5m", region=[245, 255, 20, 30])
4144
assert isinstance(output, pd.DataFrame)
4245
assert output.shape == (5849, 3)

pygmt/tests/test_clib.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,36 @@ def test_virtual_file_bad_direction():
419419
print("This should have failed")
420420

421421

422-
def test_virtualfile_from_data_required_z_matrix():
422+
@pytest.mark.parametrize(
423+
"array_func,kind",
424+
[(np.array, "matrix"), (pd.DataFrame, "vector"), (xr.Dataset, "vector")],
425+
)
426+
def test_virtualfile_from_data_required_z_matrix(array_func, kind):
427+
"""
428+
Test that function works when third z column in a matrix is needed and
429+
provided.
430+
"""
431+
shape = (5, 3)
432+
dataframe = pd.DataFrame(
433+
data=np.arange(shape[0] * shape[1]).reshape(shape), columns=["x", "y", "z"]
434+
)
435+
data = array_func(dataframe)
436+
with clib.Session() as lib:
437+
with lib.virtualfile_from_data(data=data, required_z=True) as vfile:
438+
with GMTTempFile() as outfile:
439+
lib.call_module("info", f"{vfile} ->{outfile.name}")
440+
output = outfile.read(keep_tabs=True)
441+
bounds = "\t".join(
442+
[
443+
f"<{i.min():.0f}/{i.max():.0f}>"
444+
for i in (dataframe.x, dataframe.y, dataframe.z)
445+
]
446+
)
447+
expected = f"<{kind} memory>: N = {shape[0]}\t{bounds}\n"
448+
assert output == expected
449+
450+
451+
def test_virtualfile_from_data_required_z_matrix_missing():
423452
"""
424453
Test that function fails when third z column in a matrix is needed but not
425454
provided.

0 commit comments

Comments
 (0)