Tuesday, January 7, 2025
Google search engine
HomeGuest BlogsHigher-level PyTorch APIs: A short introduction to PyTorch Lightning 

Higher-level PyTorch APIs: A short introduction to PyTorch Lightning 

In recent years, the PyTorch community developed several different libraries and APIs on top of PyTorch. PyTorch Lightning (Lightning for short) is one of them, and it makes training deep neural networks simpler by removing much of the boilerplate code. However, while Lightning’s focus lies in simplicity and flexibility, it also allows us to use many advanced features such as multi-GPU support and fast low-precision training, which you can learn about in the official documentation at https://pytorch-lightning.rtfd.io/en/latest/. 

This article is an excerpt from the book Machine Learning with PyTorch and Scikit-Learn, the PyTorch version of the widely acclaimed and bestselling Python Machine Learning series, fully updated and expanded to cover PyTorch, transformers, graph neural networks, and best practices. 

In this section, we’ll implement a multilayer perceptron for classifying handwritten digits in the MNIST dataset using PyTorch Lightning. 

Setting up the PyTorch Lightning model 

We start by implementing the model.  

Defining a model for PyTorch Lightning is relatively straightforward as it is based on regular Python and PyTorch code. All that is required to implement a Lightning model is to use LightningModule instead of the regular PyTorch module. To take advantage of PyTorch’s convenience functions, such as the trainer API and automatic logging, we just define a few specifically named methods: 

import pytorch_lightning as pl 
import torch  
import torch.nn as nn   

from torchmetrics import Accuracy  

class MultiLayerPerceptron(pl.LightningModule): 
    def __init__(self,image_shape=(1, 28, 28), hidden_units=(32, 16)): 
        super().__init__()      

        # new PL attributes: 
        self.train_acc = Accuracy() 
        self.valid_acc = Accuracy() 
        self.test_acc = Accuracy()     

        # Model similar to previous section: 
        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        all_layers = [nn.Flatten()] 
        for hidden_unit in hidden_units:  
            layer = nn.Linear(input_size, hidden_unit)  
            all_layers.append(layer)  
            all_layers.append(nn.ReLU())  
            input_size = hidden_unit    

        all_layers.append(nn.Linear(hidden_units[-1], 10))  
        all_layers.append(nn.Softmax(dim=1))  
        self.model = nn.Sequential(*all_layers)  

    def forward(self, x): 
        x = self.model(x) 
        return x 

    def training_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.train_acc.update(preds, y) 
        self.log("train_loss", loss, prog_bar=True) 
        return loss 

    def training_epoch_end(self, outs): 
        self.log("train_acc", self.train_acc.compute())      

    def validation_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.valid_acc.update(preds, y) 
        self.log("valid_loss", loss, prog_bar=True) 
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True) 
        return loss  

    def test_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.test_acc.update(preds, y) 
        self.log("test_loss", loss, prog_bar=True) 
        self.log("test_acc", self.test_acc.compute(), prog_bar=True) 
        return loss  

    def configure_optimizers(self): 
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001) 
        return optimizer 

As you can see, the __init__ constructor includes the accuracy attributes, such as self.train_acc = Accuracy(). These will allow us to track the accuracy during training. Accuracy was imported from the torchmetrics module, which should be automatically installed with Lightning. If you cannot import torchmetrics, you can try to install it via pip install torchmetrics. More information can be found at https://torchmetrics.readthedocs.io/en/latest/pages/quickstart.html. 

The forward method implements a simple forward pass that returns the logits (outputs of the last fully connected layer of our network before the softmax layer) when we call our model on the input data. The logits, computed via the forward method by calling self(x), are used for the training, validation, and test steps, which we’ll describe next. 

The training_step, training_epoch_end, validation_step, test_step, and configure_optimizers methods are methods that are specifically recognized by Lightning. For instance, training_step defines a single forward pass during training, where we also keep track of the accuracy and loss so that we can analyze these later. Note that we compute the accuracy via self.train_acc.update(preds, y) but don’t log it yet. The training_step method is executed on each individual batch during training, and via the training_epoch_end method, which is executed at the end of each training epoch, we compute the training set accuracy from the accuracy values we accumulated via training. 

The validation_step and test_step methods define, analogous to the training_step method, how the validation and test evaluation process should be computed. Similar to training_step, each validation_step and test_step receives a single batch, which is why we log the accuracy via respective accuracy attributes derived from Accuracy of torchmetric. However, note that validation_step is only called in certain intervals, for example, after each training epoch. This is why we log the validation accuracy inside the validation step, whereas with the training accuracy, we log it after each training epoch, otherwise, the accuracy plot that we inspect later will look too noisy. 

Finally, via the configure_optimizers method, we specify the optimizer used for training. The following two subsections will discuss how we can set up the dataset and how we can train the model. 

Setting up the data loaders for Lightning 

There are three main ways in which we can prepare the dataset for PyTorch Lightning. We can: 

  • Make the dataset part of the model 
  • Set up the data loaders as usual and feed them to the fit method of a Lightning Trainer—the Trainer is introduced in the next subsection 
  • Create a LightningDataModule 

