# -*- coding: utf-8 -*-
import logging

import pyproj
import geopandas as gpd
from osgeo import gdal, ogr, osr

log = logging.getLogger(__name__)

[docs]def open_raster(path, read_only=True): """Open a raster using GDAL. Parameters ---------- path : str Path of file to open. read_only : bool File mode, set to False to open in "update" mode. Returns ------- GDAL dataset """ access = gdal.GA_ReadOnly if read_only else gdal.GA_Update return gdal.OpenShared(path, access)
[docs]def open_vector(path, with_geopandas=False, read_only=True): """Open a vector dataset using OGR or GeoPandas. Parameters ---------- path : str Path to vector file. with_geopandas : bool Set to True to open with geopandas, else use OGR. read_only : bool If opening with OGR, set to False to open in "update" mode. Returns ------- GeoDataFrame if ``with_geopandas`` else OGR datsource. """ if with_geopandas: return gpd.read_file(path) update = False if read_only else True return ogr.OpenShared(path, update=update)
def _get_dataset_epsg(dataset): """Get the EPSG code from a GDAL Dataset. Parameters ---------- dataset : gdal.Dataset Returns ------- int """ prj = dataset.GetProjection() return _epsg_from_projection(prj) def _get_datasource_epsg(datasource): """Get the EPSG code from a OGR DataSource. Parameters ---------- datasource : ogr.DataSource Returns ------- int """ layer = datasource.GetLayer(0) spatial_ref = layer.GetSpatialRef() spatial_ref.AutoIdentifyEPSG() return int(spatial_ref.GetAuthorityCode(None)) def _get_gpd_epsg(gdf): """Get the EPSG code from a GeoPandas DataFrame. Parameters ---------- gdf : gpd.GeoDataFrame Returns ------- int """ return _epsg_from_projection(pyproj.Proj( def _epsg_from_projection(prj): """Return the EPSG code from a projection string. Parameters ---------- prj : str Returns ------- int """ srs = osr.SpatialReference() if prj.startswith("PROJCS") or prj.startswith("GEOGCS"): # ESRI well know text strings, srs.ImportFromESRI([prj]) else: # Proj4 strings. srs.ImportFromProj4(prj) srs.AutoIdentifyEPSG() return int(srs.GetAuthorityCode(None))
[docs]def get_epsg(data): """Get the EPSG code from an input. Parameters ---------- data : gdal.Dataset or ogr.DataSource or gpd.GeoDataFrame Returns ------- int """ if isinstance(data, gdal.Dataset): return _get_dataset_epsg(data) elif isinstance(data, ogr.DataSource): return _get_datasource_epsg(data) elif isinstance(data, gpd.GeoDataFrame): return _get_gpd_epsg(data) else: raise ValueError('Unable to get EPSG from: {}'.format(data))
[docs]def same_epsg(data1, data2): """Check sets of data have the same EPSG. Parameters ---------- data1 : gdal.DataSet or ogr.DataSource or gpd.GeoDataFrame data2 : gdal.DataSet or ogr.DataSource or gpd.GeoDataFrame Returns ------- bool """ return get_epsg(data1) == get_epsg(data2)
def _create_empty_raster(template, out_path, n_bands=1, no_data_value=None): """Create a new empty raster using GDAL. Inherits all but the data from the given template dataset. Parameters ---------- template : gdal.Dataset out_path : str n_bands : int or None Number of bands to create in the output raster. no_data_value : float or None No data value, if None uses the same as ``template``. Returns ------- gdal.DataSet """ # Raster size. x_size = template.RasterXSize y_size = template.RasterYSize n_bands = int(n_bands) if n_bands is not None else template.RasterCount dtype = template.GetRasterBand(1).DataType # Create the driver. driver = template.GetDriver() out_dataset = driver.Create(out_path, x_size, y_size, n_bands, dtype) # Set the projection. out_dataset.SetGeoTransform(template.GetGeoTransform()) out_dataset.SetProjection(template.GetProjectionRef()) # Set the no data value. if no_data_value is not None: out_ndv = float(no_data_value) for i in range(1, out_dataset.RasterCount + 1): band = out_dataset.GetRasterBand(i) band.SetNoDataValue(out_ndv) return out_dataset
[docs]def burn_vector_mask_into_raster(raster_path, vector_path, out_path, vector_field=None): """Create a new raster based on the input raster with vector features burned into the raster. To be used as a mask for pixels in the vector. Parameters ---------- raster_path : str vector_path : str out_path : str Path for output raster. Format and Datatype are the same as ``ras``. vector_field : str or None Name of a field in the vector to burn values from. If None, all vector features are burned with a constant value of 1. Returns ------- gdal.Dataset Single band raster with vector geometries burned. """ ras = open_raster(raster_path) vec = open_vector(vector_path) # Check EPSG are same, if not reproject vector. if not same_epsg(ras, vec): raise ValueError( 'Raster and vector are not the same EPSG.\n' '{} != {}'.format(get_epsg(ras), get_epsg(vec)) ) # Create an empty for GDALRasterize to burn vector values to. out_ds = _create_empty_raster(ras, out_path, n_bands=1, no_data_value=0) if vector_field is None: # Use a constant value for all features. burn_values = [1] attribute = None else: # Use the values given in the vector field. burn_values = None attribute = vector_field # Options for Rasterize. # note: burnValues and attribute are exclusive. rasterize_opts = gdal.RasterizeOptions( bands=[1], burnValues=burn_values, attribute=attribute, allTouched=True) _ = gdal.Rasterize(out_ds, vector_path, options=rasterize_opts) # Explicitly close raster to ensure it is saved. out_ds.FlushCache() out_ds = None return open_raster(out_path)
[docs]def get_raster_band_names(raster): """Obtain the names of bands from a raster. The raster metadata is queried first, if no names a present, a 1-index list of band_N is returned. Parameters ---------- raster : gdal.Dataset Returns ------- list[str] """ band_names = [] for i in range(1, raster.RasterCount + 1): band = raster.GetRasterBand(i) if band.GetDescription(): # Use the band description. band_names.append(band.GetDescription()) else: # Check for metedata. this_band_name = 'Band_{}'.format(band.GetBand()) metadata = band.GetDataset().GetMetadata_Dict() # If in metadata, return the metadata entry, else Band_N. if this_band_name in metadata and metadata[this_band_name]: band_names.append(metadata[this_band_name]) else: band_names.append(this_band_name) return band_names
[docs]def get_pixels(ras, mask, mask_val=None): """Get pixels from a raster (with optional mask). Parameters ---------- ras : np.ndarray Array of raster data in the form [bands][y][x]. mask : np.ndarray Array (2D) of zeroes to mask data. mask_val : int Value of the data pixels in the mask. Default: non-zero. Returns ------- np.ndarray Array of non-masked data. """ if mask is None: return ras # Use the mask to get the indices of the non-zero pixels. if mask_val: (i, j) = (mask == mask_val).nonzero() else: (i, j) = mask.nonzero() return (ras[i, j] if ras.ndim == 2 else ras[:, i, j])