Skip to content

Commit 988df06

Browse files
committed
extended tests to l2a product level
1 parent 1a1c8e8 commit 988df06

File tree

5 files changed

+167
-34
lines changed

5 files changed

+167
-34
lines changed

tests/test_mask.py

+67-34
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ukis_csmask.utils import reclassify, cohen_kappa_score
1212

1313

14+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
15+
16+
1417
@pytest.mark.parametrize(
1518
"img, band_order, nodata_value",
1619
[
@@ -22,7 +25,7 @@
2225
],
2326
)
2427
def test_csmask_init(img, band_order, nodata_value):
25-
CSmask(img=img, band_order=band_order, nodata_value=nodata_value)
28+
CSmask(img=img, band_order=band_order, nodata_value=nodata_value, providers=providers)
2629

2730

2831
@pytest.mark.parametrize(
@@ -38,7 +41,7 @@ def test_csmask_init(img, band_order, nodata_value):
3841
)
3942
def test_csmask_init_raises(img, band_order, nodata_value):
4043
with pytest.raises(TypeError):
41-
CSmask(img=img, band_order=band_order, nodata_value=nodata_value)
44+
CSmask(img=img, band_order=band_order, nodata_value=nodata_value, providers=providers)
4245

4346

4447
@pytest.mark.parametrize(
@@ -50,83 +53,113 @@ def test_csmask_init_raises(img, band_order, nodata_value):
5053
)
5154
def test_csmask_init_warns(img, band_order, nodata_value):
5255
with pytest.warns(UserWarning):
53-
CSmask(img=img, band_order=band_order, nodata_value=nodata_value)
56+
CSmask(img=img, band_order=band_order, nodata_value=nodata_value, providers=providers)
5457

5558

5659
@pytest.mark.filterwarnings("ignore::UserWarning")
5760
@pytest.mark.parametrize(
58-
"data",
61+
"data, product_level",
5962
[
60-
np.load(r"tests/testfiles/sentinel2.npz"),
61-
np.load(r"tests/testfiles/landsat8.npz"),
62-
np.load(r"tests/testfiles/landsat7.npz"),
63-
np.load(r"tests/testfiles/landsat5.npz"),
63+
(np.load(r"tests/testfiles/sentinel2.npz"), "l1c"),
64+
(np.load(r"tests/testfiles/landsat8.npz"), "l1c"),
65+
(np.load(r"tests/testfiles/landsat7.npz"), "l1c"),
66+
(np.load(r"tests/testfiles/landsat5.npz"), "l1c"),
67+
(np.load(r"tests/testfiles/sentinel2.npz"), "l2a"),
68+
(np.load(r"tests/testfiles/landsat8.npz"), "l2a"),
69+
(np.load(r"tests/testfiles/landsat7.npz"), "l2a"),
70+
(np.load(r"tests/testfiles/landsat5.npz"), "l2a"),
6471
],
6572
)
66-
def test_csmask_csm_6band(data):
67-
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR", "SWIR16", "SWIR22"])
73+
def test_csmask_csm_6band(data, product_level):
74+
csmask = CSmask(
75+
img=data["img"],
76+
product_level=product_level,
77+
band_order=["Blue", "Green", "Red", "NIR", "SWIR16", "SWIR22"],
78+
providers=providers,
79+
)
6880
y_pred = csmask.csm
6981
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [2, 0, 0, 0, 1]})
7082
y_true = y_true.ravel()
7183
y_pred = y_pred.ravel()
7284
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
73-
assert kappa >= 0.75
85+
assert kappa >= 0.70
7486

7587

