10
10
from s3fs import S3FileSystem
11
11
from shapely .geometry import Point
12
12
13
- from kedro_datasets .geopandas import GeoJSONDataset
13
+ from kedro_datasets .geopandas import GenericDataset
14
14
15
15
16
16
@pytest .fixture (params = [None ])
@@ -24,16 +24,36 @@ def save_version(request):
24
24
25
25
26
26
@pytest .fixture
27
- def filepath (tmp_path ):
27
+ def filepath_geojson (tmp_path ):
28
28
return (tmp_path / "test.geojson" ).as_posix ()
29
29
30
30
31
+ @pytest .fixture
32
+ def filepath_parquet (tmp_path ):
33
+ return (tmp_path / "test.parquet" ).as_posix ()
34
+
35
+
36
+ @pytest .fixture
37
+ def filepath_feather (tmp_path ):
38
+ return (tmp_path / "test.feather" ).as_posix ()
39
+
40
+
41
+ @pytest .fixture
42
+ def filepath_postgis (tmp_path ):
43
+ return (tmp_path / "test.sql" ).as_posix ()
44
+
45
+
46
+ @pytest .fixture
47
+ def filepath_abc (tmp_path ):
48
+ return tmp_path / "test.abc"
49
+
50
+
31
51
@pytest .fixture (params = [None ])
32
52
def load_args (request ):
33
53
return request .param
34
54
35
55
36
- @pytest .fixture (params = [{ "driver" : "GeoJSON" } ])
56
+ @pytest .fixture (params = [None ])
37
57
def save_args (request ):
38
58
return request .param
39
59
@@ -47,20 +67,77 @@ def dummy_dataframe():
47
67
48
68
49
69
@pytest .fixture
50
- def geojson_dataset (filepath , load_args , save_args , fs_args ):
51
- return GeoJSONDataset (
52
- filepath = filepath , load_args = load_args , save_args = save_args , fs_args = fs_args
70
+ def geojson_dataset (filepath_geojson , load_args , save_args , fs_args ):
71
+ return GenericDataset (
72
+ filepath = filepath_geojson ,
73
+ load_args = load_args ,
74
+ save_args = save_args ,
75
+ fs_args = fs_args ,
76
+ )
77
+
78
+
79
+ @pytest .fixture
80
+ def parquet_dataset (filepath_parquet , load_args , save_args , fs_args ):
81
+ return GenericDataset (
82
+ filepath = filepath_parquet ,
83
+ file_format = "parquet" ,
84
+ load_args = load_args ,
85
+ save_args = save_args ,
86
+ fs_args = fs_args ,
87
+ )
88
+
89
+
90
+ @pytest .fixture
91
+ def parquet_dataset_bad_config (filepath_parquet , load_args , save_args , fs_args ):
92
+ return GenericDataset (
93
+ filepath = filepath_parquet ,
94
+ load_args = load_args ,
95
+ save_args = save_args ,
96
+ fs_args = fs_args ,
97
+ )
98
+
99
+
100
+ @pytest .fixture
101
+ def feather_dataset (filepath_feather , load_args , save_args , fs_args ):
102
+ return GenericDataset (
103
+ filepath = filepath_feather ,
104
+ file_format = "feather" ,
105
+ load_args = load_args ,
106
+ save_args = save_args ,
107
+ fs_args = fs_args ,
108
+ )
109
+
110
+
111
+ @pytest .fixture
112
+ def postgis_dataset (filepath_postgis , load_args , save_args , fs_args ):
113
+ return GenericDataset (
114
+ filepath = filepath_postgis ,
115
+ file_format = "postgis" ,
116
+ load_args = load_args ,
117
+ save_args = save_args ,
118
+ fs_args = fs_args ,
53
119
)
54
120
55
121
56
122
@pytest .fixture
57
- def versioned_geojson_dataset (filepath , load_version , save_version ):
58
- return GeoJSONDataset (
59
- filepath = filepath , version = Version (load_version , save_version )
123
+ def abc_dataset (filepath_abc , load_args , save_args , fs_args ):
124
+ return GenericDataset (
125
+ filepath = filepath_abc ,
126
+ file_format = "abc" ,
127
+ load_args = load_args ,
128
+ save_args = save_args ,
129
+ fs_args = fs_args ,
60
130
)
61
131
62
132
63
- class TestGeoJSONDataset :
133
+ @pytest .fixture
134
+ def versioned_geojson_dataset (filepath_geojson , load_version , save_version ):
135
+ return GenericDataset (
136
+ filepath = filepath_geojson , version = Version (load_version , save_version )
137
+ )
138
+
139
+
140
+ class TestGenericDataset :
64
141
def test_save_and_load (self , geojson_dataset , dummy_dataframe ):
65
142
"""Test that saved and reloaded data matches the original one."""
66
143
geojson_dataset .save (dummy_dataframe )
@@ -72,7 +149,7 @@ def test_save_and_load(self, geojson_dataset, dummy_dataframe):
72
149
@pytest .mark .parametrize ("geojson_dataset" , [{"index" : False }], indirect = True )
73
150
def test_load_missing_file (self , geojson_dataset ):
74
151
"""Check the error while trying to load from missing source."""
75
- pattern = r"Failed while loading data from dataset GeoJSONDataset "
152
+ pattern = r"Failed while loading data from dataset GenericDataset "
76
153
with pytest .raises (DatasetError , match = pattern ):
77
154
geojson_dataset .load ()
78
155
@@ -82,6 +159,39 @@ def test_exists(self, geojson_dataset, dummy_dataframe):
82
159
geojson_dataset .save (dummy_dataframe )
83
160
assert geojson_dataset .exists ()
84
161
162
+ def test_load_parquet_dataset (self , parquet_dataset , dummy_dataframe ):
163
+ parquet_dataset .save (dummy_dataframe )
164
+ reloaded_df = parquet_dataset .load ()
165
+ assert_frame_equal (reloaded_df , dummy_dataframe )
166
+
167
+ def test_load_feather_dataset (self , feather_dataset , dummy_dataframe ):
168
+ feather_dataset .save (dummy_dataframe )
169
+ reloaded_df = feather_dataset .load ()
170
+ assert_frame_equal (reloaded_df , dummy_dataframe )
171
+
172
+ def test_bad_load (
173
+ self , parquet_dataset_bad_config , dummy_dataframe , filepath_parquet
174
+ ):
175
+ dummy_dataframe .to_parquet (filepath_parquet )
176
+ pattern = r"Failed while loading data from dataset GenericDataset(.*)"
177
+ with pytest .raises (DatasetError , match = pattern ):
178
+ parquet_dataset_bad_config .load ()
179
+
180
+ def test_none_file_system_target (self , postgis_dataset , dummy_dataframe ):
181
+ pattern = "Cannot load or save a dataset of file_format 'postgis' as it does not support a filepath target/source."
182
+ with pytest .raises (DatasetError , match = pattern ):
183
+ postgis_dataset .save (dummy_dataframe )
184
+
185
+ def test_unknown_file_format (self , abc_dataset , dummy_dataframe , filepath_abc ):
186
+ pattern = "Unable to retrieve 'geopandas.DataFrame.to_abc' method"
187
+ with pytest .raises (DatasetError , match = pattern ):
188
+ abc_dataset .save (dummy_dataframe )
189
+
190
+ filepath_abc .write_bytes (b"" )
191
+ pattern = "Unable to retrieve 'geopandas.read_abc' method"
192
+ with pytest .raises (DatasetError , match = pattern ):
193
+ abc_dataset .load ()
194
+
85
195
@pytest .mark .parametrize (
86
196
"load_args" , [{"crs" : "init:4326" }, {"crs" : "init:2154" , "driver" : "GeoJSON" }]
87
197
)
@@ -118,7 +228,7 @@ def test_open_extra_args(self, geojson_dataset, fs_args):
118
228
],
119
229
)
120
230
def test_protocol_usage (self , path , instance_type ):
121
- geojson_dataset = GeoJSONDataset (filepath = path )
231
+ geojson_dataset = GenericDataset (filepath = path )
122
232
assert isinstance (geojson_dataset ._fs , instance_type )
123
233
124
234
path = path .split (PROTOCOL_DELIMITER , 1 )[- 1 ]
@@ -129,18 +239,18 @@ def test_protocol_usage(self, path, instance_type):
129
239
def test_catalog_release (self , mocker ):
130
240
fs_mock = mocker .patch ("fsspec.filesystem" ).return_value
131
241
filepath = "test.geojson"
132
- geojson_dataset = GeoJSONDataset (filepath = filepath )
242
+ geojson_dataset = GenericDataset (filepath = filepath )
133
243
geojson_dataset .release ()
134
244
fs_mock .invalidate_cache .assert_called_once_with (filepath )
135
245
136
246
137
- class TestGeoJSONDatasetVersioned :
247
+ class TestGenericDatasetVersioned :
138
248
def test_version_str_repr (self , load_version , save_version ):
139
249
"""Test that version is in string representation of the class instance
140
250
when applicable."""
141
251
filepath = "test.geojson"
142
- ds = GeoJSONDataset (filepath = filepath )
143
- ds_versioned = GeoJSONDataset (
252
+ ds = GenericDataset (filepath = filepath )
253
+ ds_versioned = GenericDataset (
144
254
filepath = filepath , version = Version (load_version , save_version )
145
255
)
146
256
assert filepath in str (ds )
@@ -149,8 +259,8 @@ def test_version_str_repr(self, load_version, save_version):
149
259
assert filepath in str (ds_versioned )
150
260
ver_str = f"version=Version(load={ load_version } , save='{ save_version } ')"
151
261
assert ver_str in str (ds_versioned )
152
- assert "GeoJSONDataset " in str (ds_versioned )
153
- assert "GeoJSONDataset " in str (ds )
262
+ assert "GenericDataset " in str (ds_versioned )
263
+ assert "GenericDataset " in str (ds )
154
264
assert "protocol" in str (ds_versioned )
155
265
assert "protocol" in str (ds )
156
266
@@ -163,7 +273,7 @@ def test_save_and_load(self, versioned_geojson_dataset, dummy_dataframe):
163
273
164
274
def test_no_versions (self , versioned_geojson_dataset ):
165
275
"""Check the error if no versions are available for load."""
166
- pattern = r"Did not find any versions for GeoJSONDataset \(.+\)"
276
+ pattern = r"Did not find any versions for GenericDataset \(.+\)"
167
277
with pytest .raises (DatasetError , match = pattern ):
168
278
versioned_geojson_dataset .load ()
169
279
@@ -178,7 +288,7 @@ def test_prevent_override(self, versioned_geojson_dataset, dummy_dataframe):
178
288
version."""
179
289
versioned_geojson_dataset .save (dummy_dataframe )
180
290
pattern = (
181
- r"Save path \'.+\' for GeoJSONDataset \(.+\) must not "
291
+ r"Save path \'.+\' for GenericDataset \(.+\) must not "
182
292
r"exist if versioning is enabled"
183
293
)
184
294
with pytest .raises (DatasetError , match = pattern ):
@@ -197,7 +307,7 @@ def test_save_version_warning(
197
307
the subsequent load path."""
198
308
pattern = (
199
309
rf"Save version '{ save_version } ' did not match load version "
200
- rf"'{ load_version } ' for GeoJSONDataset \(.+\)"
310
+ rf"'{ load_version } ' for GenericDataset \(.+\)"
201
311
)
202
312
with pytest .warns (UserWarning , match = pattern ):
203
313
versioned_geojson_dataset .save (dummy_dataframe )
@@ -206,7 +316,7 @@ def test_http_filesystem_no_versioning(self):
206
316
pattern = "Versioning is not supported for HTTP protocols."
207
317
208
318
with pytest .raises (DatasetError , match = pattern ):
209
- GeoJSONDataset (
319
+ GenericDataset (
210
320
filepath = "https://example/file.geojson" , version = Version (None , None )
211
321
)
212
322
0 commit comments