Skip to content

Commit 1c3e341

Browse files
authored
Merge pull request #20 from fire2a/cluster-polygonize
Cluster polygonize
2 parents c1c9e63 + f65267f commit 1c3e341

File tree

2 files changed

+210
-26
lines changed

2 files changed

+210
-26
lines changed

src/fire2a/agglomerative_clustering.py

+204-23
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
#!/usr/bin/env python3
2+
# fmt: off
23
"""👋🌎 🌲🔥
34
# Raster clustering
45
## Usage
56
### Overview
67
1. Choose your raster files
78
2. Configure nodata and scaling strategies in the `config.toml` file
89
3. Choose "number of clusters" or "distance threshold" for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm
10+
- Start with a distance threshold of 10.0 and decrease for less or increase for more clusters
11+
- After calibrating the distance threshold;
12+
- [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option
913
14+
### Execution
1015
```bash
11-
source pyqgisdev/bin/activate # activate your qgis dev environment
16+
# get command line help
17+
python -m fire2a.agglomerative_clustering -h
18+
python -m fire2a.agglomerative_clustering --help
19+
20+
# activate your qgis dev environment
21+
source ~/pyqgisdev/bin/activate
22+
# execute
1223
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
1324
14-
# windows💩
25+
# windows💩 users should use QGIS's python
1526
C:\\PROGRA~1\\QGIS33~1.3\\bin\\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
16-
17-
# check help
18-
python agglomerative_clustering_pipeline.py -h
1927
```
20-
[how to: windows 💩 use qgis-python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
28+
[More info on: How to windows 💩 using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
2129
22-
### 1. Choose your raster files
30+
### Preparation
31+
#### 1. Choose your raster files
2332
- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read
2433
- Place them all in the same directory where the script will be executed
2534
- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
2635
27-
### 2. Preprocessing configuration
36+
#### 2. Preprocessing configuration
2837
See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows:
2938
3039
```toml
@@ -54,14 +63,16 @@
5463
- [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
5564
5665
57-
### 3. Clustering configuration
66+
#### 3. Clustering configuration
5867
1. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
5968
- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate.
6069
- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
6170
6271
For passing more parameters, see [here](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html)
63-
6472
"""
73+
# fmt: on
74+
# from IPython.terminal.embed import InteractiveShellEmbed
75+
# InteractiveShellEmbed()()
6576
import logging
6677
import sys
6778
from pathlib import Path
@@ -75,6 +86,8 @@
7586
from sklearn.pipeline import Pipeline
7687
from sklearn.preprocessing import OneHotEncoder, RobustScaler, StandardScaler
7788

89+
from fire2a.utils import fprint
90+
7891
logger = logging.getLogger(__name__)
7992

8093

@@ -209,9 +222,8 @@ def pipelie(observations, info_list, height, width, **kwargs):
209222

210223
# Get the neighbors of each cell in a 2D grid
211224
grid_points = np.indices((height, width)).reshape(2, -1).T
212-
connectivity = radius_neighbors_graph(
213-
grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
214-
)
225+
# grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
226+
connectivity = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1)
215227

216228
# Create the clustering object
217229
clustering = AgglomerativeClustering(connectivity=connectivity, **kwargs)
@@ -235,8 +247,107 @@ def pipelie(observations, info_list, height, width, **kwargs):
235247
return labels_reshaped, pipeline
236248

237249

