Source code for zen3geo.datapipes.xbatcher

"""
DataPipes for :doc:`xbatcher <xbatcher:index>`.
"""
from typing import Any, Dict, Hashable, Iterator, Optional, Tuple, Union

import xarray as xr

try:
    import xbatcher
except ImportError:
    xbatcher = None
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


[docs]@functional_datapipe("slice_with_xbatcher") class XbatcherSlicerIterDataPipe(IterDataPipe[Union[xr.DataArray, xr.Dataset]]): """ Takes an :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset` and creates a sliced window view (also known as a chip or tile) of the n-dimensional array (functional name: ``slice_with_xbatcher``). Parameters ---------- source_datapipe : IterDataPipe[xarray.DataArray] A DataPipe that contains :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset` objects. input_dims : dict A dictionary specifying the size of the inputs in each dimension to slice along, e.g. ``{'lon': 64, 'lat': 64}``. These are the dimensions the machine learning library will see. All other dimensions will be stacked into one dimension called ``batch``. kwargs : Optional Extra keyword arguments to pass to :py:func:`xbatcher.BatchGenerator`. Yields ------ chip : xarray.DataArray An :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset` object containing the sliced raster data, with the size/shape defined by the ``input_dims`` parameter. Raises ------ ModuleNotFoundError If ``xbatcher`` is not installed. Follow :doc:`install instructions for xbatcher <xbatcher:index>` before using this class. Example ------- >>> import pytest >>> import numpy as np >>> import xarray as xr >>> xbatcher = pytest.importorskip("xbatcher") ... >>> from torchdata.datapipes.iter import IterableWrapper >>> from zen3geo.datapipes import XbatcherSlicer ... >>> # Sliced window view of xarray.DataArray using DataPipe >>> dataarray: xr.DataArray = xr.DataArray( ... data=np.ones(shape=(3, 128, 128)), ... name="foo", ... dims=["band", "y", "x"] ... ) >>> dp = IterableWrapper(iterable=[dataarray]) >>> dp_xbatcher = dp.slice_with_xbatcher(input_dims={"y": 64, "x": 64}) ... >>> # Loop or iterate over the DataPipe stream >>> it = iter(dp_xbatcher) >>> dataarray_chip = next(it) >>> dataarray_chip <xarray.Dataset> Dimensions: (band: 3, y: 64, x: 64) Dimensions without coordinates: band, y, x Data variables: foo (band, y, x) float64 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0 """ def __init__( self, source_datapipe: IterDataPipe[Union[xr.DataArray, xr.Dataset]], input_dims: Dict[Hashable, int], **kwargs: Optional[Dict[str, Any]] ) -> None: if xbatcher is None: raise ModuleNotFoundError( "Package `xbatcher` is required to be installed to use this datapipe. " "Please use `pip install xbatcher` " "to install the package" ) self.source_datapipe: IterDataPipe[ Union[xr.DataArray, xr.Dataset] ] = source_datapipe self.input_dims: Dict[Hashable, int] = input_dims self.kwargs = kwargs def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: for dataarray in self.source_datapipe: if hasattr(dataarray, "name") and dataarray.name is None: # Workaround for ValueError: unable to convert unnamed # DataArray to a Dataset without providing an explicit name dataarray = dataarray.to_dataset( name=xr.backends.api.DATAARRAY_VARIABLE )[xr.backends.api.DATAARRAY_VARIABLE] # dataarray.name = "z" # doesn't work for some reason for chip in dataarray.batch.generator( input_dims=self.input_dims, **self.kwargs ): yield chip
# def __len__(self) -> int: # return len(self.source_datapipe)