7688
@pytest.mark.filterwarnings("ignore::UserWarning")
7789
@pytest.mark.parametrize(
78-
"data",
90+
"data, product_level",
7991
[
80-
np.load(r"tests/testfiles/sentinel2.npz"),
81-
np.load(r"tests/testfiles/landsat8.npz"),
82-
np.load(r"tests/testfiles/landsat7.npz"),
83-
np.load(r"tests/testfiles/landsat5.npz"),
92+
(np.load(r"tests/testfiles/sentinel2.npz"), "l1c"),
93+
(np.load(r"tests/testfiles/landsat8.npz"), "l1c"),
94+
(np.load(r"tests/testfiles/landsat7.npz"), "l1c"),
95+
(np.load(r"tests/testfiles/landsat5.npz"), "l1c"),
96+
(np.load(r"tests/testfiles/sentinel2.npz"), "l2a"),
97+
(np.load(r"tests/testfiles/landsat8.npz"), "l2a"),
98+
(np.load(r"tests/testfiles/landsat7.npz"), "l2a"),
99+
(np.load(r"tests/testfiles/landsat5.npz"), "l2a"),
84100
],
85101
)
86-
def test_csmask_valid_6band(data):
87-
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR", "SWIR16", "SWIR22"])
102+
def test_csmask_valid_6band(data, product_level):
103+
csmask = CSmask(
104+
img=data["img"],
105+
product_level=product_level,
106+
band_order=["Blue", "Green", "Red", "NIR", "SWIR16", "SWIR22"],
107+
providers=providers,
108+
)
88109
y_pred = csmask.valid
89110
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [0, 1, 1, 1, 0]})
90111
y_true_inverted = ~y_true.astype(bool)
91112
y_true = (~ndimage.binary_dilation(y_true_inverted, iterations=4).astype(bool)).astype(np.uint8)
92113
y_true = y_true.ravel()
93114
y_pred = y_pred.ravel()
94115
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
95-
assert kappa >= 0.75
116+
assert kappa >= 0.70
96117

97118

98119
@pytest.mark.filterwarnings("ignore::UserWarning")
99120
@pytest.mark.parametrize(
100-
"data",
121+
"data, product_level",
101122
[
102-
np.load(r"tests/testfiles/sentinel2.npz"),
103-
np.load(r"tests/testfiles/landsat8.npz"),
104-
np.load(r"tests/testfiles/landsat7.npz"),
105-
np.load(r"tests/testfiles/landsat5.npz"),
123+
(np.load(r"tests/testfiles/sentinel2.npz"), "l1c"),
124+
(np.load(r"tests/testfiles/landsat8.npz"), "l1c"),
125+
(np.load(r"tests/testfiles/landsat7.npz"), "l1c"),
126+
(np.load(r"tests/testfiles/landsat5.npz"), "l1c"),
127+
(np.load(r"tests/testfiles/sentinel2.npz"), "l2a"),
128+
(np.load(r"tests/testfiles/landsat8.npz"), "l2a"),
129+
(np.load(r"tests/testfiles/landsat7.npz"), "l2a"),
130+
(np.load(r"tests/testfiles/landsat5.npz"), "l2a"),
106131
],
107132
)
108-
def test_csmask_csm_4band(data):
109-
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR"])
133+
def test_csmask_csm_4band(data, product_level):
134+
csmask = CSmask(
135+
img=data["img"], product_level=product_level, band_order=["Blue", "Green", "Red", "NIR"], providers=providers
136+
)
110137
y_pred = csmask.csm
111138
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [2, 0, 0, 0, 1]})
112139
y_true = y_true.ravel()
113140
y_pred = y_pred.ravel()
114141
kappa = round(cohen_kappa_score(y_true, y_pred), 2)
115-
assert kappa >= 0.50
142+
assert kappa >= 0.70
116143

117144

