Stacking layers#

Do not see them differently

Do not consider all as the same

Unwaveringly, practice guarding the One

In Geographic Information Systems ๐ŸŒ, geographic data is arranged as different โ€˜layersโ€™ ๐Ÿฐ. For example:

  • Multispectral or hyperspectral ๐ŸŒˆ optical satellites collect different radiometric bands from slices along the electromagnetic spectrum

  • Synthetic Aperture Radar (SAR) ๐Ÿ“ก sensors have different polarizations such as HH, HV, VH & VV

  • Satellite laser and radar altimeters ๐Ÿ›ฐ๏ธ measure elevation which can be turned into a Digital Elevation Model (DEM)

As long as these layers are georeferenced ๐Ÿ“ though, they can be stacked! This tutorial will cover the following topics:

  • Searching for spatiotemporal datasets in a dynamic STAC Catalog ๐Ÿ“š

  • Stacking time-series ๐Ÿ“† data into a 4D tensor of shape (time, channel, y, x)

  • Organizing different geographic ๐Ÿ—บ๏ธ layers into a dataset suitable for change detection

๐ŸŽ‰ Getting started#

Load up them libraries!

import os

import matplotlib.pyplot as plt
import numpy as np
import planetary_computer
import pystac
import rasterio
import torch
import torchdata
import xarray as xr
import zen3geo

0๏ธโƒฃ Search for spatiotemporal data ๐Ÿ“…#

This time, weโ€™ll be looking at change detection using time-series data. The focus area is Gunung Talamau, Sumatra Barat, Indonesia ๐Ÿ‡ฎ๐Ÿ‡ฉ where an earthquake on 25 Feb 2022 triggered a series of landslides โ›ฐ๏ธ. Affected areas will be mapped using Sentinel-1 Radiometrically Terrain Corrected (RTC) intensity SAR data ๐Ÿ“ก obtained via a spatiotemporal query to a STAC API.

๐Ÿ”— Links:

This is how the Sentinel-1 radar image looks like over Sumatra Barat, Indonesia on 23 February 2022, two days before the earthquake.

Sentinel-1 image over Sumatra Barat, Indonesia on 20220223

Sentinel-1 PolSAR time-series โณ#

Before we start, weโ€™ll need to set the PC_SDK_SUBSCRIPTION_KEY environment variable ๐Ÿ”ก to access the Sentinel-1 RTC data from Planetary Computer ๐Ÿ’ป. The steps are:

  1. Get a ๐Ÿช Planetary Computer account at https://planetarycomputer.microsoft.com/account/request

  2. Follow ๐Ÿง‘โ€๐Ÿซ instructions to get a subscription key

  3. Go to https://planetarycomputer.developer.azure-api.net/profile. You should have a Primary key ๐Ÿ”‘ and Secondary key ๐Ÿ—๏ธ, click on โ€˜Showโ€™ to reveal it. Copy and paste the key below, or better, set it securely ๐Ÿ” in something like a.env file!

# Uncomment the line below and set your Planetary Computer subscription key
# os.environ["PC_SDK_SUBSCRIPTION_KEY"] = "abcdefghijklmnopqrstuvwxyz123456"

Done? Letโ€™s now define an ๐Ÿงญ area of interest and ๐Ÿ“† time range covering one month before and one month after the earthquake โš ๏ธ.

# Spatiotemporal query on STAC catalog for Sentinel-1 RTC data
query = dict(
    bbox=[99.8, -0.24, 100.07, 0.15],  # West, South, East, North
    datetime=["2022-01-25T00:00:00Z", "2022-03-25T23:59:59Z"],
    collections=["sentinel-1-rtc"],
)
dp = torchdata.datapipes.iter.IterableWrapper(iterable=[query])

Then, search over a dynamic STAC Catalog ๐Ÿ“š for items matching the spatiotemporal query โ” using zen3geo.datapipes.PySTACAPISearcher (functional name: search_for_pystac_item).

