在 timm
库中有三种主要的 Dataset 类
ImageDataset
IterableImageDataset
AugMixDataset
在这份文档中,我们将逐一介绍它们,并探讨这些 Dataset 类的各种用例。
class ImageDataset(root: str, parser: Union[ParserImageInTar, ParserImageFolder, str] = None, class_map: Dict[str, str] = '', load_bytes: bool = False, transform: List = None) -> Tuple[Any, Any]:
ImageDataset
可用于创建训练和验证数据集,其功能与 torchvision.datasets.ImageFolder 非常相似,并带有一些不错的附加功能。
parser
使用 create_parser
工厂方法自动设置。parser
在 root
目录中查找所有图像和目标,其中 root
文件夹的结构如下所示:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
parser
设置一个 class_to_idx
字典,将类映射到整数,看起来像这样:
{'dog': 0, 'cat': 1, ..}
并且还有一个名为 samples
的属性,它是一个元组列表,看起来像这样:
[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., ('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]
这个 parser
对象是可下标访问的,通过执行 parser[index]
可以返回 self.samples
中该特定索引处的样本。因此,执行 parser[0]
将返回 ('root/dog/xxx.png', 0)
。
一旦设置了 parser
,ImageDataset
就会根据 index
从这个 parser
获取图像和目标。
img, target = self.parser[index]
然后,它根据 load_bytes
参数,将图像读取为 PIL.Image
并转换为 RGB
,或者将图像读取为字节。
最后,它对图像进行转换并返回目标。如果目标为 None,则返回一个虚拟目标 torch.tensor(-1)
。
这个 ImageDataset
也可以用作 torchvision.datasets.ImageFolder
的替代品。考虑到我们有 imagenette2-320
数据集,其结构看起来像这样:
imagenette2-320
├── train
│ ├── n01440764
│ ├── n02102040
│ ├── n02979186
│ ├── n03000684
│ ├── n03028079
│ ├── n03394916
│ ├── n03417042
│ ├── n03425413
│ ├── n03445777
│ └── n03888257
└── val
├── n01440764
├── n02102040
├── n02979186
├── n03000684
├── n03028079
├── n03394916
├── n03417042
├── n03425413
├── n03445777
└── n03888257
每个子文件夹都包含属于该类的一组 .JPEG
文件。
# run only once
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar
那么,就可以像这样创建一个 ImageDataset
:
from timm.data.dataset import ImageDataset
dataset = ImageDataset('./imagenette2-320')
dataset[0]
(<PIL.Image.Image image mode=RGB size=426x320 at 0x7FF7F4880460>, 0)
我们还可以看到 dataset.parser
是 ParserImageFolder
的一个实例
dataset.parser
<timm.data.parsers.parser_image_folder.ParserImageFolder at 0x7ff7f4880d90>
最后,让我们看一下 parser 中的 class_to_idx
字典映射
dataset.parser.class_to_idx
{'n01440764': 0,
'n02102040': 1,
'n02979186': 2,
'n03000684': 3,
'n03028079': 4,
'n03394916': 5,
'n03417042': 6,
'n03425413': 7,
'n03445777': 8,
'n03888257': 9}
并且,还有前五个样本,如下所示
dataset.parser.samples[:5]
[('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]
timm
还提供了一个 IterableImageDataset
,类似于 PyTorch 的 IterableDataset,但有一个关键区别——IterableImageDataset
在生成图像和目标之前,会对 image
应用转换。
当数据来自流或数据长度未知时,这类数据集特别有用。
timm
会对 image
延迟应用转换,并在目标为 None
的情况下将目标设置为一个虚拟目标 torch.tensor(-1, dtype=torch.long)
。
与上面的 ImageDataset
类似,IterableImageDataset
首先创建一个 parser,它根据 root
目录获取样本的元组。
如前所述,parser 返回一个图像,目标是图像所在的对应文件夹。
IterableImageDataset
没有定义 __getitem__
方法,因此它是不可下标访问的。如果 dataset
是 IterableImageDataset
的一个实例,执行 dataset[0]
之类的操作将返回错误。IterableImageDataset
内部的 __iter__
方法首先从 self.parser
获取图像和目标,然后对图像延迟应用转换。在两者返回之前,也会将目标设置为一个虚拟值。
from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform
root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))
iterable_dataset
不可下标访问。
iterable_dataset[0]
> >
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-14-9085b17eda0c> in <module>
----> 1 iterable_dataset[0]
~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
30
31 def __getitem__(self, index) -> T_co:---> 32 raise NotImplementedError 33
34 def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
NotImplementedError:
class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):
AugmixDataset
接受一个 ImageDataset
并将其转换为 Augmix 数据集。
什么是 Augmix 数据集以及我们什么时候需要这样做?
让我们借助 Augmix 论文来回答这个问题。

如上图所示,最终的 Loss Output
实际上是分类损失与标签和模型在 Xorig、Xaugmix1 和 Xaugmix2 上的预测之间的 Jensen-Shannon 损失的 λ
倍之和。
因此,对于这种情况,我们将需要三个版本的批次——原始、augmix1 和 augmix2。那么我们如何实现这一点呢?当然是使用 AugmixDataset
!
augmix1
和 augmix2
是原始批次的数据增强版本,其中数据增强操作是从一个操作列表中随机选择的。首先,我们从 self.dataset
(它是传递给 AugmixDataset
构造函数的那个数据集)获取一个 X
和相应的标签 y
。接下来,我们对这个图像 X
进行归一化并将其添加到一个名为 x_list
的变量中。
接下来,根据默认值为 0 的 num_splits
参数,我们对 X
应用 augmentations
,对增强后的输出进行归一化,并将其追加到 x_list
中。
num_splits=2
,则 x_list
有两项——原始 + 增强
。如果 num_splits=3
,则 x_list
有三项——原始 + 增强1 + 增强2
。依此类推。from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader
dataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)
loader_train = create_loader(
dataset,
input_size=(3, 224, 224),
batch_size=8,
is_training=True,
scale=[0.08, 1.],
ratio=[0.75, 1.33],
num_aug_splits=2
)
# Requires GPU to work
next(iter(loader_train))[0].shape
>> torch.Size([16, 3, 224, 224])
batch_size=8
,但 loader_train
返回的批次大小是 16?这是为什么呢?num_aug_splits=2
。在这种情况下,loader_train
包含前 8 张原始图像和接下来代表 augmix1
的 8 张图像。
如果我们传入 num_aug_splits=3
,那么有效的 batch_size
将是 24,其中前 8 张图像是原始图像,接下来 8 张代表 augmix1
,最后 8 张代表 augmix2
。