11
11
from ukis_csmask .utils import reclassify , cohen_kappa_score
12
12
13
13
14
+ providers = ["CUDAExecutionProvider" , "CPUExecutionProvider" ]
15
+
16
+
14
17
@pytest .mark .parametrize (
15
18
"img, band_order, nodata_value" ,
16
19
[
22
25
],
23
26
)
24
27
def test_csmask_init (img , band_order , nodata_value ):
25
- CSmask (img = img , band_order = band_order , nodata_value = nodata_value )
28
+ CSmask (img = img , band_order = band_order , nodata_value = nodata_value , providers = providers )
26
29
27
30
28
31
@pytest .mark .parametrize (
@@ -38,7 +41,7 @@ def test_csmask_init(img, band_order, nodata_value):
38
41
)
39
42
def test_csmask_init_raises (img , band_order , nodata_value ):
40
43
with pytest .raises (TypeError ):
41
- CSmask (img = img , band_order = band_order , nodata_value = nodata_value )
44
+ CSmask (img = img , band_order = band_order , nodata_value = nodata_value , providers = providers )
42
45
43
46
44
47
@pytest .mark .parametrize (
@@ -50,83 +53,113 @@ def test_csmask_init_raises(img, band_order, nodata_value):
50
53
)
51
54
def test_csmask_init_warns (img , band_order , nodata_value ):
52
55
with pytest .warns (UserWarning ):
53
- CSmask (img = img , band_order = band_order , nodata_value = nodata_value )
56
+ CSmask (img = img , band_order = band_order , nodata_value = nodata_value , providers = providers )
54
57
55
58
56
59
@pytest .mark .filterwarnings ("ignore::UserWarning" )
57
60
@pytest .mark .parametrize (
58
- "data" ,
61
+ "data, product_level " ,
59
62
[
60
- np .load (r"tests/testfiles/sentinel2.npz" ),
61
- np .load (r"tests/testfiles/landsat8.npz" ),
62
- np .load (r"tests/testfiles/landsat7.npz" ),
63
- np .load (r"tests/testfiles/landsat5.npz" ),
63
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l1c" ),
64
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l1c" ),
65
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l1c" ),
66
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l1c" ),
67
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l2a" ),
68
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l2a" ),
69
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l2a" ),
70
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l2a" ),
64
71
],
65
72
)
66
- def test_csmask_csm_6band (data ):
67
- csmask = CSmask (img = data ["img" ], band_order = ["Blue" , "Green" , "Red" , "NIR" , "SWIR16" , "SWIR22" ])
73
+ def test_csmask_csm_6band (data , product_level ):
74
+ csmask = CSmask (
75
+ img = data ["img" ],
76
+ product_level = product_level ,
77
+ band_order = ["Blue" , "Green" , "Red" , "NIR" , "SWIR16" , "SWIR22" ],
78
+ providers = providers ,
79
+ )
68
80
y_pred = csmask .csm
69
81
y_true = reclassify (data ["msk" ], {"reclass_value_from" : [0 , 1 , 2 , 3 , 4 ], "reclass_value_to" : [2 , 0 , 0 , 0 , 1 ]})
70
82
y_true = y_true .ravel ()
71
83
y_pred = y_pred .ravel ()
72
84
kappa = round (cohen_kappa_score (y_true , y_pred ), 2 )
73
- assert kappa >= 0.75
85
+ assert kappa >= 0.70
74
86
75
87
76
88
@pytest .mark .filterwarnings ("ignore::UserWarning" )
77
89
@pytest .mark .parametrize (
78
- "data" ,
90
+ "data, product_level " ,
79
91
[
80
- np .load (r"tests/testfiles/sentinel2.npz" ),
81
- np .load (r"tests/testfiles/landsat8.npz" ),
82
- np .load (r"tests/testfiles/landsat7.npz" ),
83
- np .load (r"tests/testfiles/landsat5.npz" ),
92
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l1c" ),
93
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l1c" ),
94
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l1c" ),
95
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l1c" ),
96
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l2a" ),
97
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l2a" ),
98
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l2a" ),
99
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l2a" ),
84
100
],
85
101
)
86
- def test_csmask_valid_6band (data ):
87
- csmask = CSmask (img = data ["img" ], band_order = ["Blue" , "Green" , "Red" , "NIR" , "SWIR16" , "SWIR22" ])
102
+ def test_csmask_valid_6band (data , product_level ):
103
+ csmask = CSmask (
104
+ img = data ["img" ],
105
+ product_level = product_level ,
106
+ band_order = ["Blue" , "Green" , "Red" , "NIR" , "SWIR16" , "SWIR22" ],
107
+ providers = providers ,
108
+ )
88
109
y_pred = csmask .valid
89
110
y_true = reclassify (data ["msk" ], {"reclass_value_from" : [0 , 1 , 2 , 3 , 4 ], "reclass_value_to" : [0 , 1 , 1 , 1 , 0 ]})
90
111
y_true_inverted = ~ y_true .astype (bool )
91
112
y_true = (~ ndimage .binary_dilation (y_true_inverted , iterations = 4 ).astype (bool )).astype (np .uint8 )
92
113
y_true = y_true .ravel ()
93
114
y_pred = y_pred .ravel ()
94
115
kappa = round (cohen_kappa_score (y_true , y_pred ), 2 )
95
- assert kappa >= 0.75
116
+ assert kappa >= 0.70
96
117
97
118
98
119
@pytest .mark .filterwarnings ("ignore::UserWarning" )
99
120
@pytest .mark .parametrize (
100
- "data" ,
121
+ "data, product_level " ,
101
122
[
102
- np .load (r"tests/testfiles/sentinel2.npz" ),
103
- np .load (r"tests/testfiles/landsat8.npz" ),
104
- np .load (r"tests/testfiles/landsat7.npz" ),
105
- np .load (r"tests/testfiles/landsat5.npz" ),
123
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l1c" ),
124
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l1c" ),
125
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l1c" ),
126
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l1c" ),
127
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l2a" ),
128
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l2a" ),
129
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l2a" ),
130
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l2a" ),
106
131
],
107
132
)
108
- def test_csmask_csm_4band (data ):
109
- csmask = CSmask (img = data ["img" ], band_order = ["Blue" , "Green" , "Red" , "NIR" ])
133
+ def test_csmask_csm_4band (data , product_level ):
134
+ csmask = CSmask (
135
+ img = data ["img" ], product_level = product_level , band_order = ["Blue" , "Green" , "Red" , "NIR" ], providers = providers
136
+ )
110
137
y_pred = csmask .csm
111
138
y_true = reclassify (data ["msk" ], {"reclass_value_from" : [0 , 1 , 2 , 3 , 4 ], "reclass_value_to" : [2 , 0 , 0 , 0 , 1 ]})
112
139
y_true = y_true .ravel ()
113
140
y_pred = y_pred .ravel ()
114
141
kappa = round (cohen_kappa_score (y_true , y_pred ), 2 )
115
- assert kappa >= 0.50
142
+ assert kappa >= 0.70
116
143
117
144
118
145
@pytest .mark .filterwarnings ("ignore::UserWarning" )
119
146
@pytest .mark .parametrize (
120
- "data" ,
147
+ "data, product_level " ,
121
148
[
122
- np .load (r"tests/testfiles/sentinel2.npz" ),
123
- np .load (r"tests/testfiles/landsat8.npz" ),
124
- np .load (r"tests/testfiles/landsat7.npz" ),
125
- np .load (r"tests/testfiles/landsat5.npz" ),
149
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l1c" ),
150
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l1c" ),
151
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l1c" ),
152
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l1c" ),
153
+ (np .load (r"tests/testfiles/sentinel2.npz" ), "l2a" ),
154
+ (np .load (r"tests/testfiles/landsat8.npz" ), "l2a" ),
155
+ (np .load (r"tests/testfiles/landsat7.npz" ), "l2a" ),
156
+ (np .load (r"tests/testfiles/landsat5.npz" ), "l2a" ),
126
157
],
127
158
)
128
- def test_csmask_valid_4band (data ):
129
- csmask = CSmask (img = data ["img" ], band_order = ["Blue" , "Green" , "Red" , "NIR" ])
159
+ def test_csmask_valid_4band (data , product_level ):
160
+ csmask = CSmask (
161
+ img = data ["img" ], product_level = product_level , band_order = ["Blue" , "Green" , "Red" , "NIR" ], providers = providers
162
+ )
130
163
y_pred = csmask .valid
131
164
y_true = reclassify (data ["msk" ], {"reclass_value_from" : [0 , 1 , 2 , 3 , 4 ], "reclass_value_to" : [0 , 1 , 1 , 1 , 0 ]})
132
165
y_true_inverted = ~ y_true .astype (bool )
0 commit comments