118145
@pytest.mark.filterwarnings("ignore::UserWarning")
119146
@pytest.mark.parametrize(
120-
"data",
147+
"data, product_level",
121148
[
122-
np.load(r"tests/testfiles/sentinel2.npz"),
123-
np.load(r"tests/testfiles/landsat8.npz"),
124-
np.load(r"tests/testfiles/landsat7.npz"),
125-
np.load(r"tests/testfiles/landsat5.npz"),
149+
(np.load(r"tests/testfiles/sentinel2.npz"), "l1c"),
150+
(np.load(r"tests/testfiles/landsat8.npz"), "l1c"),
151+
(np.load(r"tests/testfiles/landsat7.npz"), "l1c"),
152+
(np.load(r"tests/testfiles/landsat5.npz"), "l1c"),
153+
(np.load(r"tests/testfiles/sentinel2.npz"), "l2a"),
154+
(np.load(r"tests/testfiles/landsat8.npz"), "l2a"),
155+
(np.load(r"tests/testfiles/landsat7.npz"), "l2a"),
156+
(np.load(r"tests/testfiles/landsat5.npz"), "l2a"),
126157
],
127158
)
128-
def test_csmask_valid_4band(data):
129-
csmask = CSmask(img=data["img"], band_order=["Blue", "Green", "Red", "NIR"])
159+
def test_csmask_valid_4band(data, product_level):
160+
csmask = CSmask(
161+
img=data["img"], product_level=product_level, band_order=["Blue", "Green", "Red", "NIR"], providers=providers
162+
)
130163
y_pred = csmask.valid
131164
y_true = reclassify(data["msk"], {"reclass_value_from": [0, 1, 2, 3, 4], "reclass_value_to": [0, 1, 1, 1, 0]})
132165
y_true_inverted = ~y_true.astype(bool)

ukis_csmask/model_4b_l1c.json

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"data": {
3+
"in_channels": 4,
4+
"classes": 3,
5+
"class_weights": [
6+
0.45997971296310425,
7+
1.6431068181991577,
8+
4.600073337554932
9+
],
10+
"dataset_statistics": {
11+
"img": [
12+
[
13+
0.27552253007888794,
14+
0.2556644380092621,
15+
0.269801527261734,
16+
0.33431220054626465
17+
],
18+
[
19+
0.24488002061843872,
20+
0.22757470607757568,
21+
0.2537152171134949,
22+
0.23213787376880646
23+
]
24+
],
25+
"band_names": [
26+
"blue",
27+
"green",
28+
"red",
29+
"nir"
30+
],
31+
"cnt_classes": [
32+
168881375,
33+
47277517,
34+
16887124
35+
]
36+
},
37+
"target_size": [
38+
256,
39+
256
40+
]
41+
},
42+
"model": {
43+
"version": "1.0.0-4BL1C",
44+
"decoder_name": "unet",
45+
"encoder_name": "efficientnet-b4",
46+
"encoder_weights": "imagenet",
47+
"loss_fn": "celovasz",
48+
"multi-task": false
49+
}
50+
}

ukis_csmask/model_4b_l1c.onnx

37.8 MB
Binary file not shown.

ukis_csmask/model_4b_l2a.json

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"data": {
3+
"in_channels": 4,
4+
"classes": 3,
5+
"class_weights": [
6+
0.4578048586845398,
7+
1.677311658859253,
8+
4.556408882141113
9+
],
10+
"dataset_statistics": {
11+
"img": [
12+
[
13+
0.2697608172893524,
14+
0.2809424102306366,
15+
0.2910798490047455,
16+
0.3679029643535614
17+
],
18+
[
19+
0.33664360642433167,
20+
0.312863826751709,
21+
0.3114243149757385,
22+
0.2728745937347412
23+
]
24+
],
25+
"band_names": [
26+
"blue",
27+
"green",
28+
"red",
29+
"nir"
30+
],
31+
"cnt_classes": [
32+
174455417,
33+
47615803,
34+
17528396
35+
]
36+
},
37+
"target_size": [
38+
256,
39+
256
40+
]
41+
},
42+
"model": {
43+
"version": "1.0.0-4BL2A",
44+
"decoder_name": "unet",
45+
"encoder_name": "efficientnet-b4",
46+
"encoder_weights": "imagenet",
47+
"loss_fn": "celovasz",
48+
"multi-task": false
49+
}
50+
}

ukis_csmask/model_4b_l2a.onnx

37.8 MB
Binary file not shown.

0 commit comments

Comments
 (0)