Model Training with PyTorch Lightning#
In this lesson, we will train a PyTorch Lightning model leveraging the datapipe created in the prior lesson. We’ll cover how you can wrap a torch DataPipe with a pytorch lightning DataModule to quickly get the advantages of using pytorch lightning for model development, which include:
running your model training on any hardware: CPU, GPU, or TPU
distributed training for large models and datasets
easier logging, metrics, and visualizations
to set up the jupyter env, set up hatch as described in the readme, then:
hatch -e nb shell
exit
jupyter lab
Let’s start by importing the required packages we will use. You will notice some of them require authentication. Again, we use compositional programming throughout, so by importing a package you can call methods from that package.
import os
import urllib
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pystac
import segmentation_models_pytorch as smp
import stackstac
import torch
import torch.nn as nn # PyTorch Lightning NN (neural network) module
import torch.nn.functional as F
import torch.optim as optim # Training optimizer module
import torchdata
import xarray as xr
import zen3geo
from lightning.pytorch.callbacks import DeviceStatsMonitor
from torch.utils.data import DataLoader
from torchdata.dataloader2 import DataLoader2
Just like before, we will create our query to constrain a search for data hosted on NASA’s LP DAAC using spatial and temporal bounds.
bbox = [-119.1, 36.2, -118.2, 36.9] # West, South, East, North
time_range = ["2021-08-15T00:00:00Z", "2021-09-15T23:59:59Z"]
collection_ids = ["HLSS30.v2.0"] # Harmonized Landsat-8 Sentinel-2 (HLS)
The following should look familiar. We are creating:
A Python dictionary for querying the source HLS imagery
A URL pointing to the burn scar vector labels
Some configurations related to the GDAL environment
Pre-processing functions
image_query = dict(bbox=bbox, datetime=time_range, collections=collection_ids)
vector_url = "https://gist.githubusercontent.com/weiji14/286032ac2498d10e050ba585257dd50d/raw/c897c7c1b3b8354ec8c6e8327df38fcfee79b4ef/burn_scars.geojson"
gdal_env = stackstac.DEFAULT_GDAL_ENV.updated(
always=dict(
GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
GDAL_HTTP_MERGE_CONSECUTIVE_RANGES="YES",
GDAL_HTTP_COOKIEFILE=os.path.expanduser("~/cookies.txt"),
GDAL_HTTP_COOKIEJAR=os.path.expanduser("~/cookies.txt"),
)
)
def cloud_cover_filter(item: pystac.Item, threshold=20) -> bool:
"""
Return True if less than or equal to 20% cloud cover, else False.
"""
return item.properties["eo:cloud_cover"] <= threshold
def zero_mask_filter(dataarray: xr.DataArray) -> bool:
"""
Return True if input xarray.DataArray contain non-zero values, else False.
"""
return bool(dataarray.max() > 0)
def xr_collate_fn(samples: tuple) -> torch.Tensor:
"""
Converts xarray.DataArray objects to a torch.Tensor, and stack them into a
torch.Tensor.
"""
image_tensors = [
torch.as_tensor(data=np.nan_to_num(sample[0].data).astype(dtype="int16"))
for sample in samples
]
mask_tensors = [
torch.as_tensor(data=sample[1].data.astype(dtype="uint8")) for sample in samples
]
return torch.stack(tensors=image_tensors), torch.stack(tensors=mask_tensors)
Now we will take all of the steps from the prior lesson and compile them into a module based on the Lightning DataModule structure. Note that this module also performs partitioning into train/validation sets, and supports separate dataloaders for training/validation. This way, we can train the model on a training set, and hold out other sets for hyperparameter tuning and final evaluation.
class BurnScarsDataModule(L.LightningDataModule):
def __init__(self, image_query, vector_url, batch_size):
super().__init__()
self.image_query = image_query
self.vector_url = vector_url
self.bs = batch_size
def setup(self, stage):
if stage is not None: # train/val/test/predict
# Datapipe for the raster imagery
dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.image_query])
dp_hls_stack = (
dp.search_for_pystac_item(
catalog_url="https://cmr.earthdata.nasa.gov/stac/LPCLOUD",
)
.list_pystac_items_by_search()
# .sharding_filter()
.filter(filter_fn=cloud_cover_filter)
.stack_stac_items(
assets=["B04", "B03", "B02", "B08", "B12"], # RGB+NIR+SWIR bands
epsg=32611, # UTM Zone 11N
resolution=30, # Spatial resolution of 30 metres
xy_coords="center", # pixel centroid coords instead of topleft corner
dtype=np.float16, # Use a lightweight data type
rescale=False, # Don't apply scale and offset to prevent UFuncTypeError
# https://github.com/gjoseph92/stackstac/issues/133
gdal_env=gdal_env,
)
)
# DataPipe for the vector labels
_, dp_stream = (
torchdata.datapipes.iter.IterableWrapper(iterable=[self.vector_url])
.read_from_http() # outputs a tuple of (url, I/O stream)
.unzip(sequence_length=2) # read just the I/O stream from the tuple
)
dp_pyogrio = dp_stream.read_from_pyogrio()
# Fuse raster and vector datapipes
dp_xbatcher = dp_hls_stack.slice_with_xbatcher(
input_dims={"time": 1, "y": 512, "x": 512}
)
dp_chip_canvas, dp_chip_image = dp_xbatcher.fork(num_instances=2)
dp_datashader = (
dp_chip_canvas.canvas_from_xarray().rasterize_with_datashader(
vector_datapipe=dp_pyogrio
)
)
dp_hls_mask_filtered = dp_chip_image.zip(dp_datashader)
# .filter( # not working due to missing __len__
# filter_fn=zero_mask_filter, input_col=1
# )
# Batch and Collate, and split into train/val batches
dp_collate = dp_hls_mask_filtered.batch(batch_size=self.bs).collate(
collate_fn=xr_collate_fn
)
len_dp_collate = len(dp_collate)
self.dp_train, self.dp_val = dp_collate.random_split(
total_length=len_dp_collate,
weights={
"train": round(len_dp_collate * 0.7),
"val": round(len_dp_collate * 0.3),
},
seed=42,
)
def graph_dp(self):
return torchdata.datapipes.utils.to_graph(dp=self.dp_train)
def show_batch(self):
it = iter(self.dp_train)
batch = next(it)
fig, axes = plt.subplots(nrows=self.bs, ncols=2)
for i in range(self.bs):
image = torch.clip(batch[0][i][0, [3, 2, 1], :, :].transpose(2, 0), min=0)
axes[i][0].imshow(X=image / 2**16)
axes[i][1].imshow(batch[1][i] / 2**8)
# ax.set_title(tgt)
plt.tight_layout()
def train_dataloader(self):
return DataLoader(dataset=self.dp_train, batch_size=None)
def val_dataloader(self):
return DataLoader(dataset=self.dp_val, batch_size=None)
Model architecture setup#
Now we will set up our model module. We will again adopt Pytorch Lightning’s structure, this time leveraging the Pytorch’s nn.Module (short for neural network module) to set up the core ‘backbone’ neural network model.
Oftentimes, we want to expedite the model’s loss convergence using a pre-trained backbone model. This approach however, can get tricky when our images have a different number of channels than that expected by the pre-trained model we want to fine-tune on. There are solutions to simplify the integration of backbone models with custom training data, one of which we will implement below.
How to handle single channel or multispectral imagery?#
Segmentation Models Pytorch (SMP) abstracts away the complication of using a segmentation ML model with a source dataset that has a different number of bands/channels than the original computer vision dataset the model was designed for. This can be useful for multispectral or hyperspectral imagery which has more bands than typical RGB images.
The next cell is the backbone model class, which we will use to setup a Unet model that will be called within our custom model class later.
class Backbone(nn.Module):
def __init__(
self,
encoder_name,
encoder_depth,
encoder_weights,
in_channels,
classes,
activation,
**kwargs,
):
super().__init__()
self.backbone = smp.Unet(
encoder_name=encoder_name,
encoder_depth=encoder_depth,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation,
**kwargs,
)
def forward(self, xb):
return self.backbone(xb)
Custom model class#
We will leverage the Pytorch LightningModule structure to create a custom model that uses the backbone class. Note that we can also add more metrics to optimize and log during training, such as Intersection over Union (IoU) for segmentation, or mean average precision (mAP) if we were training an object detection model.
class BurnScarsSegmentationModel(L.LightningModule):
def __init__(
self,
encoder_name="resnet18",
encoder_depth=5,
encoder_weights=None, # don't use pretrained weights
in_channels=5,
classes=1,
activation=None,
lr=1e-3, # learning rate
**kwargs,
):
super().__init__()
self.save_hyperparameters() # saves all hparams as self.hparams
self.model = Backbone(
encoder_name=encoder_name,
encoder_depth=encoder_depth,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation,
**kwargs,
)
def forward(self, xb):
return self.model(xb)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
return optimizer
def one_step(self, batch):
xb, yb = batch
img = xb.squeeze().to(dtype=torch.float32) # torch.Size([4, 5, 512, 512])
mask = yb.squeeze() # torch.Size([4, 512, 512])
# Get predicted mask from model
out = self(img)
# Compute Binary Cross Entropy loss
loss = F.binary_cross_entropy_with_logits(
input=out.squeeze(), target=mask.to(dtype=torch.float32)
)
# Compute Intersection over Union (IoU) metric
tp, fp, fn, tn = smp.metrics.get_stats(
output=out.squeeze(), target=mask, mode="binary", threshold=0.5
)
iou = smp.metrics.iou_score(tp=tp, fp=fp, fn=fn, tn=tn, reduction="micro")
return loss, iou
def training_step(self, batch, batch_idx):
loss, iou = self.one_step(batch)
self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log(
"train_iou",
iou,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
loss, iou = self.one_step(batch)
self.log("val_loss", loss, prog_bar=True, logger=True)
self.log("val_iou", iou, prog_bar=True, logger=True)
Now we will compile the trainer object, which will leverage the datamodule and custom model modules compiled here as well. Notice that we specify we want the “resnet18” encoder model and our images consist of 5 channels (Red, Green, Blue, NIR, SWIR).
trainer = L.Trainer(
min_epochs=1,
max_epochs=2,
callbacks=[DeviceStatsMonitor()],
fast_dev_run=True, # only run 1 training and 1 validation batch as a test
)
datamodule = BurnScarsDataModule(
image_query=image_query, vector_url=vector_url, batch_size=4
)
model = BurnScarsSegmentationModel(encoder_name="resnet18", in_channels=5)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
warning_cache.warn(
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
We can check a batch of our training data.
datamodule.setup(stage="train")
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
datamodule.show_batch()
/home/runner/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/prepare.py:364: UserWarning: The argument 'infer_datetime_format' is deprecated and will be removed in a future version. A strict version of it is now the default, see https://pandas.pydata.org/pdeps/0004-consistent-to-datetime-parsing.html. You can safely remove this argument.
times = pd.to_datetime(
---------------------------------------------------------------------------
CPLE_OpenFailedError Traceback (most recent call last)
File rasterio/_base.pyx:310, in rasterio._base.DatasetBase.__init__()
File rasterio/_base.pyx:221, in rasterio._base.open_dataset()
File rasterio/_err.pyx:221, in rasterio._err.exc_wrap_pointer()
CPLE_OpenFailedError: '/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.
During handling of the above exception, another exception occurred:
RasterioIOError Traceback (most recent call last)
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:326, in AutoParallelRioReader._open(self)
325 try:
--> 326 ds = SelfCleaningDatasetReader(
327 self.url, sharing=False
328 )
329 except Exception as e:
File rasterio/_base.pyx:312, in rasterio._base.DatasetBase.__init__()
RasterioIOError: '/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 datamodule.show_batch()
Cell In[4], line 75, in BurnScarsDataModule.show_batch(self)
73 def show_batch(self):
74 it = iter(self.dp_train)
---> 75 batch = next(it)
77 fig, axes = plt.subplots(nrows=self.bs, ncols=2)
78 for i in range(self.bs):
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
171 response = gen.send(None)
172 else:
--> 173 response = gen.send(None)
175 while True:
176 datapipe._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torchdata/datapipes/iter/util/randomsplitter.py:184, in SplitterIterator.__iter__(self)
182 def __iter__(self):
183 self.main_datapipe.reset()
--> 184 for sample in self.main_datapipe.source_datapipe:
185 if self.main_datapipe.draw() == self.target:
186 yield sample
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
171 response = gen.send(None)
172 else:
--> 173 response = gen.send(None)
175 while True:
176 datapipe._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/callable.py:122, in MapperIterDataPipe.__iter__(self)
121 def __iter__(self) -> Iterator[T_co]:
--> 122 for data in self.datapipe:
123 yield self._apply_fn(data)
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
171 response = gen.send(None)
172 else:
--> 173 response = gen.send(None)
175 while True:
176 datapipe._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/grouping.py:70, in BatcherIterDataPipe.__iter__(self)
68 def __iter__(self) -> Iterator[DataChunk]:
69 batch: List = []
---> 70 for x in self.datapipe:
71 batch.append(x)
72 if len(batch) == self.batch_size:
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
171 response = gen.send(None)
172 else:
--> 173 response = gen.send(None)
175 while True:
176 datapipe._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/combining.py:589, in ZipperIterDataPipe.__iter__(self)
587 def __iter__(self) -> Iterator[Tuple[T_co]]:
588 iterators = [iter(datapipe) for datapipe in self.datapipes]
--> 589 yield from zip(*iterators)
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:144, in hook_iterator.<locals>.IteratorDecorator.__next__(self)
142 return self._get_next()
143 else: # Decided against using `contextlib.nullcontext` for performance reasons
--> 144 return self._get_next()
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:132, in hook_iterator.<locals>.IteratorDecorator._get_next(self)
128 r"""
129 Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
130 """
131 _check_iterator_valid(self.source_dp, self.iterator_id)
--> 132 result = next(self.iterator)
133 if not self.self_and_has_next_method:
134 self.source_dp._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/combining.py:163, in _ForkerIterDataPipe.get_next_element_by_instance(self, instance_id)
161 self.leading_ptr = self.child_pointers[instance_id]
162 try:
--> 163 return_val = next(self._datapipe_iterator) # type: ignore[arg-type]
164 self.buffer.append(return_val)
165 except StopIteration:
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/torch/utils/data/datapipes/_hook_iterator.py:173, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
171 response = gen.send(None)
172 else:
--> 173 response = gen.send(None)
175 while True:
176 datapipe._number_of_samples_yielded += 1
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/zen3geo/datapipes/xbatcher.py:107, in XbatcherSlicerIterDataPipe.__iter__(self)
105 def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
106 for dataarray in self.source_datapipe:
--> 107 for chip in dataarray.batch.generator(
108 input_dims=self.input_dims, **self.kwargs
109 ):
110 yield chip
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xbatcher/generators.py:416, in BatchGenerator.__iter__(self)
414 def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
415 for idx in self._batch_selectors.selectors:
--> 416 yield self[idx]
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xbatcher/generators.py:461, in BatchGenerator.__getitem__(self, idx)
459 batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0])
460 if self.preload_batch:
--> 461 batch_ds.load()
462 return _maybe_stack_batch_dims(
463 batch_ds,
464 list(self.input_dims),
465 )
466 else:
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/dataarray.py:1108, in DataArray.load(self, **kwargs)
1090 def load(self: T_DataArray, **kwargs) -> T_DataArray:
1091 """Manually trigger loading of this array's data from disk or a
1092 remote source into memory and return this array.
1093
(...)
1106 dask.compute
1107 """
-> 1108 ds = self._to_temp_dataset().load(**kwargs)
1109 new = self._from_temp_dataset(ds)
1110 self._variable = new._variable
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/dataset.py:825, in Dataset.load(self, **kwargs)
822 chunkmanager = get_chunked_array_type(*lazy_data.values())
824 # evaluate all the chunked arrays simultaneously
--> 825 evaluated_data = chunkmanager.compute(*lazy_data.values(), **kwargs)
827 for k, data in zip(lazy_data, evaluated_data):
828 self.variables[k].data = data
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/xarray/core/daskmanager.py:70, in DaskManager.compute(self, *data, **kwargs)
67 def compute(self, *data: DaskArray, **kwargs) -> tuple[np.ndarray, ...]:
68 from dask.array import compute
---> 70 return compute(*data, **kwargs)
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
86 elif isinstance(pool, multiprocessing.pool.Pool):
87 pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
90 pool.submit,
91 pool._max_workers,
92 dsk,
93 keys,
94 cache=cache,
95 get_id=_thread_get_id,
96 pack_exception=pack_exception,
97 **kwargs,
98 )
100 # Cleanup pools associated to dead threads
101 with pools_lock:
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
509 _execute_task(task, data) # Re-execute locally
510 else:
--> 511 raise_exception(exc, tb)
512 res, worker_id = loads(res_info)
513 state["cache"][key] = res
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
317 if exc.__traceback__ is not tb:
318 raise exc.with_traceback(tb)
--> 319 raise exc
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
222 try:
223 task, data = loads(task_info)
--> 224 result = _execute_task(task, data)
225 id = get_id()
226 result = dumps((result, id))
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/to_dask.py:185, in fetch_raster_window(reader_table, slices, dtype, fill_value)
178 # Only read if the window we're fetching actually overlaps with the asset
179 if windows.intersect(current_window, asset_window):
180 # NOTE: when there are multiple assets, we _could_ parallelize these reads with our own threadpool.
181 # However, that would probably increase memory usage, since the internal, thread-local GDAL datasets
182 # would end up copied to even more threads.
183
184 # TODO when the Reader won't be rescaling, support passing `output` to avoid the copy?
--> 185 data = reader.read(current_window)
187 if all_empty:
188 # Turn `output` from a broadcast-trick array to a real array, so it's writeable
189 if (
190 np.isnan(data)
191 if np.isnan(fill_value)
192 else np.equal(data, fill_value)
193 ).all():
194 # Unless the data we just read is all empty anyway
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:385, in AutoParallelRioReader.read(self, window, **kwargs)
384 def read(self, window: Window, **kwargs) -> np.ndarray:
--> 385 reader = self.dataset
386 try:
387 result = reader.read(
388 window=window,
389 masked=True,
(...)
392 **kwargs,
393 )
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:381, in AutoParallelRioReader.dataset(self)
379 with self._dataset_lock:
380 if self._dataset is None:
--> 381 self._dataset = self._open()
382 return self._dataset
File ~/micromamba-root/envs/mlpipeline/lib/python3.9/site-packages/stackstac/rio_reader.py:337, in AutoParallelRioReader._open(self)
332 warnings.warn(msg)
333 return NodataReader(
334 dtype=self.dtype, fill_value=self.fill_value
335 )
--> 337 raise RuntimeError(msg) from e
338 if ds.count != 1:
339 ds.close()
RuntimeError: Error opening 'https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif': RasterioIOError("'/vsicurl/https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/HLS.S30.T11SLA.2021231T183919.v2.0/HLS.S30.T11SLA.2021231T183919.v2.0.B04.tif' not recognized as a supported file format.")
This exception is thrown by __iter__ of XbatcherSlicerIterDataPipe(input_dims={'time': 1, 'y': 512, 'x': 512}, kwargs={}, source_datapipe=StackSTACStackerIterDataPipe)
Train the model#
Now we call the .fit
method on the trainer object to execute the model
training. While the training progresses, the trained will log the loss,
metrics, and other information specified in the callbacks at each training
step.
trainer.fit(model=model, datamodule=datamodule)
| Name | Type | Params
-----------------------------------
0 | model | Backbone | 11.2 M
-----------------------------------
11.2 M Trainable params
0 Non-trainable params
11.2 M Total params
44.685 Total estimated model params size (MB)