Here, we are going to use a LightningDataModule, which is the most organized approach. The LightningDataModule consists of five main methods, as we can see in the following:  

from torch.utils.data import DataLoader 
from torch.utils.data import random_split 
from torchvision.datasets import MNIST 
from torchvision import transforms  

class MnistDataModule(pl.LightningDataModule): 
    def __init__(self, data_path='./'): 
        super().__init__() 
        self.data_path = data_path 
        self.transform = transforms.Compose([transforms.ToTensor()])          

    def prepare_data(self): 
        MNIST(root=self.data_path, download=True  

    def setup(self, stage=None): 
        # stage is either 'fit', 'validate', 'test', or 'predict' 
        # here note relevant 
        mnist_all = MNIST(  
            root=self.data_path, 
            train=True, 
            transform=self.transform,   
            download=False 
        )   

        self.train, self.val = random_split( 
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1) 
        )  

        self.test = MNIST(  
            root=self.data_path, 
            train=False, 
            transform=self.transform,   
            download=False 
        )   

    def train_dataloader(self): 
        return DataLoader(self.train, batch_size=64, num_workers=4)  

    def val_dataloader(self): 
        return DataLoader(self.val, batch_size=64, num_workers=4)  

    def test_dataloader(self): 
        return DataLoader(self.test, batch_size=64, num_workers=4) 

In the prepare_data method, we define general steps, such as downloading the dataset. In the setup method, we define the datasets used for training, validation, and testing. Note that MNIST does not have a dedicated validation split, which is why we use the random_split function to divide the 60,000-example training set into 55,000 examples for training and 5,000 examples for validation. 

The data loader methods are self-explanatory and define how the respective datasets are loaded. Now, we can initialize the data module and use it for training, validation, and testing in the next subsections: 

torch.manual_seed(1 
mnist_dm = MnistDataModule() 

Training the model using the PyTorch Lightning Trainer class 

Now we can reap the rewards from setting up the model with the specifically named methods, as well as the Lightning data module. Lightning implements a Trainer class that makes the training model super convenient by taking care of all the intermediate steps, such as calling zero_grad(), backward(), and optimizer.step() for us. Also, as a bonus, it lets us easily specify one or more GPUs to use (if available): 

mnistclassifier = MultiLayerPerceptron()  

if torch.cuda.is_available(): # if you have GPUs 
    trainer = pl.Trainer(max_epochs=10, gpus=1) 
else: 
    trainer = pl.Trainer(max_epochs=10)  

trainer.fit(model=mnistclassifier, datamodule=mnist_dm) 

Via the preceding code, we train our multilayer perceptron for 10 epochs. During training, we see a handy progress bar that keeps track of the epoch and core metrics such as the training and validation losses: 

Epoch 9: 100% 939/939 [00:12<00:00, 73.16it/s, loss=1.5, v_num=0, train_loss=1.590, valid_loss=1.520, valid_acc=0.927] 

Evaluating the model using TensorBoard 

Another nice feature of Lightning is its logging capabilities. Recall that we specified several self.log steps in our Lightning model earlier. After, and even during training, we can visualize them in TensorBoard. TensorBoard can be installed via pip or conda, depending on your preference.  

By default, Lightning tracks the training in a subfolder named lightning_logs. To visualize the training runs, you can execute the following code in the command-line terminal, which will open TensorBoard in your browser: 

tensorboard --logdir lightning_logs/ 

Alternatively, if you are running the code in a Jupyter notebook, you can add the following code to a notebook cell to show the TensorBoard dashboard in the notebook directly: 

%load_ext tensorboard 
%tensorboard --logdir lightning_logs/ 

Figure 1 shows the TensorBoard dashboard with the logged training and validation accuracy. Note that there is a version_0 toggle shown in the lower-left corner. If you run the training code multiple times, PyTorch Lightning will track them as separate subfolders: version_0, version_1, version_2, and so forth: 

 Figure 1: TensorBoard dashboard 

By looking at the training and validation accuracies in Figure 1, we can hypothesize that training the model for a few additional epochs can improve performance. 

Learn the PyTorch essentials and how to create models using popular libraries and more with the book Machine Learning with PyTorch and Scikit-Learn. 

About the authors 

Sebastian Raschka is an Assistant Professor of Statistics at the University of Wisconsin-Madison, focusing on machine learning and deep learning research. As Lead AI Educator at Grid.ai, Sebastian plans to continue following his passion for helping people get into machine learning and artificial intelligence.  

Yuxi (Hayden) Liu is a Machine Learning Software Engineer at Google. He is developing and improving machine learning models and systems for ads optimization on the largest search engine in the world.  

Vahid Mirjalili is a Deep Learning Researcher focusing on CV applications. Vahid received a Ph.D. degree in both Mechanical Engineering and Computer Science from Michigan State University. 

RELATED ARTICLES

Most Popular

Recent Comments