Model Training with PyTorch Lightning#

In this lesson, we will train a PyTorch Lightning model leveraging the datapipe created in the prior lesson. We’ll cover how you can wrap a torch DataPipe with a pytorch lightning DataModule to quickly get the advantages of using pytorch lightning for model development, which include:

  • running your model training on any hardware: CPU, GPU, or TPU

  • distributed training for large models and datasets

  • easier logging, metrics, and visualizations

to set up the jupyter env, set up hatch as described in the readme, then:

hatch -e nb shell
exit
jupyter lab

Let’s start by importing the required packages we will use. You will notice some of them require authentication. Again, we use compositional programming throughout, so by importing a package you can call methods from that package.

import os
import urllib

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pystac
import segmentation_models_pytorch as smp
import stackstac
import torch
import torch.nn as nn  # PyTorch Lightning NN (neural network) module
import torch.nn.functional as F
import torch.optim as optim  # Training optimizer module
import torchdata
import xarray as xr
import zen3geo
from lightning.pytorch.callbacks import DeviceStatsMonitor
from torch.utils.data import DataLoader
from torchdata.dataloader2 import DataLoader2

Just like before, we will create our query to constrain a search for data hosted on NASA’s LP DAAC using spatial and temporal bounds.

bbox = [-119.1, 36.2, -118.2, 36.9]  # West, South, East, North
time_range = ["2021-08-15T00:00:00Z", "2021-09-15T23:59:59Z"]
collection_ids = ["HLSS30.v2.0"]  # Harmonized Landsat-8 Sentinel-2 (HLS)

The following should look familiar. We are creating:

  • A Python dictionary for querying the source HLS imagery

  • A URL pointing to the burn scar vector labels

  • Some configurations related to the GDAL environment

  • Pre-processing functions

image_query = dict(bbox=bbox, datetime=time_range, collections=collection_ids)
vector_url = "https://gist.githubusercontent.com/weiji14/286032ac2498d10e050ba585257dd50d/raw/c897c7c1b3b8354ec8c6e8327df38fcfee79b4ef/burn_scars.geojson"

gdal_env = stackstac.DEFAULT_GDAL_ENV.updated(
    always=dict(
        GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
        GDAL_HTTP_MERGE_CONSECUTIVE_RANGES="YES",
        GDAL_HTTP_COOKIEFILE=os.path.expanduser("~/cookies.txt"),
        GDAL_HTTP_COOKIEJAR=os.path.expanduser("~/cookies.txt"),
    )
)


def cloud_cover_filter(item: pystac.Item, threshold=20) -> bool:
    """
    Return True if less than or equal to 20% cloud cover, else False.
    """
    return item.properties["eo:cloud_cover"] <= threshold


def zero_mask_filter(dataarray: xr.DataArray) -> bool:
    """
    Return True if input xarray.DataArray contain non-zero values, else False.
    """
    return bool(dataarray.max() > 0)


def xr_collate_fn(samples: tuple) -> torch.Tensor:
    """
    Converts xarray.DataArray objects to a torch.Tensor, and stack them into a
    torch.Tensor.
    """
    image_tensors = [
        torch.as_tensor(data=np.nan_to_num(sample[0].data).astype(dtype="int16"))
        for sample in samples
    ]
    mask_tensors = [
        torch.as_tensor(data=sample[1].data.astype(dtype="uint8")) for sample in samples
    ]
    return torch.stack(tensors=image_tensors), torch.stack(tensors=mask_tensors)

Now we will take all of the steps from the prior lesson and compile them into a module based on the Lightning DataModule structure. Note that this module also performs partitioning into train/validation sets, and supports separate dataloaders for training/validation. This way, we can train the model on a training set, and hold out other sets for hyperparameter tuning and final evaluation.

