Source code for zen3geo.datapipes.datashader

"""
DataPipes for :doc:`datashader <datashader:index>`.
"""
from typing import Any, Dict, Iterator, Optional, Union

try:
    import datashader
except ImportError:
    datashader = None
try:
    import spatialpandas
    from spatialpandas.geometry import (
        LineDtype,
        MultiLineDtype,
        MultiPointDtype,
        MultiPolygonDtype,
        PointDtype,
        PolygonDtype,
    )
except ImportError:
    spatialpandas = None

import xarray as xr
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


[docs] @functional_datapipe("rasterize_with_datashader") class DatashaderRasterizerIterDataPipe(IterDataPipe): """ Takes vector :py:class:`geopandas.GeoSeries` or :py:class:`geopandas.GeoDataFrame` geometries and rasterizes them using :py:class:`datashader.Canvas` to yield an :py:class:`xarray.DataArray` raster with the input geometries aggregated into a fixed-sized grid (functional name: ``rasterize_with_datashader``). Parameters ---------- source_datapipe : IterDataPipe[datashader.Canvas] A DataPipe that contains :py:class:`datashader.Canvas` objects with a ``.crs`` attribute. This will be the template defining the output raster's spatial extent and x/y range. vector_datapipe : IterDataPipe[geopandas.GeoDataFrame] A DataPipe that contains :py:class:`geopandas.GeoSeries` or :py:class:`geopandas.GeoDataFrame` vector geometries with a :py:attr:`.crs <geopandas.GeoDataFrame.crs>` property. agg : Optional[datashader.reductions.Reduction] Reduction operation to compute. Default depends on the input vector type: - For points, default is :py:class:`datashader.reductions.count` - For lines, default is :py:class:`datashader.reductions.any` - For polygons, default is :py:class:`datashader.reductions.any` For more information, refer to the section on Aggregation under datashader's :doc:`datashader:getting_started/Pipeline` docs. kwargs : Optional Extra keyword arguments to pass to the :py:class:`datashader.Canvas` class's aggregation methods such as ``datashader.Canvas.points``. Yields ------ raster : xarray.DataArray An :py:class:`xarray.DataArray` object containing the raster data. This raster will have a :py:attr:`rioxarray.rioxarray.XRasterBase.crs` property and a proper affine transform viewable with :py:meth:`rioxarray.rioxarray.XRasterBase.transform`. Raises ------ ModuleNotFoundError If ``spatialpandas`` is not installed. Please install it (e.g. via ``pip install spatialpandas``) before using this class. ValueError If either the length of the ``vector_datapipe`` is not 1, or if the length of the ``vector_datapipe`` is not equal to the length of the ``source_datapipe``. I.e. the ratio of vector:canvas must be 1:N or be exactly N:N. AttributeError If either the canvas in ``source_datapipe`` or vector geometry in ``vector_datapipe`` is missing a ``.crs`` attribute. Please set the coordinate reference system (e.g. using ``canvas.crs = 'OGC:CRS84'`` for the :py:class:`datashader.Canvas` input or ``vector = vector.set_crs(crs='OGC:CRS84')`` for the :py:class:`geopandas.GeoSeries` or :py:class:`geopandas.GeoDataFrame` input) before passing them into the datapipe. NotImplementedError If the input vector geometry type to ``vector_datapipe`` is not supported, typically when a :py:class:`shapely.geometry.GeometryCollection` is used. Supported types include `Point`, `LineString`, and `Polygon`, plus their multipart equivalents `MultiPoint`, `MultiLineString`, and `MultiPolygon`. Example ------- >>> import pytest >>> datashader = pytest.importorskip("datashader") >>> pyogrio = pytest.importorskip("pyogrio") >>> spatialpandas = pytest.importorskip("spatialpandas") ... >>> from torchdata.datapipes.iter import IterableWrapper >>> from zen3geo.datapipes import DatashaderRasterizer ... >>> # Read in a vector point data source >>> geodataframe = pyogrio.read_dataframe( ... "https://github.com/geopandas/pyogrio/raw/v0.4.0/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg", ... read_geometry=True, ... ) >>> assert geodataframe.crs == "EPSG:4326" # latitude/longitude coords >>> dp_vector = IterableWrapper(iterable=[geodataframe]) ... >>> # Setup blank raster canvas where we will burn vector geometries onto >>> canvas = datashader.Canvas( ... plot_width=5, ... plot_height=6, ... x_range=(160000.0, 620000.0), ... y_range=(0.0, 450000.0), ... ) >>> canvas.crs = "EPSG:32631" # UTM Zone 31N, North of Gulf of Guinea >>> dp_canvas = IterableWrapper(iterable=[canvas]) ... >>> # Rasterize vector point geometries onto blank canvas >>> dp_datashader = dp_canvas.rasterize_with_datashader( ... vector_datapipe=dp_vector ... ) ... >>> # Loop or iterate over the DataPipe stream >>> it = iter(dp_datashader) >>> dataarray = next(it) >>> dataarray <xarray.DataArray (y: 6, x: 5)> array([[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]], dtype=uint32) Coordinates: * x (x) float64 2.094e+05 3.083e+05 4.072e+05 5.06e+05 6.049e+05 * y (y) float64 4.157e+05 3.47e+05 2.783e+05 ... 1.41e+05 7.237e+04 spatial_ref int64 0 ... >>> dataarray.rio.crs CRS.from_epsg(32631) >>> dataarray.rio.transform() Affine(98871.00388807665, 0.0, 160000.0, 0.0, -68660.4193667199, 450000.0) """ def __init__( self, source_datapipe: IterDataPipe, vector_datapipe: IterDataPipe, agg: Optional = None, **kwargs: Optional[Dict[str, Any]], ) -> None: if spatialpandas is None: raise ModuleNotFoundError( "Package `spatialpandas` is required to be installed to use this datapipe. " "Please use `pip install spatialpandas` or " "`conda install -c conda-forge spatialpandas` " "to install the package" ) self.source_datapipe: IterDataPipe = source_datapipe # datashader.Canvas self.vector_datapipe: IterDataPipe = vector_datapipe # geopandas.GeoDataFrame self.agg: Optional = agg # Datashader Aggregation/Reduction function self.kwargs = kwargs len_vector_datapipe: int = len(self.vector_datapipe) len_canvas_datapipe: int = len(self.source_datapipe) if len_vector_datapipe != 1 and len_vector_datapipe != len_canvas_datapipe: raise ValueError( f"Unmatched lengths for the canvas datapipe ({self.source_datapipe}) " f"and vector datapipe ({self.vector_datapipe}). \n" f"The vector datapipe's length ({len_vector_datapipe}) should either " f"be (1) to allow for broadcasting, or match the canvas datapipe's " f"length of ({len_canvas_datapipe})." ) def __iter__(self) -> Iterator[xr.DataArray]: # Broadcast vector iterator to match length of raster iterator for canvas, vector in self.source_datapipe.zip_longest( self.vector_datapipe, fill_value=list(self.vector_datapipe).pop() ): # print(canvas, vector) # If canvas has no CRS attribute, set one to prevent AttributeError canvas.crs = getattr(canvas, "crs", None) if canvas.crs is None: raise AttributeError( "Missing crs information for datashader.Canvas with " f"x_range: {canvas.x_range} and y_range: {canvas.y_range}. " "Please set crs using e.g. `canvas.crs = 'OGC:CRS84'`." ) # Reproject vector geometries to coordinate reference system # of the raster canvas if both are different try: if vector.crs != canvas.crs: vector = vector.to_crs(crs=canvas.crs) except (AttributeError, ValueError) as e: raise AttributeError( f"Missing crs information for input {vector.__class__} object " f"with the following bounds: \n {vector.bounds} \n" f"Please set crs using e.g. `vector = vector.set_crs(crs='OGC:CRS84')`." ) from e # Convert vector to spatialpandas format to allow datashader's # rasterization methods to work try: _vector = spatialpandas.GeoDataFrame(data=vector.geometry) except ValueError as e: if str(e) == "Unable to convert data argument to a GeometryList array": raise NotImplementedError( f"Unsupported geometry type(s) {set(vector.geom_type)} detected, " "only point, line or polygon vector geometry types " "(or their multi- equivalents) are supported." ) from e else: raise e # Determine geometry type to know which rasterization method to use vector_dtype: spatialpandas.geometry.GeometryDtype = _vector.geometry.dtype if isinstance(vector_dtype, (PointDtype, MultiPointDtype)): raster: xr.DataArray = canvas.points( source=_vector, agg=self.agg, geometry="geometry", **self.kwargs ) elif isinstance(vector_dtype, (LineDtype, MultiLineDtype)): raster: xr.DataArray = canvas.line( source=_vector, agg=self.agg, geometry="geometry", **self.kwargs ) elif isinstance(vector_dtype, (PolygonDtype, MultiPolygonDtype)): raster: xr.DataArray = canvas.polygons( source=_vector, agg=self.agg, geometry="geometry", **self.kwargs ) # Convert boolean dtype rasters to uint8 to enable reprojection if raster.dtype == "bool": raster: xr.DataArray = raster.astype(dtype="uint8") # Set coordinate transform for raster and ensure affine # transform is correct (the y-coordinate goes from North to South) raster: xr.DataArray = raster.rio.set_crs(input_crs=canvas.crs) # assert raster.rio.transform().e > 0 # y goes South to North _raster: xr.DataArray = raster.rio.reproject( dst_crs=canvas.crs, shape=raster.rio.shape ) # assert _raster.rio.transform().e < 0 # y goes North to South yield _raster def __len__(self) -> int: return len(self.source_datapipe)
[docs] @functional_datapipe("canvas_from_xarray") class XarrayCanvasIterDataPipe(IterDataPipe[Union[xr.DataArray, xr.Dataset]]): """ Takes an :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset` and creates a blank :py:class:`datashader.Canvas` based on the spatial extent and coordinates of the input (functional name: ``canvas_from_xarray``). Parameters ---------- source_datapipe : IterDataPipe[xarrray.DataArray] A DataPipe that contains :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset` objects. These data objects need to have both a ``.rio.x_dim`` and ``.rio.y_dim`` attribute, which is present if the original dataset was opened using :py:func:`rioxarray.open_rasterio`, or by setting it manually using :py:meth:`rioxarray.rioxarray.XRasterBase.set_spatial_dims`. kwargs : Optional Extra keyword arguments to pass to :py:class:`datashader.Canvas`. Yields ------ canvas : datashader.Canvas A :py:class:`datashader.Canvas` object representing the same spatial extent and x/y coordinates of the input raster grid. This canvas will also have a ``.crs`` attribute that captures the original Coordinate Reference System from the input xarray object's :py:attr:`rioxarray.rioxarray.XRasterBase.crs` property. Raises ------ ModuleNotFoundError If ``datashader`` is not installed. Follow :doc:`install instructions for datashader <datashader:getting_started/index>` before using this class. Example ------- >>> import pytest >>> import numpy as np >>> import xarray as xr >>> datashader = pytest.importorskip("datashader") ... >>> from torchdata.datapipes.iter import IterableWrapper >>> from zen3geo.datapipes import XarrayCanvas ... >>> # Create blank canvas from xarray.DataArray using DataPipe >>> y = np.arange(0, -3, step=-1) >>> x = np.arange(0, 6) >>> dataarray: xr.DataArray = xr.DataArray( ... data=np.zeros(shape=(1, 3, 6)), ... coords=dict(band=[1], y=y, x=x), ... ) >>> dataarray = dataarray.rio.set_spatial_dims(x_dim="x", y_dim="y") >>> dp = IterableWrapper(iterable=[dataarray]) >>> dp_canvas = dp.canvas_from_xarray() ... >>> # Loop or iterate over the DataPipe stream >>> it = iter(dp_canvas) >>> canvas = next(it) >>> print(canvas.raster(source=dataarray)) <xarray.DataArray (band: 1, y: 3, x: 6)> array([[[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]]) Coordinates: * x (x) int64 0 1 2 3 4 5 * y (y) int64 0 -1 -2 * band (band) int64 1 ... """ def __init__( self, source_datapipe: IterDataPipe[Union[xr.DataArray, xr.Dataset]], **kwargs: Optional[Dict[str, Any]], ) -> None: if datashader is None: raise ModuleNotFoundError( "Package `datashader` is required to be installed to use this datapipe. " "Please use `pip install datashader` or " "`conda install -c conda-forge datashader` " "to install the package" ) self.source_datapipe: IterDataPipe[ Union[xr.DataArray, xr.Dataset] ] = source_datapipe self.kwargs = kwargs def __iter__(self) -> Iterator: for dataarray in self.source_datapipe: x_dim: str = dataarray.rio.x_dim y_dim: str = dataarray.rio.y_dim plot_width: int = len(dataarray[x_dim]) plot_height: int = len(dataarray[y_dim]) xmin, ymin, xmax, ymax = dataarray.rio.bounds() canvas = datashader.Canvas( plot_width=plot_width, plot_height=plot_height, x_range=(xmin, xmax), y_range=(ymin, ymax), **self.kwargs, ) canvas.crs = dataarray.rio.crs yield canvas def __len__(self) -> int: return len(self.source_datapipe)