238-
def postprocess(labels_reshaped, pipeline, data_list, info_list, width, height, args):
250+
def write(
251+
label_map,
252+
width,
253+
height,
254+
output_raster="",
255+
output_poly="output.shp",
256+
authid="EPSG:3857",
257+
geotransform=(0, 1, 0, 0, 0, 1),
258+
nodata=None,
259+
feedback=None,
260+
):
261+
from osgeo import gdal, ogr, osr
262+
263+
from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename
264+
265+
# setup drivers for raster and polygon output formats
266+
if output_raster == "":
267+
raster_driver = "MEM"
268+
else:
269+
try:
270+
raster_driver = get_output_raster_format(output_raster, feedback=feedback)
271+
except Exception:
272+
raster_driver = "GTiff"
273+
try:
274+
poly_driver = get_vector_driver_from_filename(output_poly)
275+
except Exception:
276+
poly_driver = "ESRI Shapefile"
277+
278+
# create raster output
279+
src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64)
280+
src_ds.SetGeoTransform(geotransform) # != 0 ?
281+
src_ds.SetProjection(authid) # != 0 ?
282+
# src_band = src_ds.GetRasterBand(1)
283+
# if nodata:
284+
# src_band.SetNoDataValue(nodata)
285+
# src_band.WriteArray(label_map)
286+
287+
# create polygon output
288+
drv = ogr.GetDriverByName(poly_driver)
289+
dst_ds = drv.CreateDataSource(output_poly)
290+
sp_ref = osr.SpatialReference()
291+
sp_ref.SetFromUserInput(authid) # != 0 ?
292+
dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon)
293+
dst_lyr.CreateField(ogr.FieldDefn("DN", ogr.OFTInteger)) # != 0 ?
294+
dst_lyr.CreateField(ogr.FieldDefn("area", ogr.OFTInteger))
295+
dst_lyr.CreateField(ogr.FieldDefn("perimeter", ogr.OFTInteger))
296+
297+
# != 0 ?
298+
# gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
299+
300+
# A todo junto
301+
# src_band = src_ds.GetRasterBand(1)
302+
# if nodata:
303+
# src_band.SetNoDataValue(nodata)
304+
# src_band.WriteArray(label_map)
305+
# gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"])
306+
307+
# B separado
308+
# for loop for creating each label_map value into a different polygonized feature
309+
mem_drv = ogr.GetDriverByName("Memory")
310+
tmp_ds = mem_drv.CreateDataSource("tmp_ds")
311+
# itera = iter(np.unique(label_map))
312+
# cluster_id = next(itera)
313+
areas = []
314+
for cluster_id in np.unique(label_map):
315+
# temporarily write band
316+
src_band = src_ds.GetRasterBand(1)
317+
data = np.zeros_like(label_map)
318+
data -= 1 # labels in 0..NC
319+
data[label_map == cluster_id] = label_map[label_map == cluster_id]
320+
src_band.WriteArray(data)
321+
# create feature
322+
tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref)
323+
gdal.Polygonize(src_band, None, tmp_lyr, -1)
324+
# set
325+
feat = tmp_lyr.GetNextFeature()
326+
geom = feat.GetGeometryRef()
327+
featureDefn = dst_lyr.GetLayerDefn()
328+
feature = ogr.Feature(featureDefn)
329+
feature.SetGeometry(geom)
330+
feature.SetField("DN", int(cluster_id))
331+
areas += [geom.GetArea()]
332+
feature.SetField("area", int(geom.GetArea()))
333+
feature.SetField("perimeter", int(geom.Boundary().Length()))
334+
dst_lyr.CreateFeature(feature)
335+
336+
fprint(f"Clusters: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger)
337+
# fix temporarily written band
338+
src_band = src_ds.GetRasterBand(1)
339+
if nodata:
340+
src_band.SetNoDataValue(nodata)
341+
src_band.WriteArray(label_map)
342+
# close datasets
343+
src_ds.FlushCache()
344+
src_ds = None
345+
dst_ds.FlushCache()
346+
dst_ds = None
347+
return True
348+
239349

350+
def postprocess(labels_reshaped, pipeline, data_list, info_list, width, height, args):
240351
# trick to plot
241352
effective_num_clusters = len(np.unique(labels_reshaped))
242353

@@ -313,6 +424,56 @@ def plot(data_list, info_list):
313424
plt.show()
314425

315426

427+
def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
428+
"""Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
429+
Args:
430+
data (np.ndarray): The input data to filter
431+
threshold (int): The maximum number of pixels in a cluster to keep
432+
connectedness (int): The number of connected pixels to consider when filtering 4 or 8
433+
feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
434+
Returns:
435+
np.ndarray: The filtered data
436+
"""
437+
logger.info("Applying sieve filter")
438+
from osgeo import gdal
439+
440+
height, width = data.shape
441+
# fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
442+
num_clusters = len(np.unique(data))
443+
src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, gdal.GDT_Int64)
444+
src_band = src_ds.GetRasterBand(1)
445+
src_band.WriteArray(data)
446+
if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness):
447+
fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger)
448+
else:
449+
sieved = src_band.ReadAsArray()
450+
src_band = None
451+
src_ds = None
452+
num_sieved = len(np.unique(sieved))
453+
# fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
454+
fprint(
455+
f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less",
456+
level="info",
457+
feedback=feedback,
458+
logger=logger,
459+
)
460+
fprint(
461+
"Please try again increasing distance_threshold or reducing n_clusters instead...",
462+
level="info",
463+
feedback=feedback,
464+
logger=logger,
465+
)
466+
# from matplotlib import pyplot as plt
467+
# fig, (ax1, ax2) = plt.subplots(1, 2)
468+
# ax1.imshow(data)
469+
# ax1.set_title("before sieve" + str(len(np.unique(data))))
470+
# ax2.imshow(sieved)
471+
# ax2.set_title("after sieve" + str(len(np.unique(sieved))))
472+
# plt.show()
473+
# data = sieved
474+
return sieved
475+
476+
316477
def read_toml(config_toml="config.toml"):
317478
if sys.version_info >= (3, 11):
318479
import tomllib
@@ -352,7 +513,8 @@ def arg_parser(argv=None):
352513
)
353514
aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters")
354515