class BurnScarsDataModule(L.LightningDataModule):
    def __init__(self, image_query, vector_url, batch_size):
        super().__init__()
        self.image_query = image_query
        self.vector_url = vector_url
        self.bs = batch_size

    def setup(self, stage):
        if stage is not None:  # train/val/test/predict
            # Datapipe for the raster imagery
            dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.image_query])

            dp_hls_stack = (
                dp.search_for_pystac_item(
                    catalog_url="https://cmr.earthdata.nasa.gov/stac/LPCLOUD",
                )
                .list_pystac_items_by_search()
                # .sharding_filter()
                .filter(filter_fn=cloud_cover_filter)
                .stack_stac_items(
                    assets=["B04", "B03", "B02", "B08", "B12"],  # RGB+NIR+SWIR bands
                    epsg=32611,  # UTM Zone 11N
                    resolution=30,  # Spatial resolution of 30 metres
                    xy_coords="center",  # pixel centroid coords instead of topleft corner
                    dtype=np.float16,  # Use a lightweight data type
                    rescale=False,  # Don't apply scale and offset to prevent UFuncTypeError
                    # https://github.com/gjoseph92/stackstac/issues/133
                    gdal_env=gdal_env,
                )
            )

            # DataPipe for the vector labels
            _, dp_stream = (
                torchdata.datapipes.iter.IterableWrapper(iterable=[self.vector_url])
                .read_from_http()  # outputs a tuple of (url, I/O stream)
                .unzip(sequence_length=2)  # read just the I/O stream from the tuple
            )
            dp_pyogrio = dp_stream.read_from_pyogrio()

            # Fuse raster and vector datapipes
            dp_xbatcher = dp_hls_stack.slice_with_xbatcher(
                input_dims={"time": 1, "y": 512, "x": 512}
            )
            dp_chip_canvas, dp_chip_image = dp_xbatcher.fork(num_instances=2)
            dp_datashader = (
                dp_chip_canvas.canvas_from_xarray().rasterize_with_datashader(
                    vector_datapipe=dp_pyogrio
                )
            )
            dp_hls_mask_filtered = dp_chip_image.zip(dp_datashader)
            # .filter(  # not working due to missing __len__
            #     filter_fn=zero_mask_filter, input_col=1
            # )

            # Batch and Collate, and split into train/val batches
            dp_collate = dp_hls_mask_filtered.batch(batch_size=self.bs).collate(
                collate_fn=xr_collate_fn
            )

            len_dp_collate = len(dp_collate)
            self.dp_train, self.dp_val = dp_collate.random_split(
                total_length=len_dp_collate,
                weights={
                    "train": round(len_dp_collate * 0.7),
                    "val": round(len_dp_collate * 0.3),
                },
                seed=42,
            )

    def graph_dp(self):
        return torchdata.datapipes.utils.to_graph(dp=self.dp_train)

    def show_batch(self):
        it = iter(self.dp_train)
        batch = next(it)

        fig, axes = plt.subplots(nrows=self.bs, ncols=2)
        for i in range(self.bs):
            image = torch.clip(batch[0][i][0, [3, 2, 1], :, :].transpose(2, 0), min=0)
            axes[i][0].imshow(X=image / 2**16)
            axes[i][1].imshow(batch[1][i] / 2**8)
            # ax.set_title(tgt)
        plt.tight_layout()

    def train_dataloader(self):
        return DataLoader(dataset=self.dp_train, batch_size=None)

    def val_dataloader(self):
        return DataLoader(dataset=self.dp_val, batch_size=None)

Model architecture setup#

Now we will set up our model module. We will again adopt Pytorch Lightning’s structure, this time leveraging the Pytorch’s nn.Module (short for neural network module) to set up the core ‘backbone’ neural network model.

Oftentimes, we want to expedite the model’s loss convergence using a pre-trained backbone model. This approach however, can get tricky when our images have a different number of channels than that expected by the pre-trained model we want to fine-tune on. There are solutions to simplify the integration of backbone models with custom training data, one of which we will implement below.

How to handle single channel or multispectral imagery?#

Segmentation Models Pytorch (SMP) abstracts away the complication of using a segmentation ML model with a source dataset that has a different number of bands/channels than the original computer vision dataset the model was designed for. This can be useful for multispectral or hyperspectral imagery which has more bands than typical RGB images.

The next cell is the backbone model class, which we will use to setup a Unet model that will be called within our custom model class later.

class Backbone(nn.Module):
    def __init__(
        self,
        encoder_name,
        encoder_depth,
        encoder_weights,
        in_channels,
        classes,
        activation,
        **kwargs,
    ):
        super().__init__()
        self.backbone = smp.Unet(
            encoder_name=encoder_name,
            encoder_depth=encoder_depth,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
            **kwargs,
        )

    def forward(self, xb):
        return self.backbone(xb)

Custom model class#

We will leverage the Pytorch LightningModule structure to create a custom model that uses the backbone class. Note that we can also add more metrics to optimize and log during training, such as Intersection over Union (IoU) for segmentation, or mean average precision (mAP) if we were training an object detection model.

