PyTorch Dataloader:自定义Dataset与drop_last参数的深入理解

作者:问答酱2023.12.25 07:11浏览量:7

简介:Dataloader在PyTorch中是用于加载数据集的重要工具,它可以帮助我们方便地处理数据,如切分数据集,打乱数据,以及批处理数据等。在PyTorch中,我们可以通过继承`torch.utils.data.Dataset`类来创建自己的数据集。

千帆应用开发平台“智能体Pro”全新上线 限时免费体验

面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用

立即体验

Dataloader在PyTorch中是用于加载数据集的重要工具,它可以帮助我们方便地处理数据,如切分数据集,打乱数据,以及批处理数据等。在PyTorch中,我们可以通过继承torch.utils.data.Dataset类来创建自己的数据集。
但是,如果你正在处理一个序列标注问题(例如文本分类、命名实体识别等),你可能希望数据集中每一行都有不同的长度。例如,一个句子中的单词和另一个句子中的单词可能长度不同。在这种情况下,使用默认的DataLoader可能会遇到问题,因为默认情况下,它期望所有的样本都有相同的长度。
这就是torch.utils.data.DataLoaderdrop_last参数的用途。当你将drop_last设置为True时,如果数据集中存在长度不同的样本,那么在每个epoch结束时,最后一个批次将被丢弃,以确保所有的批次都有相同的长度。
下面是一个简单的例子,展示了如何创建一个自定义的数据集类,以及如何使用带有drop_last参数的DataLoader

  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. class MyDataset(Dataset):
  4. def __init__(self, data):
  5. self.data = data
  6. def __len__(self):
  7. return len(self.data)
  8. def __getitem__(self, idx):
  9. return self.data[idx]
  10. # 假设我们有以下数据:
  11. data = [torch.randn(2, 3), torch.randn(2, 6), torch.randn(2, 9)]
  12. dataset = MyDataset(data)
  13. # 使用 DataLoader 并设置 drop_last=True
  14. dataloader = DataLoader(dataset, batch_size=2, shuffle=True, drop_last=True)
  15. # 遍历 dataloader
  16. for batch in dataloader:
  17. print(batch)

在这个例子中,我们的数据集中的样本长度都不同(第一个样本长度为3,第二个样本长度为6,第三个样本长度为9)。当我们使用drop_last=TrueDataLoader时,由于存在长度不同的样本,最后一个批次将被丢弃。这意味着每个epoch将只使用两个批次的数据(即前两个样本)。

article bottom image

相关文章推荐

发表评论