355-
parser.add_argument("-o", "--output", help="Output raster file, warning overwrites!", default="output.tif")
516+
parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
517+
parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
356518
parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
357519
parser.add_argument(
358520
"-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
@@ -361,7 +523,7 @@ def arg_parser(argv=None):
361523
"-nw",
362524
"--no_write",
363525
action="store_true",
364-
help="Do not write output raster",
526+
help="Do not write outputs raster nor polygons",
365527
default=False,
366528
)
367529
parser.add_argument(
@@ -371,6 +533,12 @@ def arg_parser(argv=None):
371533
help="Run in script mode, returning the label_map and the pipeline object",
372534
default=False,
373535
)
536+
parser.add_argument(
537+
"--sieve",
538+
type=int,
539+
help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
540+
)
541+
parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")
374542
args = parser.parse_args(argv)
375543
args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
376544
if Path(args.config_file).is_file() is False:
@@ -390,8 +558,13 @@ def main(argv=None):
390558
argv = sys.argv[1:]
391559
args = arg_parser(argv)
392560

393-
# 1 LEE ARGUMENTOS
394-
logger.debug(args)
561+
if args.verbose != 0:
562+
global logger
563+
from fire2a import setup_logger
564+
565+
logger = setup_logger(verbosity=args.verbose)
566+
567+
logger.info("args %s", args)
395568

396569
# 2 LEE CONFIG
397570
config = read_toml(args.config_file)
@@ -436,15 +609,23 @@ def main(argv=None):
436609
distance_threshold=args.distance_threshold,
437610
)
438611

612+
# SIEVE
613+
if args.sieve:
614+
labels_reshaped = sieve_filter(labels_reshaped, args.sieve)
615+
439616
# 7 debug postprocess
440-
postprocess(labels_reshaped, pipeline, data_list, info_list, width, height, args)
617+
# postprocess(labels_reshaped, pipeline, data_list, info_list, width, height, args)
441618

442619
# 8. ESCRIBIR RASTER
443620
if not args.no_write:
444-
from fire2a.raster import write_raster
445-
446-
if not write_raster(
447-
labels_reshaped, outfile=args.output, authid=args.authid, geotransform=args.geotransform, logger=logger
621+
if not write(
622+
labels_reshaped,
623+
width,
624+
height,
625+
output_raster=args.output_raster,
626+
output_poly=args.output_poly,
627+
authid=args.authid,
628+
geotransform=args.geotransform,
448629
):
449630
logger.error("Error writing output raster")
450631

src/fire2a/raster.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def get_rlayer_info(layer: QgsRasterLayer):
213213
"cellsize_y": layer.rasterUnitsPerPixelY(),
214214
"nodata": ndv,
215215
"bands": layer.bandCount(),
216+
"file": layer.publicSource(),
216217
}
217218

218219

@@ -398,6 +399,7 @@ def write_raster(
398399
driver_name="GTiff",
399400
authid="EPSG:3857",
400401
geotransform=(0, 1, 0, 0, 0, 1),
402+
nodata=None,
401403
feedback=None,
402404
logger=None, # logger default ?
403405
):
@@ -438,9 +440,10 @@ def write_raster(
438440
ds.SetGeoTransform(geotransform)
439441
ds.SetProjection(authid)
440442
band = ds.GetRasterBand(1)
441-
if 0 != band.SetNoDataValue(-9999):
442-
fprint("Set NoData failed", level="warning", feedback=feedback, logger=logger)
443-
return False
443+
if nodata:
444+
if 0 != band.SetNoDataValue(nodata):
445+
fprint("Set NoData failed", level="warning", feedback=feedback, logger=logger)
446+
return False
444447
if 0 != band.WriteArray(data):
445448
fprint(f"WriteArray failed for Burn Probability {burn_prob}", level="warning", feedback=feedback, logger=logger)
446449
return False

0 commit comments

Comments
 (0)