class BurnScarsSegmentationModel(L.LightningModule):
    def __init__(
        self,
        encoder_name="resnet18",
        encoder_depth=5,
        encoder_weights=None,  # don't use pretrained weights
        in_channels=5,
        classes=1,
        activation=None,
        lr=1e-3,  # learning rate
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()  # saves all hparams as self.hparams
        self.model = Backbone(
            encoder_name=encoder_name,
            encoder_depth=encoder_depth,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
            **kwargs,
        )

    def forward(self, xb):
        return self.model(xb)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        return optimizer

    def one_step(self, batch):
        xb, yb = batch
        img = xb.squeeze().to(dtype=torch.float32)  # torch.Size([4, 5, 512, 512])
        mask = yb.squeeze()  # torch.Size([4, 512, 512])

        # Get predicted mask from model
        out = self(img)

        # Compute Binary Cross Entropy loss
        loss = F.binary_cross_entropy_with_logits(
            input=out.squeeze(), target=mask.to(dtype=torch.float32)
        )
        # Compute Intersection over Union (IoU) metric
        tp, fp, fn, tn = smp.metrics.get_stats(
            output=out.squeeze(), target=mask, mode="binary", threshold=0.5
        )
        iou = smp.metrics.iou_score(tp=tp, fp=fp, fn=fn, tn=tn, reduction="micro")

        return loss, iou

    def training_step(self, batch, batch_idx):
        loss, iou = self.one_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log(
            "train_iou",
            iou,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, iou = self.one_step(batch)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_iou", iou, prog_bar=True, logger=True)

Now we will compile the trainer object, which will leverage the datamodule and custom model modules compiled here as well. Notice that we specify we want the “resnet18” encoder model and our images consist of 5 channels (Red, Green, Blue, NIR, SWIR).

trainer = L.Trainer(
    min_epochs=1,
    max_epochs=2,
    callbacks=[DeviceStatsMonitor()],
    fast_dev_run=True,  # only run 1 training and 1 validation batch as a test
)

datamodule = BurnScarsDataModule(
    image_query=image_query, vector_url=vector_url, batch_size=4
)

model = BurnScarsSegmentationModel(encoder_name="resnet18", in_channels=5)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

We can check a batch of our training data.

datamodule.setup(stage="train")
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
datamodule.show_batch()
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
  times = pd.to_datetime(
---------------------------------------------------------------------------
CPLE_OpenFailedError                      Traceback (most recent call last)
File rasterio/_base.pyx:310, in rasterio._base.DatasetBase.__init__()

File rasterio/_base.pyx:221, in rasterio._base.open_dataset()

File rasterio/_err.pyx:221, in rasterio._err.exc_wrap_pointer()

CPLE_OpenFailedError: '/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.

During handling of the above exception, another exception occurred:

RasterioIOError                           Traceback (most recent call last)
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:326, in AutoParallelRioReader._open(self)
    325 try:
--> 326     ds = SelfCleaningDatasetReader(
    327         self.url, sharing=False
    328     )
    329 except Exception as e:

File rasterio/_base.pyx:312, in rasterio._base.DatasetBase.__init__()

RasterioIOError: '/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 datamodule.show_batch()

Cell In[4], line 75, in BurnScarsDataModule.show_batch(self)
     73 def show_batch(self):
     74     it = iter(self.dp_train)
---> 75     batch = next(it)
     77     fig, axes = plt.subplots(nrows=self.bs, ncols=2)
     78     for i in range(self.bs):

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    171         response = gen.send(None)
    172 else:
--> 173     response = gen.send(None)
    175 while True:
    176     datapipe._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torchdata/datapipes/iter/util/randomsplitter.py:184, in SplitterIterator.__iter__(self)
    182 def __iter__(self):
    183     self.main_datapipe.reset()
--> 184     for sample in self.main_datapipe.source_datapipe:
    185         if self.main_datapipe.draw() == self.target:
    186             yield sample

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    171         response = gen.send(None)
    172 else:
--> 173     response = gen.send(None)
    175 while True:
    176     datapipe._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/callable.py:122, in MapperIterDataPipe.__iter__(self)
    121 def __iter__(self) -> Iterator[T_co]:
--> 122     for data in self.datapipe:
    123         yield self._apply_fn(data)

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    171         response = gen.send(None)
    172 else:
--> 173     response = gen.send(None)
    175 while True:
    176     datapipe._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/grouping.py:70, in BatcherIterDataPipe.__iter__(self)
     68 def __iter__(self) -> Iterator[DataChunk]:
     69     batch: List = []
---> 70     for x in self.datapipe:
     71         batch.append(x)
     72         if len(batch) == self.batch_size:

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    171         response = gen.send(None)
    172 else:
--> 173     response = gen.send(None)
    175 while True:
    176     datapipe._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/combining.py:589, in ZipperIterDataPipe.__iter__(self)
    587 def __iter__(self) -> Iterator[Tuple[T_co]]:
    588     iterators = [iter(datapipe) for datapipe in self.datapipes]
--> 589     yield from zip(*iterators)

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:144, in hook_iterator.<locals>.IteratorDecorator.__next__(self)
    142         return self._get_next()
    143 else:  # Decided against using `contextlib.nullcontext` for performance reasons
--> 144     return self._get_next()

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:132, in hook_iterator.<locals>.IteratorDecorator._get_next(self)
    128 r"""
    129 Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
    130 """
    131 _check_iterator_valid(self.source_dp, self.iterator_id)
--> 132 result = next(self.iterator)
    133 if not self.self_and_has_next_method:
    134     self.source_dp._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/combining.py:163, in _ForkerIterDataPipe.get_next_element_by_instance(self, instance_id)
    161 self.leading_ptr = self.child_pointers[instance_id]
    162 try:
--> 163     return_val = next(self._datapipe_iterator)  # type: ignore[arg-type]
    164     self.buffer.append(return_val)
    165 except StopIteration:

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    171         response = gen.send(None)
    172 else:
--> 173     response = gen.send(None)
    175 while True:
    176     datapipe._number_of_samples_yielded += 1

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/zen3geo/datapipes/xbatcher.py:107, in XbatcherSlicerIterDataPipe.__iter__(self)
    105 def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
    106     for dataarray in self.source_datapipe:
--> 107         for chip in dataarray.batch.generator(
    108             input_dims=self.input_dims, **self.kwargs
    109         ):
    110             yield chip

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xbatcher/generators.py:416, in BatchGenerator.__iter__(self)
    414 def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
    415     for idx in self._batch_selectors.selectors:
--> 416         yield self[idx]

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xbatcher/generators.py:461, in BatchGenerator.__getitem__(self, idx)
    459         batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0])
    460         if self.preload_batch:
--> 461             batch_ds.load()
    462         return _maybe_stack_batch_dims(
    463             batch_ds,
    464             list(self.input_dims),
    465         )
    466 else:

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/dataarray.py:1108, in DataArray.load(self, **kwargs)
   1090 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1091     """Manually trigger loading of this array's data from disk or a
   1092     remote source into memory and return this array.
   1093 
   (...)
   1106     dask.compute
   1107     """
-> 1108     ds = self._to_temp_dataset().load(**kwargs)
   1109     new = self._from_temp_dataset(ds)
   1110     self._variable = new._variable

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/dataset.py:825, in Dataset.load(self, **kwargs)
    822 chunkmanager = get_chunked_array_type(*lazy_data.values())
    824 # evaluate all the chunked arrays simultaneously
--> 825 evaluated_data = chunkmanager.compute(*lazy_data.values(), **kwargs)
    827 for k, data in zip(lazy_data, evaluated_data):
    828     self.variables[k].data = data

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/daskmanager.py:70, in DaskManager.compute(self, *data, **kwargs)
     67 def compute(self, *data: DaskArray, **kwargs) -> tuple[np.ndarray, ...]:
     68     from dask.array import compute
---> 70     return compute(*data, **kwargs)

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/to_dask.py:185, in fetch_raster_window(reader_table, slices, dtype, fill_value)
    178 # Only read if the window we're fetching actually overlaps with the asset
    179 if windows.intersect(current_window, asset_window):
    180     # NOTE: when there are multiple assets, we _could_ parallelize these reads with our own threadpool.
    181     # However, that would probably increase memory usage, since the internal, thread-local GDAL datasets
    182     # would end up copied to even more threads.
    183 
    184     # TODO when the Reader won't be rescaling, support passing `output` to avoid the copy?
--> 185     data = reader.read(current_window)
    187     if all_empty:
    188         # Turn `output` from a broadcast-trick array to a real array, so it's writeable
    189         if (
    190             np.isnan(data)
    191             if np.isnan(fill_value)
    192             else np.equal(data, fill_value)
    193         ).all():
    194             # Unless the data we just read is all empty anyway

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:385, in AutoParallelRioReader.read(self, window, **kwargs)
    384 def read(self, window: Window, **kwargs) -> np.ndarray:
--> 385     reader = self.dataset
    386     try:
    387         result = reader.read(
    388             window=window,
    389             masked=True,
   (...)
    392             **kwargs,
    393         )

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:381, in AutoParallelRioReader.dataset(self)
    379 with self._dataset_lock:
    380     if self._dataset is None:
--> 381         self._dataset = self._open()
    382     return self._dataset

File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:337, in AutoParallelRioReader._open(self)
    332             warnings.warn(msg)
    333             return NodataReader(
    334                 dtype=self.dtype, fill_value=self.fill_value
    335             )
--> 337         raise RuntimeError(msg) from e
    338 if ds.count != 1:
    339     ds.close()

RuntimeError: Error opening 'https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif': RasterioIOError("'/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.")
This exception is thrown by __iter__ of XbatcherSlicerIterDataPipe(input_dims={'time': 1, 'y': 512, 'x': 512}, kwargs={}, source_datapipe=StackSTACStackerIterDataPipe)

Train the model#

Now we call the .fit method on the trainer object to execute the model training. While the training progresses, the trained will log the loss, metrics, and other information specified in the callbacks at each training step.

trainer.fit(model=model, datamodule=datamodule)
  | Name  | Type     | Params
-----------------------------------
0 | model | Backbone | 11.2 M
-----------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.685    Total estimated model params size (MB)