1
1
#!/usr/bin/env python3
2
+ # fmt: off
2
3
"""👋🌎 🌲🔥
3
4
# Raster clustering
4
5
## Usage
5
6
### Overview
6
7
1. Choose your raster files
7
8
2. Configure nodata and scaling strategies in the `config.toml` file
8
9
3. Choose "number of clusters" or "distance threshold" for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm
10
+ - Start with a distance threshold of 10.0 and decrease for less or increase for more clusters
11
+ - After calibrating the distance threshold;
12
+ - [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option
9
13
14
+ ### Execution
10
15
```bash
11
- source pyqgisdev/bin/activate # activate your qgis dev environment
16
+ # get command line help
17
+ python -m fire2a.agglomerative_clustering -h
18
+ python -m fire2a.agglomerative_clustering --help
19
+
20
+ # activate your qgis dev environment
21
+ source ~/pyqgisdev/bin/activate
22
+ # execute
12
23
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
13
24
14
- # windows💩
25
+ # windows💩 users should use QGIS's python
15
26
C:\\ PROGRA~1\\ QGIS33~1.3\\ bin\\ python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
16
-
17
- # check help
18
- python agglomerative_clustering_pipeline.py -h
19
27
```
20
- [how to: windows 💩 use qgis- python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
28
+ [More info on: How to windows 💩 using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
21
29
22
- ### 1. Choose your raster files
30
+ ### Preparation
31
+ #### 1. Choose your raster files
23
32
- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read
24
33
- Place them all in the same directory where the script will be executed
25
34
- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
26
35
27
- ### 2. Preprocessing configuration
36
+ #### 2. Preprocessing configuration
28
37
See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows:
29
38
30
39
```toml
54
63
- [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
55
64
56
65
57
- ### 3. Clustering configuration
66
+ #### 3. Clustering configuration
58
67
1. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
59
68
- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate.
60
69
- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
61
70
62
71
For passing more parameters, see [here](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html)
63
-
64
72
"""
73
+ # fmt: on
74
+ # from IPython.terminal.embed import InteractiveShellEmbed
75
+ # InteractiveShellEmbed()()
65
76
import logging
66
77
import sys
67
78
from pathlib import Path
75
86
from sklearn .pipeline import Pipeline
76
87
from sklearn .preprocessing import OneHotEncoder , RobustScaler , StandardScaler
77
88
89
+ from fire2a .utils import fprint
90
+
78
91
logger = logging .getLogger (__name__ )
79
92
80
93
@@ -209,9 +222,8 @@ def pipelie(observations, info_list, height, width, **kwargs):
209
222
210
223
# Get the neighbors of each cell in a 2D grid
211
224
grid_points = np .indices ((height , width )).reshape (2 , - 1 ).T
212
- connectivity = radius_neighbors_graph (
213
- grid_points , radius = 2 ** (1 / 2 ), metric = "euclidean" , include_self = False , n_jobs = - 1
214
- )
225
+ # grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
226
+ connectivity = radius_neighbors_graph (grid_points , radius = 1 , metric = "manhattan" , include_self = False , n_jobs = - 1 )
215
227
216
228
# Create the clustering object
217
229
clustering = AgglomerativeClustering (connectivity = connectivity , ** kwargs )
@@ -235,8 +247,107 @@ def pipelie(observations, info_list, height, width, **kwargs):
235
247
return labels_reshaped , pipeline
236
248
237
249
238
- def postprocess (labels_reshaped , pipeline , data_list , info_list , width , height , args ):
250
+ def write (
251
+ label_map ,
252
+ width ,
253
+ height ,
254
+ output_raster = "" ,
255
+ output_poly = "output.shp" ,
256
+ authid = "EPSG:3857" ,
257
+ geotransform = (0 , 1 , 0 , 0 , 0 , 1 ),
258
+ nodata = None ,
259
+ feedback = None ,
260
+ ):
261
+ from osgeo import gdal , ogr , osr
262
+
263
+ from fire2a .processing_utils import get_output_raster_format , get_vector_driver_from_filename
264
+
265
+ # setup drivers for raster and polygon output formats
266
+ if output_raster == "" :
267
+ raster_driver = "MEM"
268
+ else :
269
+ try :
270
+ raster_driver = get_output_raster_format (output_raster , feedback = feedback )
271
+ except Exception :
272
+ raster_driver = "GTiff"
273
+ try :
274
+ poly_driver = get_vector_driver_from_filename (output_poly )
275
+ except Exception :
276
+ poly_driver = "ESRI Shapefile"
277
+
278
+ # create raster output
279
+ src_ds = gdal .GetDriverByName (raster_driver ).Create (output_raster , width , height , 1 , gdal .GDT_Int64 )
280
+ src_ds .SetGeoTransform (geotransform ) # != 0 ?
281
+ src_ds .SetProjection (authid ) # != 0 ?
282
+ # src_band = src_ds.GetRasterBand(1)
283
+ # if nodata:
284
+ # src_band.SetNoDataValue(nodata)
285
+ # src_band.WriteArray(label_map)
286
+
287
+ # create polygon output
288
+ drv = ogr .GetDriverByName (poly_driver )
289
+ dst_ds = drv .CreateDataSource (output_poly )
290
+ sp_ref = osr .SpatialReference ()
291
+ sp_ref .SetFromUserInput (authid ) # != 0 ?
292
+ dst_lyr = dst_ds .CreateLayer ("clusters" , srs = sp_ref , geom_type = ogr .wkbPolygon )
293
+ dst_lyr .CreateField (ogr .FieldDefn ("DN" , ogr .OFTInteger )) # != 0 ?
294
+ dst_lyr .CreateField (ogr .FieldDefn ("area" , ogr .OFTInteger ))
295
+ dst_lyr .CreateField (ogr .FieldDefn ("perimeter" , ogr .OFTInteger ))
296
+
297
+ # != 0 ?
298
+ # gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
299
+
300
+ # A todo junto
301
+ # src_band = src_ds.GetRasterBand(1)
302
+ # if nodata:
303
+ # src_band.SetNoDataValue(nodata)
304
+ # src_band.WriteArray(label_map)
305
+ # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"])
306
+
307
+ # B separado
308
+ # for loop for creating each label_map value into a different polygonized feature
309
+ mem_drv = ogr .GetDriverByName ("Memory" )
310
+ tmp_ds = mem_drv .CreateDataSource ("tmp_ds" )
311
+ # itera = iter(np.unique(label_map))
312
+ # cluster_id = next(itera)
313
+ areas = []
314
+ for cluster_id in np .unique (label_map ):
315
+ # temporarily write band
316
+ src_band = src_ds .GetRasterBand (1 )
317
+ data = np .zeros_like (label_map )
318
+ data -= 1 # labels in 0..NC
319
+ data [label_map == cluster_id ] = label_map [label_map == cluster_id ]
320
+ src_band .WriteArray (data )
321
+ # create feature
322
+ tmp_lyr = tmp_ds .CreateLayer ("" , srs = sp_ref )
323
+ gdal .Polygonize (src_band , None , tmp_lyr , - 1 )
324
+ # set
325
+ feat = tmp_lyr .GetNextFeature ()
326
+ geom = feat .GetGeometryRef ()
327
+ featureDefn = dst_lyr .GetLayerDefn ()
328
+ feature = ogr .Feature (featureDefn )
329
+ feature .SetGeometry (geom )
330
+ feature .SetField ("DN" , int (cluster_id ))
331
+ areas += [geom .GetArea ()]
332
+ feature .SetField ("area" , int (geom .GetArea ()))
333
+ feature .SetField ("perimeter" , int (geom .Boundary ().Length ()))
334
+ dst_lyr .CreateFeature (feature )
335
+
336
+ fprint (f"Clusters: { min (areas )= } { max (areas )= } " , level = "info" , feedback = feedback , logger = logger )
337
+ # fix temporarily written band
338
+ src_band = src_ds .GetRasterBand (1 )
339
+ if nodata :
340
+ src_band .SetNoDataValue (nodata )
341
+ src_band .WriteArray (label_map )
342
+ # close datasets
343
+ src_ds .FlushCache ()
344
+ src_ds = None
345
+ dst_ds .FlushCache ()
346
+ dst_ds = None
347
+ return True
348
+
239
349
350
+ def postprocess (labels_reshaped , pipeline , data_list , info_list , width , height , args ):
240
351
# trick to plot
241
352
effective_num_clusters = len (np .unique (labels_reshaped ))
242
353
@@ -313,6 +424,56 @@ def plot(data_list, info_list):
313
424
plt .show ()
314
425
315
426
427
+ def sieve_filter (data , threshold = 2 , connectedness = 4 , feedback = None ):
428
+ """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
429
+ Args:
430
+ data (np.ndarray): The input data to filter
431
+ threshold (int): The maximum number of pixels in a cluster to keep
432
+ connectedness (int): The number of connected pixels to consider when filtering 4 or 8
433
+ feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
434
+ Returns:
435
+ np.ndarray: The filtered data
436
+ """
437
+ logger .info ("Applying sieve filter" )
438
+ from osgeo import gdal
439
+
440
+ height , width = data .shape
441
+ # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
442
+ num_clusters = len (np .unique (data ))
443
+ src_ds = gdal .GetDriverByName ("MEM" ).Create ("sieve" , width , height , 1 , gdal .GDT_Int64 )
444
+ src_band = src_ds .GetRasterBand (1 )
445
+ src_band .WriteArray (data )
446
+ if 0 != gdal .SieveFilter (src_band , None , src_band , threshold , connectedness ):
447
+ fprint ("Error applying sieve filter" , level = "error" , feedback = feedback , logger = logger )
448
+ else :
449
+ sieved = src_band .ReadAsArray ()
450
+ src_band = None
451
+ src_ds = None
452
+ num_sieved = len (np .unique (sieved ))
453
+ # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
454
+ fprint (
455
+ f"Reduced from { num_clusters } to { num_sieved } clusters, { num_clusters - num_sieved } less" ,
456
+ level = "info" ,
457
+ feedback = feedback ,
458
+ logger = logger ,
459
+ )
460
+ fprint (
461
+ "Please try again increasing distance_threshold or reducing n_clusters instead..." ,
462
+ level = "info" ,
463
+ feedback = feedback ,
464
+ logger = logger ,
465
+ )
466
+ # from matplotlib import pyplot as plt
467
+ # fig, (ax1, ax2) = plt.subplots(1, 2)
468
+ # ax1.imshow(data)
469
+ # ax1.set_title("before sieve" + str(len(np.unique(data))))
470
+ # ax2.imshow(sieved)
471
+ # ax2.set_title("after sieve" + str(len(np.unique(sieved))))
472
+ # plt.show()
473
+ # data = sieved
474
+ return sieved
475
+
476
+
316
477
def read_toml (config_toml = "config.toml" ):
317
478
if sys .version_info >= (3 , 11 ):
318
479
import tomllib
@@ -352,7 +513,8 @@ def arg_parser(argv=None):
352
513
)
353
514
aggclu .add_argument ("-n" , "--n_clusters" , type = int , help = "Number of clusters" )
354
515
355
- parser .add_argument ("-o" , "--output" , help = "Output raster file, warning overwrites!" , default = "output.tif" )
516
+ parser .add_argument ("-or" , "--output_raster" , help = "Output raster file, warning overwrites!" , default = "" )
517
+ parser .add_argument ("-op" , "--output_poly" , help = "Output polygons file, warning overwrites!" , default = "output.gpkg" )
356
518
parser .add_argument ("-a" , "--authid" , type = str , help = "Output raster authid" , default = "EPSG:3857" )
357
519
parser .add_argument (
358
520
"-g" , "--geotransform" , type = str , help = "Output raster geotransform" , default = "(0, 1, 0, 0, 0, 1)"
@@ -361,7 +523,7 @@ def arg_parser(argv=None):
361
523
"-nw" ,
362
524
"--no_write" ,
363
525
action = "store_true" ,
364
- help = "Do not write output raster" ,
526
+ help = "Do not write outputs raster nor polygons " ,
365
527
default = False ,
366
528
)
367
529
parser .add_argument (
@@ -371,6 +533,12 @@ def arg_parser(argv=None):
371
533
help = "Run in script mode, returning the label_map and the pipeline object" ,
372
534
default = False ,
373
535
)
536
+ parser .add_argument (
537
+ "--sieve" ,
538
+ type = int ,
539
+ help = "Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor" ,
540
+ )
541
+ parser .add_argument ("--verbose" , "-v" , action = "count" , default = 0 , help = "WARNING:1, INFO:2, DEBUG:3" )
374
542
args = parser .parse_args (argv )
375
543
args .geotransform = tuple (map (float , args .geotransform [1 :- 1 ].split ("," )))
376
544
if Path (args .config_file ).is_file () is False :
@@ -390,8 +558,13 @@ def main(argv=None):
390
558
argv = sys .argv [1 :]
391
559
args = arg_parser (argv )
392
560
393
- # 1 LEE ARGUMENTOS
394
- logger .debug (args )
561
+ if args .verbose != 0 :
562
+ global logger
563
+ from fire2a import setup_logger
564
+
565
+ logger = setup_logger (verbosity = args .verbose )
566
+
567
+ logger .info ("args %s" , args )
395
568
396
569
# 2 LEE CONFIG
397
570
config = read_toml (args .config_file )
@@ -436,15 +609,23 @@ def main(argv=None):
436
609
distance_threshold = args .distance_threshold ,
437
610
)
438
611
612
+ # SIEVE
613
+ if args .sieve :
614
+ labels_reshaped = sieve_filter (labels_reshaped , args .sieve )
615
+
439
616
# 7 debug postprocess
440
- postprocess (labels_reshaped , pipeline , data_list , info_list , width , height , args )
617
+ # postprocess(labels_reshaped, pipeline, data_list, info_list, width, height, args)
441
618
442
619
# 8. ESCRIBIR RASTER
443
620
if not args .no_write :
444
- from fire2a .raster import write_raster
445
-
446
- if not write_raster (
447
- labels_reshaped , outfile = args .output , authid = args .authid , geotransform = args .geotransform , logger = logger
621
+ if not write (
622
+ labels_reshaped ,
623
+ width ,
624
+ height ,
625
+ output_raster = args .output_raster ,
626
+ output_poly = args .output_poly ,
627
+ authid = args .authid ,
628
+ geotransform = args .geotransform ,
448
629
):
449
630
logger .error ("Error writing output raster" )
450
631
0 commit comments