Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get band mean and standard deviation for better normalization #49

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions malpolon/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import os
import re
from typing import Iterable, Union
import warnings

import numpy as np
import rasterio
from shapely import Point, Polygon
from torchgeo.datasets import BoundingBox

Expand Down Expand Up @@ -155,3 +157,58 @@ def get_files_path_recursively(path, *args, suffix='') -> list:
for f in filenames
if re.search(rf"^.*({suffix})\.({ext_list})$", f)]
return result


def get_mean_sd_by_band(path, compute_if_needed=False, ignore_zeros=True):
'''
Reads metadata or computes mean and sd of each band of a geotiff.
If the metadata is not available, mean and standard deviation can be computed via numpy.

Parameters
----------
path : str
path to a geotiff file
ignore_zeros : boolean
ignore zeros when computing mean and sd via numpy

Returns
-------
means : list
list of mean values per band
sds : list
list of standard deviation values per band
'''

src = rasterio.open(path)
means = []
sds = []

for band in range(1, src.count+1):
try:
tags = src.tags(band)
if 'STATISTICS_MEAN' in tags and 'STATISTICS_STDDEV' in tags:
mean = float(tags['STATISTICS_MEAN'])
sd = float(tags['STATISTICS_STDDEV'])
means.append(mean)
sds.append(sd)
else:
raise KeyError("Statistics metadata not found.")

except KeyError:
if compute_if_needed:
arr = src.read(band)
if ignore_zeros:
mean = np.ma.masked_equal(arr, 0).mean()
sd = np.ma.masked_equal(arr, 0).std()
else:
mean = np.mean(arr)
sd = np.std(arr)
means.append(float(mean))
sds.append(float(sd))
else :
warnings.warn("Statistics metadata not found and computation not enabled.", UserWarning)
except Exception as e:
print(f"Error processing band {band}: {e}")

src.close()
return means, sds