Chipping and batching data#

Following on from the previous tutorial, letโ€™s ๐Ÿง‘โ€๐ŸŽ“ learn more about creating a more complicated ๐ŸŒˆ raster data pipeline. Specifically, weโ€™ll go through the following:

  • Loading Cloud-Optimized GeoTIFFs (COGs) from different geographic regions ๐ŸŒ

  • Cut up each large GeoTIFF into several 512 x 512 pixel chips ๐Ÿฅจ

  • Create batches of chips/tensors to feed into a DataLoader ๐Ÿ‹๏ธ

Some terminology ๐Ÿ“œ disambiguation:

  • scene - the big image (e.g. 10000x10000 pixels) from a satellite ๐Ÿ›ฐ๏ธ (e.g. a GeoTIFF)

  • chip - the small image (e.g. 512x512 pixels) cut โœ‚๏ธ out from a satellite scene to be loaded as a tensor

See also:

๐ŸŽ‰ Getting started#

Load up them libraries!

import pystac
import planetary_computer
import rioxarray

import torch
import torchdata
import zen3geo

0๏ธโƒฃ Find Cloud-Optimized GeoTIFFs โ˜๏ธ#

Synthetic-Aperture Radar (SAR) from a STAC catalog! Weโ€™ll get some Sentinel-1 Ground-Range Detected (GRD) data over Osaka and Tokyo in Japan ๐Ÿ‡ฏ๐Ÿ‡ต.

๐Ÿ”— Links:

