Chipping and batching data#

What is separation?

What isnโ€™t?

Following on from the previous tutorial, letโ€™s ๐Ÿง‘โ€๐ŸŽ“ learn more about creating a more complicated ๐ŸŒˆ raster data pipeline. Specifically, weโ€™ll go through the following:

  • Loading Cloud-Optimized GeoTIFFs (COGs) from different geographic regions ๐ŸŒ

  • Cut up each large GeoTIFF into several 512 x 512 pixel chips ๐Ÿฅจ

  • Create batches of chips/tensors to feed into a DataLoader ๐Ÿ‹๏ธ

Some terminology ๐Ÿ“œ disambiguation:

  • scene - the big image (e.g. 10000x10000 pixels) from a satellite ๐Ÿ›ฐ๏ธ (e.g. a GeoTIFF)

  • chip - the small image (e.g. 512x512 pixels) cut โœ‚๏ธ out from a satellite scene to be loaded as a tensor

See also:

๐ŸŽ‰ Getting started#

Load up them libraries!

import pystac
import planetary_computer
import rioxarray

import torch
import torchdata
import zen3geo

0๏ธโƒฃ Find Cloud-Optimized GeoTIFFs โ˜๏ธ#

Synthetic-Aperture Radar (SAR) from a STAC catalog! Weโ€™ll get some Sentinel-1 Ground-Range Detected (GRD) data over Osaka and Tokyo in Japan ๐Ÿ‡ฏ๐Ÿ‡ต.

๐Ÿ”— Links:

