# Model Training with PyTorch Lightning

In this lesson, we will train a PyTorch Lightning model leveraging the
datapipe created in the prior lesson. We'll cover how you can wrap a
torch DataPipe with a pytorch lightning DataModule to quickly get the
advantages of using pytorch lightning for model development, which include:
 * running your model training on any hardware: CPU, GPU, or TPU
 * distributed training for large models and datasets
 * easier logging, metrics, and visualizations

to set up the jupyter env, set up hatch as described in the readme, then:
```
hatch -e nb shell
exit
jupyter lab
```

Let's start by importing the required packages we will use. You will notice
some of them require authentication. Again, we use compositional programming
throughout, so by importing a package you can call methods from that package.

In [1]:
import os
import urllib

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pystac
import segmentation_models_pytorch as smp
import stackstac
import torch
import torch.nn as nn  # PyTorch Lightning NN (neural network) module
import torch.nn.functional as F
import torch.optim as optim  # Training optimizer module
import torchdata
import xarray as xr
import zen3geo
from lightning.pytorch.callbacks import DeviceStatsMonitor
from torch.utils.data import DataLoader
from torchdata.dataloader2 import DataLoader2

Just like before, we will create our query to constrain a search for data
hosted on [NASA's LP DAAC](https://lpdaac.usgs.gov) using spatial and
temporal bounds.

In [2]:
bbox = [-119.1, 36.2, -118.2, 36.9]  # West, South, East, North
time_range = ["2021-08-15T00:00:00Z", "2021-09-15T23:59:59Z"]
collection_ids = ["HLSS30.v2.0"]  # Harmonized Landsat-8 Sentinel-2 (HLS)

The following should look familiar. We are creating:
- A Python dictionary for querying the source HLS imagery
- A URL pointing to the burn scar vector labels
- Some configurations related to the GDAL environment
- Pre-processing functions

In [3]:
image_query = dict(bbox=bbox, datetime=time_range, collections=collection_ids)
vector_url = "https://gist.githubusercontent.com/weiji14/286032ac2498d10e050ba585257dd50d/raw/c897c7c1b3b8354ec8c6e8327df38fcfee79b4ef/burn_scars.geojson"

gdal_env = stackstac.DEFAULT_GDAL_ENV.updated(
    always=dict(
        GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
        GDAL_HTTP_MERGE_CONSECUTIVE_RANGES="YES",
        GDAL_HTTP_COOKIEFILE=os.path.expanduser("~/cookies.txt"),
        GDAL_HTTP_COOKIEJAR=os.path.expanduser("~/cookies.txt"),
    )
)


def cloud_cover_filter(item: pystac.Item, threshold=20) -> bool:
    """
    Return True if less than or equal to 20% cloud cover, else False.
    """
    return item.properties["eo:cloud_cover"] <= threshold


def zero_mask_filter(dataarray: xr.DataArray) -> bool:
    """
    Return True if input xarray.DataArray contain non-zero values, else False.
    """
    return bool(dataarray.max() > 0)


def xr_collate_fn(samples: tuple) -> torch.Tensor:
    """
    Converts xarray.DataArray objects to a torch.Tensor, and stack them into a
    torch.Tensor.
    """
    image_tensors = [
        torch.as_tensor(data=np.nan_to_num(sample[0].data).astype(dtype="int16"))
        for sample in samples
    ]
    mask_tensors = [
        torch.as_tensor(data=sample[1].data.astype(dtype="uint8")) for sample in samples
    ]
    return torch.stack(tensors=image_tensors), torch.stack(tensors=mask_tensors)

Now we will take all of the steps from the prior lesson and compile them into
a module based on the
[Lightning DataModule structure](https://lightning.ai/docs/pytorch/2.0.3/data/datamodule.html).
Note that this module also performs partitioning into train/validation sets,
and supports separate dataloaders for training/validation. This way, we can
train the model on a training set, and hold out other sets for hyperparameter
tuning and final evaluation.

In [4]:
class BurnScarsDataModule(L.LightningDataModule):
    def __init__(self, image_query, vector_url, batch_size):
        super().__init__()
        self.image_query = image_query
        self.vector_url = vector_url
        self.bs = batch_size

    def setup(self, stage):
        if stage is not None:  # train/val/test/predict
            # Datapipe for the raster imagery
            dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.image_query])

            dp_hls_stack = (
                dp.search_for_pystac_item(
                    catalog_url="https://cmr.earthdata.nasa.gov/stac/LPCLOUD",
                )
                .list_pystac_items_by_search()
                # .sharding_filter()
                .filter(filter_fn=cloud_cover_filter)
                .stack_stac_items(
                    assets=["B04", "B03", "B02", "B08", "B12"],  # RGB+NIR+SWIR bands
                    epsg=32611,  # UTM Zone 11N
                    resolution=30,  # Spatial resolution of 30 metres
                    xy_coords="center",  # pixel centroid coords instead of topleft corner
                    dtype=np.float16,  # Use a lightweight data type
                    rescale=False,  # Don't apply scale and offset to prevent UFuncTypeError
                    # https://github.com/gjoseph92/stackstac/issues/133
                    gdal_env=gdal_env,
                )
            )

            # DataPipe for the vector labels
            _, dp_stream = (
                torchdata.datapipes.iter.IterableWrapper(iterable=[self.vector_url])
                .read_from_http()  # outputs a tuple of (url, I/O stream)
                .unzip(sequence_length=2)  # read just the I/O stream from the tuple
            )
            dp_pyogrio = dp_stream.read_from_pyogrio()

            # Fuse raster and vector datapipes
            dp_xbatcher = dp_hls_stack.slice_with_xbatcher(
                input_dims={"time": 1, "y": 512, "x": 512}
            )
            dp_chip_canvas, dp_chip_image = dp_xbatcher.fork(num_instances=2)
            dp_datashader = (
                dp_chip_canvas.canvas_from_xarray().rasterize_with_datashader(
                    vector_datapipe=dp_pyogrio
                )
            )
            dp_hls_mask_filtered = dp_chip_image.zip(dp_datashader)
            # .filter(  # not working due to missing __len__
            #     filter_fn=zero_mask_filter, input_col=1
            # )

            # Batch and Collate, and split into train/val batches
            dp_collate = dp_hls_mask_filtered.batch(batch_size=self.bs).collate(
                collate_fn=xr_collate_fn
            )

            len_dp_collate = len(dp_collate)
            self.dp_train, self.dp_val = dp_collate.random_split(
                total_length=len_dp_collate,
                weights={
                    "train": round(len_dp_collate * 0.7),
                    "val": round(len_dp_collate * 0.3),
                },
                seed=42,
            )

    def graph_dp(self):
        return torchdata.datapipes.utils.to_graph(dp=self.dp_train)

    def show_batch(self):
        it = iter(self.dp_train)
        batch = next(it)

        fig, axes = plt.subplots(nrows=self.bs, ncols=2)
        for i in range(self.bs):
            image = torch.clip(batch[0][i][0, [3, 2, 1], :, :].transpose(2, 0), min=0)
            axes[i][0].imshow(X=image / 2**16)
            axes[i][1].imshow(batch[1][i] / 2**8)
            # ax.set_title(tgt)
        plt.tight_layout()

    def train_dataloader(self):
        return DataLoader(dataset=self.dp_train, batch_size=None)

    def val_dataloader(self):
        return DataLoader(dataset=self.dp_val, batch_size=None)