dp_pystac_client = dp.search_for_pystac_item(
    catalog_url="https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

Tip

Confused about which parameters go where ๐Ÿ˜•? Hereโ€™s some clarification:

  1. Different spatiotemporal queries (e.g. for multiple geographical areas) should go in torchdata.datapipes.iter.IterableWrapper, e.g. IterableWrapper(iterable=[query_area1, query_area2]). The query dictionaries will be passed to pystac_client.Client.search().

  2. Common parameters to interact with the STAC API Client should go in search_for_pystac_item(), e.g. the STAC APIโ€™s URL (see https://stacindex.org/catalogs?access=public&type=api for a public list) and connection related parameters. These will be passed to pystac_client.Client.open().

The output is a pystac_client.ItemSearch ๐Ÿ”Ž instance that only holds the STAC API query information โ„น๏ธ but doesnโ€™t request for data! Weโ€™ll need to order it ๐Ÿงž to return something like a pystac.ItemCollection.

def get_all_items(item_search) -> pystac.ItemCollection:
    return item_search.item_collection()
dp_sen1_items = dp_pystac_client.map(fn=get_all_items)
dp_sen1_items
MapperIterDataPipe

Take a peek ๐Ÿซฃ to see if the query does contain STAC items.

it = iter(dp_sen1_items)
item_collection = next(it)
item_collection.items
[<Item id=S1A_IW_GRDH_1SDV_20220320T230514_20220320T230548_042411_050E99_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220320T230449_20220320T230514_042411_050E99_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220319T114141_20220319T114206_042389_050DD7_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220308T230513_20220308T230548_042236_0508AF_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220308T230448_20220308T230513_042236_0508AF_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220307T114141_20220307T114206_042214_0507E8_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220224T230514_20220224T230548_042061_0502B9_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220224T230449_20220224T230514_042061_0502B9_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220223T114141_20220223T114206_042039_0501F9_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220212T230514_20220212T230548_041886_04FCA5_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220212T230449_20220212T230514_041886_04FCA5_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220211T114141_20220211T114206_041864_04FBD9_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220131T230514_20220131T230548_041711_04F691_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220131T230449_20220131T230514_041711_04F691_rtc>,
 <Item id=S1A_IW_GRDH_1SDV_20220130T114142_20220130T114207_041689_04F5D0_rtc>]

Copernicus Digital Elevation Model (DEM) โ›ฐ๏ธ#

Since landslides ๐Ÿ› typically happen on steep slopes, it can be useful to have a ๐Ÿ”๏ธ topographic layer. Letโ€™s set up a STAC query ๐Ÿ™‹ to get the 30m spatial resolution Copernicus DEM.

# Spatiotemporal query on STAC catalog for Copernicus DEM 30m data
query = dict(
    bbox=[99.8, -0.24, 100.07, 0.15],  # West, South, East, North
    collections=["cop-dem-glo-30"],
)
dp_copdem30 = torchdata.datapipes.iter.IterableWrapper(iterable=[query])

Just to be fancy, letโ€™s chain ๐Ÿ”— together the next two datapipes.

dp_copdem30_items = dp_copdem30.search_for_pystac_item(
    catalog_url="https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
).map(fn=get_all_items)
dp_copdem30_items
MapperIterDataPipe

This is one of the four DEM tiles ๐Ÿ€ซ that will be returned from the query.

Copernicus 30m DEM over Sumatra Barat, Indonesia

Landslide extent vector polygons ๐Ÿ”ถ#

Now for the target labels ๐Ÿท๏ธ. Following Vector segmentation masks, weโ€™ll first load the digitized landslide polygons from a vector file ๐Ÿ“ using zen3geo.datapipes.PyogrioReader (functional name: read_from_pyogrio).

# https://gdal.org/user/virtual_file_systems.html#vsizip-zip-archives
shape_url = "/vsizip/vsicurl/https://unosat-maps.web.cern.ch/ID/LS20220308IDN/LS20220308IDN_SHP.zip/LS20220308IDN_SHP/S2_20220304_LandslideExtent_MountTalakmau.shp"

dp_shapes = torchdata.datapipes.iter.IterableWrapper(iterable=[shape_url])
dp_pyogrio = dp_shapes.read_from_pyogrio()
dp_pyogrio
PyogrioReaderIterDataPipe

Letโ€™s take a look at the geopandas.GeoDataFrame data table ๐Ÿ“Š to see the attributes inside.

it = iter(dp_pyogrio)
geodataframe = next(it)
print(geodataframe.bounds)
geodataframe.dropna(axis="columns")
        minx      miny        maxx      maxy
0  99.806331 -0.248744  100.065765  0.147054
SensorDate EventCode Area_m2 Area_ha Shape_Leng Shape_Area SensorID Confidence FieldValid Main_Dmg StaffID geometry
0 2022-04-03 LS20220308IDN 6238480.0 623.848 2.873971 0.000507 Sentinel-2 To Be Evaluated Not yet field validated Landslide TH MULTIPOLYGON (((99.80642 -0.24865, 99.80642 -0...

Weโ€™ll show you what the landslide segmentation masks ๐Ÿ˜ท look like after itโ€™s been rasterized later ๐Ÿ˜‰.

1๏ธโƒฃ Stack bands, append variables ๐Ÿ“š#

There are now three layers ๐Ÿฐ to handle, two rasters and a vector. This section will show you step by step ๐Ÿ“ถ instructions to combine them using xarray like so:

  1. Stack the Sentinel-1 ๐Ÿ›ฐ๏ธ time-series STAC Items (GeoTIFFs) into an xarray.DataArray.

  2. Combine the Sentinel-1 and Copernicus DEM โ›ฐ๏ธ xarray.DataArray layers into a single xarray.Dataset.

  3. Using the xarray.Dataset as a canvas template, rasterize the landslide ๐Ÿ› polygon extents, and append the resulting segmentation mask as another data variable ๐Ÿ—ƒ๏ธ in the xarray.Dataset.

Stack multi-channel time-series GeoTIFFs ๐Ÿ—“๏ธ#

Each pystac.Item in a pystac.ItemCollection represents a ๐Ÿ›ฐ๏ธ Sentinel-1 RTC image captured at a particular datetime โŒš. Letโ€™s subset the data to just the mountain area, and stack ๐Ÿฅž all the STAC items into a 4D time-series tensor using zen3geo.datapipes.StackSTACStacker (functional name: stack_stac_items)!

dp_sen1_stack = dp_sen1_items.stack_stac_items(
    assets=["vh", "vv"],  # SAR polarizations
    epsg=32647,  # UTM Zone 47N
    resolution=30,  # Spatial resolution of 30 metres
    bounds_latlon=[99.933681, -0.009951, 100.065765, 0.147054], # W, S, E, N
    xy_coords="center",  # pixel centroid coords instead of topleft corner
    dtype=np.float16,  # Use a lightweight data type
)
dp_sen1_stack
StackSTACStackerIterDataPipe

The keyword arguments are ๐Ÿ“จ passed to stackstac.stack() behind the scenes. The importantโ•parameters to set in this case are:

  • assets: The STAC item assets ๐Ÿฑ (typically the โ€˜bandโ€™ names)

  • epsg: The ๐ŸŒ EPSG projection code, best if you know the native projection

  • resolution: Spatial resolution ๐Ÿ“. The Sentinel-1 RTC is actually at 10m, but weโ€™ll resample to 30m to keep things small ๐Ÿค and match the Copernicus DEM.

The result is a single xarray.DataArray โ€˜datacubeโ€™ ๐ŸงŠ with dimensions (time, band, y, x).

it = iter(dp_sen1_stack)
dataarray = next(it)
dataarray
<xarray.DataArray 'stackstac-3b069ee52dbc03b1fd6c981837d6c185' (time: 15,
                                                                band: 2,
                                                                y: 579, x: 491)>
dask.array<fetch_raster_window, shape=(15, 2, 579, 491), dtype=float16, chunksize=(1, 1, 579, 491), chunktype=numpy.ndarray>
Coordinates: (12/39)
  * time                                   (time) datetime64[ns] 2022-01-30T1...
    id                                     (time) <U66 'S1A_IW_GRDH_1SDV_2022...
  * band                                   (band) <U2 'vh' 'vv'
  * x                                      (x) float64 6.039e+05 ... 6.186e+05
  * y                                      (y) float64 1.624e+04 ... -1.095e+03
    sar:resolution_range                   int64 20
    ...                                     ...
    sar:looks_equivalent_number            float64 4.4
    sat:absolute_orbit                     (time) int64 41689 41711 ... 42411
    description                            (band) <U173 'Terrain-corrected ga...
    title                                  (band) <U41 'VH: vertical transmit...
    raster:bands                           object {'nodata': -32768, 'data_ty...
    epsg                                   int64 32647
Attributes:
    spec:        RasterSpec(epsg=32647, bounds=(603870, -1110, 618600, 16260)...
    crs:         epsg:32647
    transform:   | 30.00, 0.00, 603870.00|\n| 0.00,-30.00, 16260.00|\n| 0.00,...
    resolution:  30

Append single-band DEM to datacube ๐ŸงŠ#

Time for layer number 2 ๐Ÿ’•. Letโ€™s read the Copernicus DEM โ›ฐ๏ธ STAC Item into an xarray.DataArray first, again via zen3geo.datapipes.StackSTACStacker (functional name: stack_stac_items). Weโ€™ll need to ensure โœ”๏ธ that the DEM is reprojected to the same ๐ŸŒ coordinate reference system and ๐Ÿ“ aligned to the same spatial extent as the Sentinel-1 time-series.

dp_copdem_stack = dp_copdem30_items.stack_stac_items(
    assets=["data"],
    epsg=32647,  # UTM Zone 47N
    resolution=30,  # Spatial resolution of 30 metres
    bounds_latlon=[99.933681, -0.009951, 100.065765, 0.147054], # W, S, E, N
    xy_coords="center",  # pixel centroid coords instead of topleft corner
    dtype=np.float16,  # Use a lightweight data type
    resampling=rasterio.enums.Resampling.bilinear,  # Bilinear resampling
)
dp_copdem_stack
StackSTACStackerIterDataPipe
it = iter(dp_copdem_stack)
dataarray = next(it)
dataarray
<xarray.DataArray 'stackstac-afdc91bccb12ec4ba1376f87b0641b62' (time: 4,
                                                                band: 1,
                                                                y: 579, x: 491)>
dask.array<fetch_raster_window, shape=(4, 1, 579, 491), dtype=float16, chunksize=(1, 1, 579, 491), chunktype=numpy.ndarray>
Coordinates:
  * time        (time) datetime64[ns] 2021-04-22 2021-04-22 ... 2021-04-22
    id          (time) <U40 'Copernicus_DSM_COG_10_S01_00_E100_00_DEM' ... 'C...
  * band        (band) <U4 'data'
  * x           (x) float64 6.039e+05 6.039e+05 ... 6.186e+05 6.186e+05
  * y           (y) float64 1.624e+04 1.622e+04 ... -1.065e+03 -1.095e+03
    proj:epsg   int64 4326
    platform    <U8 'TanDEM-X'
    gsd         int64 30
    proj:shape  object {3600}
    epsg        int64 32647
Attributes:
    spec:        RasterSpec(epsg=32647, bounds=(603870, -1110, 618600, 16260)...
    crs:         epsg:32647
    transform:   | 30.00, 0.00, 603870.00|\n| 0.00,-30.00, 16260.00|\n| 0.00,...
    resolution:  30

Why are there 4 โณ time layers? Actually, the STAC query had returned four DEM tiles ๐Ÿ€ซ, and stackstac.stack() has stacked both of them along a dimension name โ€˜timeโ€™ (probably better named โ€˜tileโ€™). Fear not, the tiles can be joined ๐Ÿ’ into a single terrain mosaic layer with dimensions (โ€œbandโ€, โ€œyโ€, โ€œxโ€) using zen3geo.datapipes.StackSTACMosaicker (functional name: mosaic_dataarray).

dp_copdem_mosaic = dp_copdem_stack.mosaic_dataarray(nodata=0)
dp_copdem_mosaic
StackSTACMosaickerIterDataPipe

Great! The two xarray.DataArray objects (Sentinel-1 and Copernicus DEM mosaic) can now be combined ๐Ÿชข. First, use torchdata.datapipes.iter.Zipper (functional name: zip) to put the two xarray.DataArray objects into a tuple ๐ŸŽต.

dp_sen1_copdem = dp_sen1_stack.zip(dp_copdem_mosaic)
dp_sen1_copdem
ZipperIterDataPipe

Next, use torchdata.datapipes.iter.Collator (functional name: collate) to convert ๐Ÿคธ the tuple of xarray.DataArray objects into an xarray.Dataset ๐ŸงŠ, similar to what was done in Object detection boxes.

def sardem_collate_fn(sar_and_dem: tuple) -> xr.Dataset:
    """
    Combine a pair of xarray.DataArray (SAR, DEM) inputs into an
    xarray.Dataset with data variables named 'vh', 'vv' and 'dem'.
    """
    # Turn 2 xr.DataArray objects into 1 xr.Dataset with multiple data vars
    sar, dem = sar_and_dem

    # Initialize xr.Dataset with VH and VV channels
    dataset: xr.Dataset = sar.sel(band="vh").to_dataset(name="vh")
    dataset["vv"] = sar.sel(band="vv")

    # Add Copernicus DEM mosaic as another layer
    dataset["dem"] = dem.squeeze()

    return dataset
dp_vhvvdem_dataset = dp_sen1_copdem.collate(collate_fn=sardem_collate_fn)
dp_vhvvdem_dataset
CollatorIterDataPipe

Hereโ€™s how the current xarray.Dataset ๐Ÿงฑ is structured. Notice that VH and VV polarization channels ๐Ÿ“ก are now two separate data variables, each with dimensions (time, y, x). The DEM โ›ฐ๏ธ data is not a time-series, so it has dimensions (y, x) only. All the โ€˜bandโ€™ dimensions have been removed โŒ and are now data variables within the xarray.Dataset ๐Ÿ˜Ž.

it = iter(dp_vhvvdem_dataset)
dataset = next(it)
dataset
<xarray.Dataset>
Dimensions:                                (time: 15, x: 491, y: 579)
Coordinates: (12/41)
  * time                                   (time) datetime64[ns] 2022-01-30T1...
    id                                     (time) <U66 'S1A_IW_GRDH_1SDV_2022...
    band                                   <U2 'vh'
  * x                                      (x) float64 6.039e+05 ... 6.186e+05
  * y                                      (y) float64 1.624e+04 ... -1.095e+03
    sar:resolution_range                   int64 20
    ...                                     ...
    description                            <U173 'Terrain-corrected gamma nau...
    title                                  <U41 'VH: vertical transmit, horiz...
    raster:bands                           object {'nodata': -32768, 'data_ty...
    epsg                                   int64 32647
    gsd                                    int64 30
    proj:shape                             object {3600}
Data variables:
    vh                                     (time, y, x) float16 dask.array<chunksize=(1, 579, 491), meta=np.ndarray>
    vv                                     (time, y, x) float16 dask.array<chunksize=(1, 579, 491), meta=np.ndarray>
    dem                                    (y, x) float16 dask.array<chunksize=(579, 491), meta=np.ndarray>

Visualize the DataPipe graph โ›“๏ธ too for good measure.

torchdata.datapipes.utils.to_graph(dp=dp_vhvvdem_dataset)
_images/stacking_36_0.svg

Rasterize target labels to datacube extent ๐Ÿท๏ธ#

The landslide polygons ๐Ÿ”ถ can now be rasterized and added as another layer to our datacube ๐ŸงŠ. Following Vector segmentation masks, weโ€™ll first fork the DataPipe into two branches ๐Ÿซ’ using torchdata.datapipes.iter.Forker (functional name: fork).

dp_vhvvdem_canvas, dp_vhvvdem_datacube = dp_vhvvdem_dataset.fork(num_instances=2)
dp_vhvvdem_canvas, dp_vhvvdem_datacube
(_ChildDataPipe, _ChildDataPipe)

Next, create a blank canvas ๐Ÿ“ƒ using zen3geo.datapipes.XarrayCanvas (functional name: canvas_from_xarray) and rasterize ๐Ÿ–Œ the vector polygons onto the template canvas using zen3geo.datapipes.DatashaderRasterizer (functional name: rasterize_with_datashader)

dp_datashader = dp_vhvvdem_canvas.canvas_from_xarray().rasterize_with_datashader(
    vector_datapipe=dp_pyogrio
)
dp_datashader
DatashaderRasterizerIterDataPipe

Cool, and this layer can be added ๐Ÿงฎ as another data variable in the datacube.

def cubemask_collate_fn(cube_and_mask: tuple) -> xr.Dataset:
    """
    Merge target 'mask' (xarray.DataArray) into an existing datacube
    (xarray.Dataset) as another data variable.
    """
    datacube, mask = cube_and_mask

    merged_datacube = xr.merge(objects=[datacube, mask.rename("mask")], join="override")

    return merged_datacube
dp_datacube = dp_vhvvdem_datacube.zip(dp_datashader).collate(
    collate_fn=cubemask_collate_fn
)
dp_datacube
CollatorIterDataPipe

Inspect the datacube ๐ŸงŠ and visualize all the layers ๐Ÿง… within.

it = iter(dp_datacube)
datacube = next(it)
datacube
<xarray.Dataset>
Dimensions:                                (time: 15, x: 491, y: 579)
Coordinates: (12/42)
  * time                                   (time) datetime64[ns] 2022-01-30T1...
    id                                     (time) <U66 'S1A_IW_GRDH_1SDV_2022...
    band                                   <U2 'vh'
  * x                                      (x) float64 6.039e+05 ... 6.186e+05
  * y                                      (y) float64 1.624e+04 ... -1.095e+03
    sar:resolution_range                   int64 20
    ...                                     ...
    title                                  <U41 'VH: vertical transmit, horiz...
    raster:bands                           object {'nodata': -32768, 'data_ty...
    epsg                                   int64 32647
    gsd                                    int64 30
    proj:shape                             object {3600}
    spatial_ref                            int64 0
Data variables:
    vh                                     (time, y, x) float16 dask.array<chunksize=(1, 579, 491), meta=np.ndarray>
    vv                                     (time, y, x) float16 dask.array<chunksize=(1, 579, 491), meta=np.ndarray>
    dem                                    (y, x) float16 dask.array<chunksize=(579, 491), meta=np.ndarray>
    mask                                   (y, x) uint8 0 0 0 0 0 ... 0 0 0 0 0
dataslice = datacube.sel(time="2022-02-23T11:41:54.329096000").compute()

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(11, 12), sharex=True, sharey=True)

dataslice.vh.plot.imshow(ax=ax[0][0], cmap="bone", robust=True)
ax[0][0].set_title("Sentinel-1 RTC 20220223 VH")

dataslice.vv.plot.imshow(ax=ax[0][1], cmap="bone", robust=True)
ax[0][1].set_title("Sentinel-1 RTC 20220223 VV")

dataslice.dem.plot.imshow(ax=ax[1][0], cmap="gist_earth")
ax[1][0].set_title("Copernicus DEM")

dataslice.mask.plot.imshow(ax=ax[1][1], cmap="binary_r")
ax[1][1].set_title("Landslide mask")

plt.show()
_images/stacking_46_0.png

2๏ธโƒฃ Splitters and lumpers ๐Ÿชจ#

There are many ways to do change detection ๐Ÿ•ต๏ธ. Here is but one โ˜๏ธ.

Slice spatially and temporally ๐Ÿ’‡#

For the splitters, letโ€™s first slice the datacube along the spatial dimension into 256x256 chips ๐Ÿช using zen3geo.datapipes.XbatcherSlicer (functional name: slice_with_xbatcher). Refer to Chipping and batching data if you need a ๐Ÿง‘โ€๐ŸŽ“ refresher.

dp_xbatcher = dp_datacube.slice_with_xbatcher(input_dims={"y": 256, "x": 256})
dp_xbatcher
XbatcherSlicerIterDataPipe

Next, weโ€™ll use the earthquake โš ๏ธ date to divide each 256x256 SAR time-series chip ๐Ÿ• with dimensions (time, y, x) into pre-event and post-event tensors. The target landslide ๐Ÿ› mask will be split out too.

def pre_post_target_tuple(
    datachip: xr.Dataset, event_time: str = "2022-02-25T01:39:27"
) -> tuple[xr.Dataset, xr.Dataset, xr.Dataset]:
    """
    From a single xarray.Dataset, split it into a tuple containing the
    pre/post/target tensors.
    """
    pre_times = datachip.time <= np.datetime64(event_time)
    post_times = datachip.time > np.datetime64(event_time)

    return (
        datachip.sel(time=pre_times)[["vv", "vh", "dem"]],
        datachip.sel(time=post_times)[["vv", "vh", "dem"]],
        datachip[["mask"]],
    )
dp_pre_post_target = dp_xbatcher.map(fn=pre_post_target_tuple)
dp_pre_post_target
MapperIterDataPipe

Inspect ๐Ÿ‘€ the shapes of one of the data chips that has been split into pre/post/target ๐Ÿก xarray.Dataset objects.

it = iter(dp_pre_post_target)
pre, post, target = next(it)
print(f"Before: {pre.sizes}")
print(f"After: {post.sizes}")
print(f"Target: {target.sizes}")
/home/docs/checkouts/readthedocs.org/user_builds/zen3geo/envs/v0.5.0/lib/python3.10/site-packages/torch/utils/data/datapipes/iter/combining.py:248: UserWarning: Some child DataPipes are not exhausted when __iter__ is called. We are resetting the buffer and each child DataPipe will read from the start again.
  warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
Before: Frozen({'time': 9, 'y': 256, 'x': 256})
After: Frozen({'time': 6, 'y': 256, 'x': 256})
Target: Frozen({'y': 256, 'x': 256})

Cool, at this point, you may want to decide ๐Ÿค” on how to handle different sized before and after time-series images ๐ŸŽž๏ธ. Or maybe not, and torch.Tensor objects are all you desire โค๏ธโ€๐Ÿ”ฅ.

def dataset_to_tensors(triple_tuple) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Converts xarray.Datasets in a tuple into torch.Tensor objects.
    """
    pre, post, target = triple_tuple

    _pre: torch.Tensor = torch.as_tensor(pre.to_array().data)
    _post: torch.Tensor = torch.as_tensor(post.to_array().data)
    _target: torch.Tensor = torch.as_tensor(target.mask.data.astype("uint8"))

    return _pre, _post, _target
dp_tensors = dp_pre_post_target.map(fn=dataset_to_tensors)
dp_tensors
MapperIterDataPipe

This is the final DataPipe graph โ›“๏ธ.

torchdata.datapipes.utils.to_graph(dp=dp_tensors)
_images/stacking_58_0.svg

Into a DataLoader ๐Ÿ‹๏ธ#

Time to connect the DataPipe to torch.utils.data.DataLoader โ™ป๏ธ!

dataloader = torch.utils.data.DataLoader2(dataset=dp_tensors, batch_size=None)
for i, batch in enumerate(dataloader):
    pre, post, target = batch
    print(f"Batch {i} - pre: {pre.shape}, post: {post.shape}, target: {target.shape}")
/home/docs/checkouts/readthedocs.org/user_builds/zen3geo/envs/v0.5.0/lib/python3.10/site-packages/torch/utils/data/datapipes/iter/combining.py:248: UserWarning: Some child DataPipes are not exhausted when __iter__ is called. We are resetting the buffer and each child DataPipe will read from the start again.
  warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
Batch 0 - pre: torch.Size([3, 9, 256, 256]), post: torch.Size([3, 6, 256, 256]), target: torch.Size([256, 256])
Batch 1 - pre: torch.Size([3, 9, 256, 256]), post: torch.Size([3, 6, 256, 256]), target: torch.Size([256, 256])

Donโ€™t just see change, be the change ๐Ÿชง!

See also

This data pipeline is adapted from (a subset of) some amazing ๐Ÿงช research done during the Frontier Development Lab 2022 - Self Supervised Learning on SAR data for Change Detection challenge ๐Ÿš€. Watch the final showcase video at https://www.youtube.com/watch?v=igAUTJwbmsY ๐Ÿ“ฝ๏ธ!