item_urls = [
    # Osaka
    # Tokyo

# Load each STAC item's metadata and sign the assets
items = [pystac.Item.from_file(item_url) for item_url in item_urls]
signed_items = [planetary_computer.sign(item) for item in items]
[<Item id=S1A_IW_GRDH_1SDV_20220614T210034_20220614T210059_043664_05368A>,
 <Item id=S1A_IW_GRDH_1SDV_20220616T204349_20220616T204414_043693_053764>]

Inspect one of the data assets ๐Ÿฑ#

The Sentinel-1 STAC item contains several assets. These include different ใ€ฐ๏ธ polarizations (e.g. โ€˜VHโ€™, โ€˜VVโ€™). Letโ€™s just take the โ€˜thumbnailโ€™ product for now which is an RGB preview, with the red ๐ŸŸฅ channel (R) representing the co-polarization (VV or HH), the green ๐ŸŸฉ channel (G) representing the cross-polarization (VH or HV) and the blue ๐ŸŸฆ channel (B) representing the ratio of the cross and co-polarizations.

url: str = signed_items[0].assets["thumbnail"].href
da = rioxarray.open_rasterio(filename=url)
/home/docs/checkouts/ NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
  warnings.warn(str(rio_warning.message), type(rio_warning.message))  # type: ignore
<xarray.DataArray (band: 3, y: 348, x: 503)> Size: 525kB
[525132 values with dtype=uint8]
  * band         (band) int64 24B 1 2 3
  * x            (x) float64 4kB 0.5 1.5 2.5 3.5 4.5 ... 499.5 500.5 501.5 502.5
  * y            (y) float64 3kB 0.5 1.5 2.5 3.5 4.5 ... 344.5 345.5 346.5 347.5
    spatial_ref  int64 8B 0
    scale_factor:  1.0
    add_offset:    0.0

This is how the Sentinel-1 radar image looks like over Osaka on 14 June 2022.

Sentinel-1 GRD image over Osaka, Japan on 20220614

1๏ธโƒฃ Creating 512x512 chips from large satellite scenes ๐ŸชŸ#

Unless you have a lot of RAM, it is common to cut โœ‚๏ธ a large satellite scene into multiple smaller chips (or patches, tiles ๐Ÿ€„, etc) first. This is typically done in a rolling or sliding window ๐ŸชŸ fashion, via a nested loop through the y-dimension and x-dimension in strides of say, 512 pixels x 512 pixels.

Letโ€™s begin by setting up the first part of the DataPipe, which is to read the satellite scene ๐Ÿ–ผ๏ธ using rioxarray.

# Just get the VV polarization for now from Sentinel-1
urls = [item.assets["vv"].href for item in signed_items]
dp = torchdata.datapipes.iter.IterableWrapper(iterable=urls)
dp_rioxarray = dp.read_from_rioxarray(overview_level=3)

Slicing with XbatcherSlicer ๐Ÿ•#

To create the chips, weโ€™ll be using xbatcher which allows slicing ๐Ÿ”ช of an n-dimensional datacube along any dimension (e.g. longitude, latitude, time ๐Ÿ•›). This xbatcher library is integrated into โ˜ฏ zen3geo as a DataPipe called zen3geo.datapipes.XbatcherSlicer (functional name: slice_with_xbatcher), which can be used as follows:

dp_xbatcher = dp_rioxarray.slice_with_xbatcher(input_dims={"y": 512, "x": 512})

This should give us about 12 chips in total, 6 from each of the 2 Sentinel-1 images that were passed in.

print(f"Number of chips: {len(dp_xbatcher)}")
Number of chips: 12

Now, if you want to customize the sliding window (e.g. do overlapping strides), pass in extra parameters to slice_with_xbatcher, and it will be handled by xbatcher.BatchGenerator.

dp_xbatcher = dp_rioxarray.slice_with_xbatcher(
        input_dims={"y": 512, "x": 512}, input_overlap={"y": 256, "x": 256}

Great, and this overlapping stride method should give us more 512x512 chips ๐Ÿงฎ than before.

print(f"Number of chips: {len(dp_xbatcher)}")
Number of chips: 30

Double-check that single chips are of the correct dimensions (band: 1, y: 512, x: 512).

chips = list(dp_xbatcher)
sample = chips[0]
<xarray.DataArray (band: 1, y: 512, x: 512)> Size: 524kB
array([[[  0, 157, 154, ..., 232, 205, 220],
        [  0, 152, 149, ..., 222, 235, 209],
        [  0, 149, 166, ..., 199, 200, 222],
        [189, 156, 158, ..., 193, 265, 164],
        [171, 170, 168, ..., 250, 161, 135],
        [155, 162, 178, ..., 179, 128, 125]]], dtype=uint16)
  * band         (band) int64 8B 1
    spatial_ref  int64 8B 0
Dimensions without coordinates: y, x
    AREA_OR_POINT:             Area
    TIFFTAG_DATETIME:          2022:06:14 22:54:32
    TIFFTAG_SOFTWARE:          Sentinel-1 IPF 003.52
    _FillValue:                0
    scale_factor:              1.0
    add_offset:                0.0


Please do not use overlapping strides (i.e. input_overlap < input_dim) if you will be ๐Ÿช“ splitting your chips into training, validation and test sets later! If you have say 60 overlapping chips and then go on to divide those ๐Ÿช chips randomly into train/val/test sets of 30/20/10, you will have information leakage ๐Ÿšฐ between the 30 training chips and 20 validation plus 10 test chips, so your modelโ€™s reported validation and test metrics ๐Ÿ“ˆ will be overestimating the actual performance ๐Ÿ˜ฒ!

Ideally, your train/val/test chips should be situated independently within spatially contiguous blocks ๐Ÿงฑ. See these links for more information on why:

  • Kattenborn, T., Schiefer, F., Frey, J., Feilhauer, H., Mahecha, M. D., & Dormann, C. F. (2022). Spatially autocorrelated training and validation samples inflate performance assessment of convolutional neural networks. ISPRS Open Journal of Photogrammetry and Remote Sensing, 5, 100018.

  • pangeo-data/xbatcher#78

Yes, spatial statistics ๐Ÿงฎ matter, geography is special ๐Ÿค“.

2๏ธโƒฃ Pool chips into mini-batches โš™๏ธ#

In total, we now have a set of 30 ๐Ÿช chips of size 512 x 512 pixels each. These chips can be divided into batches that are of a reasonable size. Letโ€™s use torchdata.datapipes.iter.Batcher (functional name: batch) to do so.

dp_batch = dp_xbatcher.batch(batch_size=10)
print(f"Number of items in first batch: {len(list(dp_batch)[0])}")
Number of items in first batch: 10

Now each batch will have 10 chips of size 512 x 512, with each chip being an xarray.DataArray.


Notice how no mosaicking nor reprojection was done for the two satellite scenes. This is the beauty of zen3geo - full flexibility of combining geospatial datasets ๐Ÿ˜Ž. Respect the native coordinate system and let the data flow directly into your models!

Oh, and to be super clear, of the 3 batches of 10 chips each:

  • The first batch has 10 chips are from the 1st satellite scene over Osaka

  • The second batch has 5 chips over Osaka, and 5 chips over Tokyo

  • The third batch has 10 chips from the 2nd satellite scene over Tokyo

Stack many chips in mini-batches into a single tensor ๐Ÿฅž#

Letโ€™s now stack all these chips into a single tensor per batch, with a (number, channel, height, width) shape like (10, 1, 512, 512). Weโ€™ll need a custom ๐Ÿช„ collate function to do the conversion (from xarray.DataArray to torch.Tensor) and stacking.

def xr_collate_fn(samples) -> torch.Tensor:
    Converts individual xarray.DataArray objects to a torch.Tensor (int16
    dtype), and stacks them all into a single torch.Tensor.
    tensors = [
        torch.as_tensor("int16")) for sample in samples
    return torch.stack(tensors=tensors)

Then, pass this collate function to torchdata.datapipes.iter.Collator (functional name: collate).

dp_collate = dp_batch.collate(collate_fn=xr_collate_fn)
print(f"Number of mini-batches: {len(dp_collate)}")
print(f"Mini-batch tensor shape: {list(dp_collate)[0].shape}")
Number of mini-batches: 3
Mini-batch tensor shape: torch.Size([10, 1, 512, 512])

Into a DataLoader ๐Ÿ‹๏ธ#

One more thing ๐ŸŽ, throw the DataPipe into! Set batch_size to None, since weโ€™ve handled the batching manually in the above sections already.

dataloader =, batch_size=None)
for i, batch in enumerate(dataloader):
    tensor = batch
    print(f"Batch {i}: {tensor.shape}")
Batch 0: torch.Size([10, 1, 512, 512])
Batch 1: torch.Size([10, 1, 512, 512])
Batch 2: torch.Size([10, 1, 512, 512])

Lights, camera, action ๐Ÿ’ฅ