## Model architecture setup

Now we will set up our model module. We will again adopt Pytorch Lightning's
structure, this time leveraging the Pytorch's
[nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module)
(short for neural network module) to set up the core 'backbone' neural
network model.

Oftentimes, we want to expedite the model's loss convergence using a
pre-trained backbone model. This approach however, can get tricky when our
images have a different number of channels than that expected by the
pre-trained model we want to fine-tune on. There are solutions to simplify
the integration of backbone models with custom training data, one of which
we will implement below.

### How to handle single channel or multispectral imagery?

[Segmentation Models Pytorch (SMP)](https://smp.readthedocs.io/en/latest)
abstracts away the complication of using a segmentation ML model with a
source dataset that has a different number of bands/channels than the
original computer vision dataset the model was designed for. This can be
useful for multispectral or hyperspectral imagery which has more bands than
typical RGB images.

The next cell is the backbone model class, which we will use to setup a
[Unet](https://arxiv.org/abs/1505.04597) model that will be called within our
custom model class later.

In [5]:
class Backbone(nn.Module):
    def __init__(
        self,
        encoder_name,
        encoder_depth,
        encoder_weights,
        in_channels,
        classes,
        activation,
        **kwargs,
    ):
        super().__init__()
        self.backbone = smp.Unet(
            encoder_name=encoder_name,
            encoder_depth=encoder_depth,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
            **kwargs,
        )

    def forward(self, xb):
        return self.backbone(xb)

#### Custom model class

We will leverage the Pytorch LightningModule structure to create a custom
model that uses the backbone class. Note that we can also add more metrics to
optimize and log during training, such as Intersection over Union (IoU) for
segmentation, or mean average precision (mAP) if we were training an object
detection model.

In [6]:
class BurnScarsSegmentationModel(L.LightningModule):
    def __init__(
        self,
        encoder_name="resnet18",
        encoder_depth=5,
        encoder_weights=None,  # don't use pretrained weights
        in_channels=5,
        classes=1,
        activation=None,
        lr=1e-3,  # learning rate
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()  # saves all hparams as self.hparams
        self.model = Backbone(
            encoder_name=encoder_name,
            encoder_depth=encoder_depth,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
            **kwargs,
        )

    def forward(self, xb):
        return self.model(xb)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        return optimizer

    def one_step(self, batch):
        xb, yb = batch
        img = xb.squeeze().to(dtype=torch.float32)  # torch.Size([4, 5, 512, 512])
        mask = yb.squeeze()  # torch.Size([4, 512, 512])

        # Get predicted mask from model
        out = self(img)

        # Compute Binary Cross Entropy loss
        loss = F.binary_cross_entropy_with_logits(
            input=out.squeeze(), target=mask.to(dtype=torch.float32)
        )
        # Compute Intersection over Union (IoU) metric
        tp, fp, fn, tn = smp.metrics.get_stats(
            output=out.squeeze(), target=mask, mode="binary", threshold=0.5
        )
        iou = smp.metrics.iou_score(tp=tp, fp=fp, fn=fn, tn=tn, reduction="micro")

        return loss, iou

    def training_step(self, batch, batch_idx):
        loss, iou = self.one_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log(
            "train_iou",
            iou,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, iou = self.one_step(batch)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_iou", iou, prog_bar=True, logger=True)

Now we will compile the trainer object, which will leverage the datamodule
and custom model modules compiled here as well. Notice that we specify we
want the "resnet18" encoder model and our images consist of 5 channels
(Red, Green, Blue, NIR, SWIR).

In [7]:
trainer = L.Trainer(
    min_epochs=1,
    max_epochs=2,
    callbacks=[DeviceStatsMonitor()],
    fast_dev_run=True,  # only run 1 training and 1 validation batch as a test
)

datamodule = BurnScarsDataModule(
    image_query=image_query, vector_url=vector_url, batch_size=4
)

model = BurnScarsSegmentationModel(encoder_name="resnet18", in_channels=5)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


We can check a batch of our training data.

In [8]:
datamodule.setup(stage="train")

In [None]:
datamodule.show_batch()

## Train the model

Now we call the `.fit` method on the trainer object to execute the model
training. While the training progresses, the trained will log the loss,
metrics, and other information specified in the callbacks at each training
step.

In [10]:
trainer.fit(model=model, datamodule=datamodule)


  | Name  | Type     | Params
-----------------------------------
0 | model | Backbone | 11.2 M
-----------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.685    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]