Source code for zen3geo.datapipes.pyogrio

"""
DataPipes for :doc:`pyogrio <pyogrio:index>`.
"""
from typing import Any, Dict, Iterator, Optional

try:
    import pyogrio
except ImportError:
    pyogrio = None
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper


[docs] @functional_datapipe("read_from_pyogrio") class PyogrioReaderIterDataPipe(IterDataPipe[StreamWrapper]): """ Takes vector files (e.g. FlatGeoBuf, GeoPackage, GeoJSON) from local disk or URLs (as long as they can be read by pyogrio) and yields :py:class:`geopandas.GeoDataFrame` objects (functional name: ``read_from_pyogrio``). Based on https://github.com/pytorch/data/blob/v0.4.0/torchdata/datapipes/iter/load/iopath.py#L42-L97 Parameters ---------- source_datapipe : IterDataPipe[str] A DataPipe that contains filepaths or URL links to vector files such as FlatGeoBuf, GeoPackage, GeoJSON, etc. kwargs : Optional Extra keyword arguments to pass to :py:func:`pyogrio.read_dataframe`. Yields ------ stream_obj : geopandas.GeoDataFrame A :py:class:`geopandas.GeoDataFrame` object containing the vector data. Raises ------ ModuleNotFoundError If ``pyogrio`` is not installed. See :doc:`install instructions for pyogrio <pyogrio:install>`, and ensure that ``geopandas`` is installed too (e.g. via ``pip install pyogrio[geopandas]``) before using this class. Example ------- >>> import pytest >>> pyogrio = pytest.importorskip("pyogrio") ... >>> from torchdata.datapipes.iter import IterableWrapper >>> from zen3geo.datapipes import PyogrioReader ... >>> # Read in GeoPackage data using DataPipe >>> file_url: str = "https://github.com/geopandas/pyogrio/raw/v0.4.0/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg" >>> dp = IterableWrapper(iterable=[file_url]) >>> dp_pyogrio = dp.read_from_pyogrio() ... >>> # Loop or iterate over the DataPipe stream >>> it = iter(dp_pyogrio) >>> geodataframe = next(it) >>> geodataframe StreamWrapper< col_bool col_int8 ... col_float64 geometry 0 1.0 1.0 ... 1.5 POINT (0.00000 0.00000) 1 0.0 2.0 ... 2.5 POINT (1.00000 1.00000) 2 1.0 3.0 ... 3.5 POINT (2.00000 2.00000) 3 NaN NaN ... NaN POINT (4.00000 4.00000) <BLANKLINE> [4 rows x 12 columns]> """ def __init__( self, source_datapipe: IterDataPipe[str], **kwargs: Optional[Dict[str, Any]] ) -> None: if pyogrio is None: raise ModuleNotFoundError( "Package `pyogrio` is required to be installed to use this datapipe. " "Please use `pip install pyogrio[geopandas]` or " "`conda install -c conda-forge pyogrio` " "to install the package" ) self.source_datapipe: IterDataPipe[str] = source_datapipe self.kwargs = kwargs def __iter__(self) -> Iterator[StreamWrapper]: for filename in self.source_datapipe: yield StreamWrapper(pyogrio.read_dataframe(filename, **self.kwargs)) def __len__(self) -> int: return len(self.source_datapipe)