理解 PyTorch Lightning DataModules

引言

PyTorch Lightning 旨在让我们的 PyTorch 代码更加结构化和易读,这种优化不仅仅局限于 PyTorch 模型本身,也包括数据处理部分。在标准的 PyTorch 中,我们通常使用 DataLoaders 来训练或测试模型。虽然我们仍然可以在 PyTorch Lightning 中使用 DataLoaders 来训练模型,但 PyTorch Lightning 为我们提供了一种更优雅的解决方案,那就是 DataModules。

DataModule 是一个可复用且可共享的类,它将 DataLoaders 以及处理数据所需的步骤封装在一起。手动创建 dataloaders 可能会变得非常杂乱,因此将数据集以 DataModule 的形式进行组织是一个更好的选择。在开始之前,建议您先了解如何使用 PyTorch Lightning 定义神经网络

安装 PyTorch Lightning

安装 Lightning 的过程与安装 Python 中其他任何库一样简单。

pip install pytorch-lightning

或者,如果您想在 conda 环境中安装它,可以使用以下命令:

conda install -c conda-forge pytorch-lightning

Pytorch Lightning DataModule 格式

为了定义一个 Lightning DataModule,我们需要遵循以下格式:

import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader

class DataModuleClass(pl.LightningDataModule):
    def __init__(self):
        #在此处定义所需的参数
    
    def prepare_data(self):
        # 定义只需在一个 GPU 上执行的步骤,
        # 例如下载数据。
    
    def setup(self, stage=None):
        # 定义需要在每个 GPU 上执行的步骤,
        # 例如分割数据、应用变换等。
    
    def train_dataloader(self):
        # 在此处返回训练数据的 DataLoader
    
    def val_dataloader(self):
        # 在此处返回验证数据的 DataLoader
    
    def test_dataloader(self):
        # 在此处返回测试数据的 DataLoader

注意: 上述函数的名称必须完全一致。

深入理解 DataModule 类

在本文中,我将使用 MNIST 数据集作为示例。正如我们所见,创建 Lightning DataModule 的首要要求是继承 pytorch-lightning 中的 LightningDataModule 类:

import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader

class DataModuleMNIST(pl.LightningDataModule):

init() 方法

此方法用于存储关于批次大小、数据变换等信息。

def __init__(self):
    super().__init__()
    self.download_dir = ‘‘
    self.batch_size = 32
    self.transform = transforms.Compose([
        transforms.ToTensor()
    ])

prepare_data() 方法

此方法用于定义那些只需由单个 GPU 执行的过程。它通常用于处理数据下载的任务。

def prepare_data(self):
    datasets.MNIST(self.download_dir,
           train=True, download=True)
           
    datasets.MNIST(self.download_dir, train=False,        
           download=True)

setup() 方法

此方法用于定义那些旨在由所有可用 GPU 执行的过程。它通常用于处理数据加载的任务。

def setup(self, stage=None):
    data = datasets.MNIST(self.download_dir,
             train=True, transform=self.transform)
             
    self.train_data, self.valid_data = random_split(data, [55000, 5000])
        
    self.test_data = datasets.MNIST(self.download_dir,
                        train=False, transform=self.transform)

train_dataloader() 方法

此方法用于创建训练数据的 dataloader。在此函数中,通常只需返回训练数据的 dataloader。

def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.batch_size)

val_dataloader() 方法

此方法用于创建验证数据的 dataloader。在此函数中,通常只需返回验证数据的 dataloader。

def val_dataloader(self):
   return DataLoader(self.valid_data, batch_size=self.batch_size)

test_dataloader() 方法

此方法用于创建测试数据的 dataloader。在此函数中,通常只需返回测试数据的 dataloader。

def test_dataloader(self):
   return DataLoader(self.test_data, batch_size=self.batch_size)

使用 DataModule 训练 Pytorch Lightning 模型

在 Pytorch Lightning 中,我们使用 Trainer() 来训练模型,在这里我们可以以 DataLoader 或 DataModule 的形式传递数据。让我们以我在这篇文章中定义的模型为例:

“`

class model(pl.LightningModule):

def init(self):

super(model, self).init()

self.fc1 = nn.Linear(28*28, 256)

self.fc2 = nn.Linear(256, 128)

self.out = nn.Linear(128, 10)

self.lr = 0.01

self.loss = nn.CrossEntropyLoss()

def forwa

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。如需转载,请注明文章出处豆丁博客和来源网址。https://shluqu.cn/42751.html
点赞
0.00 平均评分 (0% 分数) - 0