item_urls = [
    # Osaka
    # Tokyo

# Load each STAC item's metadata and sign the assets
items = [pystac.Item.from_file(item_url) for item_url in item_urls]
signed_items = [planetary_computer.sign(item) for item in items]
[<Item id=S1A_IW_GRDH_1SDV_20220614T210034_20220614T210059_043664_05368A>,
 <Item id=S1A_IW_GRDH_1SDV_20220616T204349_20220616T204414_043693_053764>]

Inspect one of the data assets ๐Ÿฑ#

The Sentinel-1 STAC item contains several assets. These include different ใ€ฐ๏ธ polarizations (e.g. โ€˜VHโ€™, โ€˜VVโ€™). Letโ€™s just take the โ€˜thumbnailโ€™ product for now which is an RGB preview, with the red ๐ŸŸฅ channel (R) representing the co-polarization (VV or HH), the green ๐ŸŸฉ channel (G) representing the cross-polarization (VH or HV) and the blue ๐ŸŸฆ channel (B) representing the ratio of the cross and co-polarizations.

url: str = signed_items[0].assets["thumbnail"].href
da = rioxarray.open_rasterio(filename=url)
/home/docs/checkouts/ NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
  warnings.warn(str(rio_warning.message), type(rio_warning.message))  # type: ignore
<xarray.DataArray (band: 3, y: 348, x: 503)>
[525132 values with dtype=uint8]
  * band         (band) int64 1 2 3
  * x            (x) float64 0.5 1.5 2.5 3.5 4.5 ... 499.5 500.5 501.5 502.5
  * y            (y) float64 0.5 1.5 2.5 3.5 4.5 ... 344.5 345.5 346.5 347.5
    spatial_ref  int64 0
    scale_factor:  1.0
    add_offset:    0.0

This is how the Sentinel-1 radar image looks like over Osaka on 14 June 2022.

Sentinel-1 image over Osaka, Japan on 20220614

1๏ธโƒฃ Creating 512x512 chips from large satellite scenes ๐ŸชŸ#

Unless you have a lot of RAM, it is common to cut โœ‚๏ธ a large satellite scene into multiple smaller chips (or patches, tiles ๐Ÿ€„, etc) first. This is typically done in a rolling or sliding window ๐ŸชŸ fashion, via a nested loop through the y-dimension and x-dimension in strides of say, 512 pixels x 512 pixels.

Letโ€™s begin by setting up the first part of the DataPipe, which is to read the satellite scene ๐Ÿ–ผ๏ธ using rioxarray.

# Just get the VV polarization for now from Sentinel-1
urls = [item.assets["vv"].href for item in signed_items]
dp = torchdata.datapipes.iter.IterableWrapper(iterable=urls)
dp_rioxarray = dp.read_from_rioxarray(overview_level=3)

Slicing with XbatcherSlicer ๐Ÿ•#

To create the chips, weโ€™ll be using xbatcher which allows slicing ๐Ÿ”ช of an n-dimensional datacube along any dimension (e.g. longitude, latitude, time ๐Ÿ•›). This xbatcher library is integrated into โ˜ฏ zen3geo as a DataPipe called zen3geo.datapipes.XbatcherSlicer, which can be used as follows:

dp_xbatcher = dp_rioxarray.slice_with_xbatcher(input_dims={"y": 512, "x": 512})

This should give us about 12 chips in total, 6 from each of the 2 Sentinel-1 images that were passed in.

chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
Number of chips: 12

Now, if you want to customize the sliding window (e.g. do overlapping strides), pass in extra parameters to slice_with_xbatcher, and it will be handled by xbatcher.BatchGenerator.

dp_xbatcher = dp_rioxarray.slice_with_xbatcher(
        input_dims={"y": 512, "x": 512}, input_overlap={"y": 256, "x": 256}

Great, and this overlapping stride method should give us more 512x512 chips ๐Ÿงฎ than before.

chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
Number of chips: 30

Double-check that single chips are of the correct dimensions (band: 1, y: 512, x: 512).

sample = chips[0]
Dimensions:                        (band: 1, y: 512, x: 512)
  * band                           (band) int64 1
    spatial_ref                    int64 0
Dimensions without coordinates: y, x
Data variables:
    __xarray_dataarray_variable__  (band, y, x) uint16 0 157 154 ... 179 128 125

2๏ธโƒฃ Pool chips into mini-batches โš™๏ธ#

In total, we now have a set of 30 chips of size 512 x 512 pixels each. These chips can be divided into batches that are of a reasonable size. Letโ€™s use torchdata.datapipes.iter.Batcher() to do so.

dp_batch = dp_xbatcher.batch(batch_size=10)
print(f"Number of items in first batch: {len(list(dp_batch)[0])}")
Number of items in first batch: 10

Now each batch will have 10 chips of size 512 x 512, with each chip being an xarray.Dataset.


Notice how no mosaicking nor reprojection was done for the two satellite scenes. This is the beauty of zen3geo - full flexibility of combining geospatial datasets ๐Ÿ˜Ž. Respect the native coordinate system and let the data flow directly into your models!

Oh, and to be super clear, of the 3 batches of 10 chips each:

  • The first batch has 10 chips are from the 1st satellite scene over Osaka

  • The second batch has 5 chips over Osaka, and 5 chips over Tokyo

  • The third batch has 10 chips from the 2nd satellite scene over Tokyo

Stack many chips in mini-batches into a single tensor ๐Ÿฅž#

Letโ€™s now stack all these chips into a single tensor per batch, with a (number, channel, height, width) shape like (10, 1, 512, 512). Weโ€™ll need a custom ๐Ÿช„ collate function to do the conversion (from xarray.Dataset to torch.Tensor) and stacking.

def xr_collate_fn(samples) -> torch.Tensor:
    Converts individual xarray.Dataset objects to a torch.Tensor (int16 dtype),
    and stacks them all into a single torch.Tensor.
    tensors = [
        for sample in samples
    return torch.stack(tensors=tensors)

Then, pass this collate function to torchdata.datapipes.iter.Collator.

dp_collate = dp_batch.collate(collate_fn=xr_collate_fn)
print(f"Number of mini-batches: {len(list(dp_collate))}")
print(f"Mini-batch tensor shape: {list(dp_collate)[0].shape}")
Number of mini-batches: 3
Mini-batch tensor shape: torch.Size([10, 1, 512, 512])

Into a DataLoader ๐Ÿ‹๏ธ#

One more thing ๐ŸŽ, throw the DataPipe into! Set batch_size to None, since weโ€™ve handled the batching manually in the above sections already.

dataloader =, batch_size=None)
for i, batch in enumerate(dataloader):
    tensor = batch
    print(f"Batch {i}: {tensor.shape}")
Batch 0: torch.Size([10, 1, 512, 512])
Batch 1: torch.Size([10, 1, 512, 512])
Batch 2: torch.Size([10, 1, 512, 512])

Lights, camera, action ๐Ÿ’ฅ