Chipping and batching data
Contents
Chipping and batching data#
Following on from the previous tutorial, letโs ๐งโ๐ learn more about creating a more complicated ๐ raster data pipeline. Specifically, weโll go through the following:
Loading Cloud-Optimized GeoTIFFs (COGs) from different geographic regions ๐
Cut up each large GeoTIFF into several 512 x 512 pixel chips ๐ฅจ
Create batches of chips/tensors to feed into a DataLoader ๐๏ธ
Some terminology ๐ disambiguation:
scene - the big image (e.g. 10000x10000 pixels) from a satellite ๐ฐ๏ธ (e.g. a GeoTIFF)
chip - the small image (e.g. 512x512 pixels) cut โ๏ธ out from a satellite scene to be loaded as a tensor
See also:
๐ Getting started#
Load up them libraries!
import pystac
import planetary_computer
import rioxarray
import torch
import torchdata
import zen3geo
0๏ธโฃ Find Cloud-Optimized GeoTIFFs โ๏ธ#
Synthetic-Aperture Radar (SAR) from a STAC catalog! Weโll get some Sentinel-1 Ground-Range Detected (GRD) data over Osaka and Tokyo in Japan ๐ฏ๐ต.
๐ Links:
item_urls = [
# Osaka
"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-1-grd/items/S1A_IW_GRDH_1SDV_20220614T210034_20220614T210059_043664_05368A",
# Tokyo
"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-1-grd/items/S1A_IW_GRDH_1SDV_20220616T204349_20220616T204414_043693_053764",
]
# Load each STAC item's metadata and sign the assets
items = [pystac.Item.from_file(item_url) for item_url in item_urls]
signed_items = [planetary_computer.sign(item) for item in items]
signed_items
[<Item id=S1A_IW_GRDH_1SDV_20220614T210034_20220614T210059_043664_05368A>,
<Item id=S1A_IW_GRDH_1SDV_20220616T204349_20220616T204414_043693_053764>]
Inspect one of the data assets ๐ฑ#
The Sentinel-1 STAC item contains several assets. These include different ใฐ๏ธ polarizations (e.g. โVHโ, โVVโ). Letโs just take the โthumbnailโ product for now which is an RGB preview, with the red ๐ฅ channel (R) representing the co-polarization (VV or HH), the green ๐ฉ channel (G) representing the cross-polarization (VH or HV) and the blue ๐ฆ channel (B) representing the ratio of the cross and co-polarizations.
url: str = signed_items[0].assets["thumbnail"].href
da = rioxarray.open_rasterio(filename=url)
da
/home/docs/checkouts/readthedocs.org/user_builds/zen3geo/envs/v0.2.0/lib/python3.10/site-packages/rioxarray/_io.py:851: NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
warnings.warn(str(rio_warning.message), type(rio_warning.message)) # type: ignore
<xarray.DataArray (band: 3, y: 348, x: 503)> [525132 values with dtype=uint8] Coordinates: * band (band) int64 1 2 3 * x (x) float64 0.5 1.5 2.5 3.5 4.5 ... 499.5 500.5 501.5 502.5 * y (y) float64 0.5 1.5 2.5 3.5 4.5 ... 344.5 345.5 346.5 347.5 spatial_ref int64 0 Attributes: scale_factor: 1.0 add_offset: 0.0
This is how the Sentinel-1 radar image looks like over Osaka on 14 June 2022.
1๏ธโฃ Creating 512x512 chips from large satellite scenes ๐ช#
Unless you have a lot of RAM, it is common to cut โ๏ธ a large satellite scene into multiple smaller chips (or patches, tiles ๐, etc) first. This is typically done in a rolling or sliding window ๐ช fashion, via a nested loop through the y-dimension and x-dimension in strides of say, 512 pixels x 512 pixels.
Letโs begin by setting up the first part of the DataPipe,
which is to read the satellite scene ๐ผ๏ธ using rioxarray
.
# Just get the VV polarization for now from Sentinel-1
urls = [item.assets["vv"].href for item in signed_items]
dp = torchdata.datapipes.iter.IterableWrapper(iterable=urls)
dp_rioxarray = dp.read_from_rioxarray(overview_level=3)
dp_rioxarray
RioXarrayReaderIterDataPipe
Slicing with XbatcherSlicer ๐#
To create the chips, weโll be using xbatcher
which allows slicing ๐ช of an
n-dimensional datacube along any dimension (e.g. longitude, latitude, time ๐).
This xbatcher
library is integrated into โฏ zen3geo
as a DataPipe called
zen3geo.datapipes.XbatcherSlicer
, which can be used as follows:
dp_xbatcher = dp_rioxarray.slice_with_xbatcher(input_dims={"y": 512, "x": 512})
dp_xbatcher
XbatcherSlicerIterDataPipe
This should give us about 12 chips in total, 6 from each of the 2 Sentinel-1 images that were passed in.
chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
Number of chips: 12
Now, if you want to customize the sliding window (e.g. do overlapping strides),
pass in extra parameters to slice_with_xbatcher
, and it will be handled by
xbatcher.BatchGenerator
.
dp_xbatcher = dp_rioxarray.slice_with_xbatcher(
input_dims={"y": 512, "x": 512}, input_overlap={"y": 256, "x": 256}
)
dp_xbatcher
XbatcherSlicerIterDataPipe
Great, and this overlapping stride method should give us more 512x512 chips ๐งฎ than before.
chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
Number of chips: 30
Double-check that single chips are of the correct dimensions (band: 1, y: 512, x: 512).
sample = chips[0]
sample
<xarray.Dataset> Dimensions: (band: 1, y: 512, x: 512) Coordinates: * band (band) int64 1 spatial_ref int64 0 Dimensions without coordinates: y, x Data variables: __xarray_dataarray_variable__ (band, y, x) uint16 0 157 154 ... 179 128 125
2๏ธโฃ Pool chips into mini-batches โ๏ธ#
In total, we now have a set of 30 chips of size 512 x 512 pixels each.
These chips can be divided into batches that are of a reasonable size.
Letโs use torchdata.datapipes.iter.Batcher()
to do so.
dp_batch = dp_xbatcher.batch(batch_size=10)
print(f"Number of items in first batch: {len(list(dp_batch)[0])}")
Number of items in first batch: 10
Now each batch will have 10 chips of size 512 x 512, with
each chip being an xarray.Dataset
.
Note
Notice how no mosaicking nor reprojection was done for the two satellite scenes. This is the beauty of zen3geo - full flexibility of combining geospatial datasets ๐. Respect the native coordinate system and let the data flow directly into your models!
Oh, and to be super clear, of the 3 batches of 10 chips each:
The first batch has 10 chips are from the 1st satellite scene over Osaka
The second batch has 5 chips over Osaka, and 5 chips over Tokyo
The third batch has 10 chips from the 2nd satellite scene over Tokyo
Stack many chips in mini-batches into a single tensor ๐ฅ#
Letโs now stack all these chips into a single tensor per batch, with a
(number, channel, height, width) shape like (10, 1, 512, 512). Weโll need a
custom ๐ช collate function to do the conversion
(from xarray.Dataset
to torch.Tensor
) and stacking.
def xr_collate_fn(samples) -> torch.Tensor:
"""
Converts individual xarray.Dataset objects to a torch.Tensor (int16 dtype),
and stacks them all into a single torch.Tensor.
"""
tensors = [
torch.as_tensor(
data=sample.data_vars.get(key="__xarray_dataarray_variable__").data.astype("int16"),
)
for sample in samples
]
return torch.stack(tensors=tensors)
Then, pass this collate function to
torchdata.datapipes.iter.Collator
.
dp_collate = dp_batch.collate(collate_fn=xr_collate_fn)
print(f"Number of mini-batches: {len(list(dp_collate))}")
print(f"Mini-batch tensor shape: {list(dp_collate)[0].shape}")
Number of mini-batches: 3
Mini-batch tensor shape: torch.Size([10, 1, 512, 512])
Into a DataLoader ๐๏ธ#
One more thing ๐, throw the DataPipe into
torch.utils.data.DataLoader
!
Set batch_size
to None
, since weโve handled the batching manually in the
above sections already.
dataloader = torch.utils.data.DataLoader(dataset=dp_collate, batch_size=None)
for i, batch in enumerate(dataloader):
tensor = batch
print(f"Batch {i}: {tensor.shape}")
Batch 0: torch.Size([10, 1, 512, 512])
Batch 1: torch.Size([10, 1, 512, 512])
Batch 2: torch.Size([10, 1, 512, 512])
Lights, camera, action ๐ฅ