timm 库中有三种主要的 Dataset 类

  1. ImageDataset
  2. IterableImageDataset
  3. AugMixDataset

在这份文档中,我们将逐一介绍它们,并探讨这些 Dataset 类的各种用例。

ImageDataset

class ImageDataset(root: str, parser: Union[ParserImageInTar, ParserImageFolder, str] = None, class_map: Dict[str, str] = '', load_bytes: bool = False, transform: List = None) -> Tuple[Any, Any]:

ImageDataset 可用于创建训练和验证数据集,其功能与 torchvision.datasets.ImageFolder 非常相似,并带有一些不错的附加功能。

Parser

parser 使用 create_parser 工厂方法自动设置。parserroot 目录中查找所有图像和目标,其中 root 文件夹的结构如下所示:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

parser 设置一个 class_to_idx 字典,将类映射到整数,看起来像这样:

{'dog': 0, 'cat': 1, ..}

并且还有一个名为 samples 的属性,它是一个元组列表,看起来像这样:

[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., ('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]

这个 parser 对象是可下标访问的,通过执行 parser[index] 可以返回 self.samples 中该特定索引处的样本。因此,执行 parser[0] 将返回 ('root/dog/xxx.png', 0)

__getitem__(index: int) → Tuple[Any, Any]

一旦设置了 parserImageDataset 就会根据 index 从这个 parser 获取图像和目标。

img, target = self.parser[index]

然后,它根据 load_bytes 参数,将图像读取为 PIL.Image 并转换为 RGB,或者将图像读取为字节。

最后,它对图像进行转换并返回目标。如果目标为 None,则返回一个虚拟目标 torch.tensor(-1)

用法

这个 ImageDataset 也可以用作 torchvision.datasets.ImageFolder 的替代品。考虑到我们有 imagenette2-320 数据集,其结构看起来像这样:

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

每个子文件夹都包含属于该类的一组 .JPEG 文件。

# run only once
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar

那么,就可以像这样创建一个 ImageDataset

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset[0]

(<PIL.Image.Image image mode=RGB size=426x320 at 0x7FF7F4880460>, 0)

我们还可以看到 dataset.parserParserImageFolder 的一个实例

dataset.parser

<timm.data.parsers.parser_image_folder.ParserImageFolder at 0x7ff7f4880d90>

最后,让我们看一下 parser 中的 class_to_idx 字典映射

dataset.parser.class_to_idx

{'n01440764': 0,
 'n02102040': 1,
 'n02979186': 2,
 'n03000684': 3,
 'n03028079': 4,
 'n03394916': 5,
 'n03417042': 6,
 'n03425413': 7,
 'n03445777': 8,
 'n03888257': 9}

并且,还有前五个样本,如下所示

dataset.parser.samples[:5]

[('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]

IterableImageDataset

timm 还提供了一个 IterableImageDataset,类似于 PyTorch 的 IterableDataset,但有一个关键区别——IterableImageDataset 在生成图像和目标之前,会对 image 应用转换。

当数据来自流或数据长度未知时,这类数据集特别有用。

timm 会对 image 延迟应用转换,并在目标为 None 的情况下将目标设置为一个虚拟目标 torch.tensor(-1, dtype=torch.long)

与上面的 ImageDataset 类似,IterableImageDataset 首先创建一个 parser,它根据 root 目录获取样本的元组。

如前所述,parser 返回一个图像,目标是图像所在的对应文件夹。

__iter__

IterableImageDataset 内部的 __iter__ 方法首先从 self.parser 获取图像和目标,然后对图像延迟应用转换。在两者返回之前,也会将目标设置为一个虚拟值。

用法

from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform 

root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))
((<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,
  0),
 (<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,
  0))

iterable_dataset 不可下标访问。

iterable_dataset[0]
> > 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-14-9085b17eda0c> in <module>
----> 1 iterable_dataset[0]

~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     30 
     31     def __getitem__(self, index) -> T_co:---> 32         raise NotImplementedError     33 
     34     def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':

NotImplementedError:

AugmixDataset

class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):

AugmixDataset 接受一个 ImageDataset 并将其转换为 Augmix 数据集。

什么是 Augmix 数据集以及我们什么时候需要这样做?

让我们借助 Augmix 论文来回答这个问题。

Augmix

如上图所示,最终的 Loss Output 实际上是分类损失与标签和模型在 Xorig、Xaugmix1 和 Xaugmix2 上的预测之间的 Jensen-Shannon 损失的 λ 倍之和。

因此,对于这种情况,我们将需要三个版本的批次——原始、augmix1 和 augmix2。那么我们如何实现这一点呢?当然是使用 AugmixDataset

__getitem__(index: int) -> Tuple[Any, Any]

首先,我们从 self.dataset(它是传递给 AugmixDataset 构造函数的那个数据集)获取一个 X 和相应的标签 y。接下来,我们对这个图像 X 进行归一化并将其添加到一个名为 x_list 的变量中。

接下来,根据默认值为 0 的 num_splits 参数,我们对 X 应用 augmentations,对增强后的输出进行归一化,并将其追加到 x_list 中。

用法

from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader

dataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)
loader_train = create_loader(
    dataset, 
    input_size=(3, 224, 224), 
    batch_size=8, 
    is_training=True, 
    scale=[0.08, 1.], 
    ratio=[0.75, 1.33], 
    num_aug_splits=2
)
# Requires GPU to work

next(iter(loader_train))[0].shape

>> torch.Size([16, 3, 224, 224])

因为我们传入了 num_aug_splits=2。在这种情况下,loader_train 包含前 8 张原始图像和接下来代表 augmix1 的 8 张图像。

如果我们传入 num_aug_splits=3,那么有效的 batch_size 将是 24,其中前 8 张图像是原始图像,接下来 8 张代表 augmix1,最后 8 张代表 augmix2