本章通过一个食物图片分类的例子介绍如何自定义自己的数据集。
什么是自定义数据集?
自定义数据集是你需要的数据集合。
如果我们正在构建像 Nutrify 这样的食物图像分类应用程序,我们的自定义数据集可能是食物图像。
如果我们试图建立一个模型来分类网站上基于文本的评论是正面的还是负面的,我们的自定义数据集可能是现有客户评论及其评级的示例。
如果我们试图构建一个声音分类应用程序,我们的自定义数据集可能是声音样本及其样本标签。
PyTorch 包含许多现有函数,用于加载自定义数据集: TorchVision
, TorchText
, TorchAudio
and TorchRec
但有时内置的函数不够实现想要的功能。此时, 可以通过继承 torch.utils.data.Dataset
自定义我们的数据集。
1. 准备数据
我们使用Food101 dataset的一部分来自定义我们的数据集。Food101 是流行的计算机视觉基准,它包含 101 种不同食物的 1000 张图像(750个训练, 250个测试),总共 101,000张图像。

但是,我们并不想对101种食物分类,而是从3种食物开始:比萨、牛排和寿司。
作者已经将这三类食物的图片的压缩包上传到了github上,只需要下载解压即可。
import requests
import zipfile
from pathlib import Path
# Setup path to data folder
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"
# If the image folder doesn't exist, download it and prepare it...
if image_path.is_dir():
print(f"{image_path} directory exists.")
else:
print(f"Did not find {image_path} directory, creating one...")
image_path.mkdir(parents=True, exist_ok=True)
# Download pizza, steak, sushi data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",verify=False)
print("Downloading pizza, steak, sushi data...")
f.write(request.content)
# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
print("Unzipping pizza, steak, sushi data...")
zip_ref.extractall(image_path)
2. 探索数据
下载解压完数据后,得到文件夹pizza_steak_sushi
,结构如下:
pizza_steak_sushi/ <- 总的文件夹
train/ <- 训练集
pizza/ <- 类名作为文件夹名称
image01.jpg
image02.jpg
...
steak/
image24.jpg
image25.jpg
...
sushi/
image37.jpg
...
test/ <- 测试集
pizza/
image101.jpg
image102.jpg
...
steak/
image154.jpg
image155.jpg
...
sushi/
image167.jpg
...
在我们的例子中,我们有标准图像分类格式的披萨、牛排和寿司的图像。图像分类格式在以特定类名命名的单独目录中包含单独的图像类。例如,“pizza”的所有图像都包含在“pizza/”目录中。这种格式在许多不同的图像分类基准中都很流行,包括 ImageNet。
打开任意一张图片后,发现是一张512x512像素大小的jpg格式图片。
(也可以用PIL库来查看图片数据。)
3. 数据变换
想将图像数据加载到PyTorch我们需要:
1. 把它变成张量(图像的数字表示)。
2. 将其转换为
torch.utils.data.Dataset
和随后的torch.utils.data.DataLoader
,我们将它们简称为Dataset
和DataLoader
。
PyTorch 有几种不同类型的预构建数据集和数据集加载器,具体取决于您正在处理的问题。
| Problem space | Pre-built Datasets and Functions |
| Vision | torchvision.datasets |
| Audio | torchaudio.datasets |
| Text | torchtext.datasets |
| Recommendation system | torchrec.datasets |
由于我们正在处理视觉问题,我们将查看 torchvision.datasets
的数据加载功能以及 torchvision.transforms
用于准备我们的数据。
让我们先导入相关的库。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
3.1 用 torchvision.transforms
变换数据
我们有图像文件夹,为了能在PyTorch中使用,需要将它们转换为张量。
通常使用“torchvision.transforms”模块来实现。
torchvision.transforms
包含许多方法来格式化图像,将它们转换为张量,甚至进行数据增强。
要获得使用 torchvision.transforms
的经验,让我们编写一系列转换步骤:
1. 使用
transforms.Resize()
调整图像大小。2. 使用
transforms.RandomHorizontalFlip()
在水平方向上随机翻转我们的图像(这个可以被认为是一种数据增强形式)。3. 使用
transforms.ToTensor()
。
我们可以使用 torchvision.transforms.Compose()
组合所有这些步骤。
data_transform = transforms.Compose([
# Resize the images to 64x64
transforms.Resize(size=(64, 64)),
# Flip the images randomly on the horizontal
transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
# Turn the image into a torch.Tensor
transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])
在一张图片上验证我们的变换:
def plot_transformed_images(image_paths, transform, n=3, seed=42):
"""Plots a series of random images from image_paths.
Will open n image paths from image_paths, transform them
with transform and plot them side by side.
Args:
image_paths (list): List of target image paths.
transform (PyTorch Transforms): Transforms to apply to images.
n (int, optional): Number of images to plot. Defaults to 3.
seed (int, optional): Random seed for the random generator. Defaults to 42.
"""
random.seed(seed)
random_image_paths = random.sample(image_paths, k=n)
for image_path in random_image_paths:
with Image.open(image_path) as f:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(f)
ax[0].set_title(f"Original \nSize: {f.size}")
ax[0].axis("off")
# Transform and plot image
# Note: permute() will change shape of image to suit matplotlib
# (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
transformed_image = transform(f).permute(1, 2, 0)
ax[1].imshow(transformed_image)
ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")
ax[1].axis("off")
fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)
plot_transformed_images(image_path_list,
transform=data_transform,
n=3)
4. 选项1: 使用 ImageFolder
加载图像
是时候将我们的图片用Dataset
变成PyTorch可用的数据了。由于我们的数据是标准的图片分类类型,可以使用 torchvision.datasets.ImageFolder
.
我们可以将目标图像目录的文件路径以及我们希望对图像执行的一系列转换传递给它。
让我们在我们的数据文件夹 train_dir
和 test_dir
上进行测试,通过 transform=data_transform
将我们的图像转换为张量。
# Use ImageFolder to create dataset(s)
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, # target folder of images
transform=data_transform, # transforms to perform on data (images)
target_transform=None) # transforms to perform on labels (if necessary)
test_data = datasets.ImageFolder(root=test_dir,
transform=data_transform)
print(f"Train data:\n{train_data}\nTest data:\n{test_data}")
查看一下数据集信息
# Get class names as a list
class_names = train_data.classes
class_names
# Can also get class names as a dict
class_dict = train_data.class_to_idx
class_dict
# Check the lengths
len(train_data), len(test_data)
获取数据:
img, label = train_data[0][0], train_data[0][1]
print(f"Image tensor:\n{img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")
4.1 使用 DataLoader
装载图像
我们已经将图像作为 PyTorch 的“Dataset”,但现在让我们将它们变成“DataLoader”。
我们将使用 torch.utils.data.DataLoader
来实现。将我们的 Dataset
转换为 DataLoader
使图像可迭代。
为简单起见,DataLoader的参数将使用 batch_size=1
和 num_workers=1
。num_workers
定义了将创建多少个子进程来加载您的数据。可以这样想,num_workers
设置的值越高,PyTorch 用于加载数据的计算能力就越大。就个人而言,我通常通过 Python 的 os.cpu_count()
将其设置为我机器上的 CPU 总数。这确保了 DataLoader
使用尽可能多的核心来加载数据。
注意:更多参数你可以在PyTorch文档中查看
torch.utils.data.DataLoader
来熟悉。
# Turn train and test Datasets into DataLoaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset=train_data,
batch_size=1, # how many samples per batch?
num_workers=1, # how many subprocesses to use for data loading? (higher = more)
shuffle=True) # shuffle the data?
test_dataloader = DataLoader(dataset=test_data,
batch_size=1,
num_workers=1,
shuffle=False) # don't usually need to shuffle testing data
train_dataloader, test_dataloader
验证我们的dataloader
img, label = next(iter(train_dataloader))
# Batch size will now be 1, try changing the batch_size parameter above and see what happens
print(f"Image shape: {img.shape} -> [batch_size, color_channels, height, width]")
print(f"Label shape: {label.shape}")
5 选项2:使用自定义Dataset加载数据
如果不存在像 torchvision.datasets.ImageFolder()
这样的预构建“数据集”怎么办?你可以建立自定义数据集。
创建自定义方式来加载“数据集”的优缺点:
Pros of creating a custom Dataset | Cons of creating a custom Dataset |
可以为几乎任何数据创建Dataset | 即使您可以用几乎任何东西创建一个“数据集”,但这并不意味着它会起作用。 |
不再被内置Dataset限制 | 使用自定义Dataset通常会导致编写更多代码,这可能容易出现错误或性能问题。 |
为了看到这一点,让我们通过继承 torch.utils.data.Dataset
(PyTorch 中所有 Dataset
的基类)来实现 torchvision.datasets.ImageFolder()
类似的功能。
我们将从导入我们需要的模块开始:
• Python 的
os
用于处理目录(我们的数据存储在目录中)。• Python 的
pathlib
用于处理文件路径(我们的每个图像都有一个唯一的文件路径)。•
torch
适用于 PyTorch 的所有内容。• PIL 的
Image
类用于加载图像。•
torch.utils.data.Dataset
继承并创建我们自己的自定义Dataset
。•
torchvision.transforms
将我们的图像转换为张量。• 来自 Python 的
typing
模块的各种类型,用于向我们的代码添加类型提示。
import os
import pathlib
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List
5.1 Creating a helper function to get class names
让我们编写一个辅助函数,该函数能够在给定目录路径的情况下创建类名列表和"类名:索引"的字典。为此,我们将:
1. 使用
os.scandir()
获取类名,遍历一个目标目录(理想情况下该目录是标准图像分类格式)。2. 如果找不到类名,则引发错误(如果发生这种情况,目录结构可能有问题)。
3. 将类名转换成数字标签字典,每个类一个。
在我们编写完整函数之前,让我们看一下步骤 1 的一个小例子。
# Make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folder names in a target directory.
Assumes target directory is in standard image classification format.
Args:
directory (str): target directory to load classnames from.
Returns:
Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
Example:
find_classes("food_images/train")
>>> (["class_1", "class_2"], {"class_1": 0, ...})
"""
# 1. Get the class names by scanning the target directory
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
# 2. Raise an error if class names not found
if not classes:
raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
# 3. Crearte a dictionary of index labels (computers prefer numerical rather than string labels)
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
现在我们准备构建我们自己的自定义Dataset
。我们将构建一个来复制torchvision.datasets.ImageFolder()
的功能。这将是一个很好的做法,此外,它还将揭示制作您自己的自定义“数据集”所需的一些步骤。
让我们将其分解为下面步骤(1、5、6是所有数据集都需要实现的,2、3、4则是我们这个问题需要的):
1. 继承
torch.utils.data.Dataset
.2. 用
targ_dir
(the target data directory)和transform
(so we have the option to transform our data if needed)参数初始化。3. 创建属性:
paths
(目标图像的路径)、transform
(我们可能想要使用的变换,可以是None
)、classes
和class_to_idx
(来自我们的find_classes()
函数)。4. 创建一个从文件加载图像并返回它们的函数,这可以使用
PIL
或[torchvision.io
](https://pytorch.org/vision/stable/io.html#image)(用于输入/ 视觉数据的输出)。5. 重写
torch.utils.data.Dataset
的__len__
方法以返回Dataset
中的样本数。这样您就可以调用len(Dataset)
。6. 重写
torch.utils.data.Dataset
的__getitem__
方法以从Dataset
返回单个样本。
# Write a custom dataset class (inherits from torch.utils.data.Dataset)
from torch.utils.data import Dataset
# 1. Subclass torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
# 2. Initialize with a targ_dir and transform (optional) parameter
def __init__(self, targ_dir: str, transform=None) -> None:
# 3. Create class attributes
# Get all image paths
self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's
# Setup transforms
self.transform = transform
# Create classes and class_to_idx attributes
self.classes, self.class_to_idx = find_classes(targ_dir)
# 4. Make function to load images
def load_image(self, index: int) -> Image.Image:
"Opens an image via a path and returns it."
image_path = self.paths[index]
return Image.open(image_path)
# 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
def __len__(self) -> int:
"Returns the total number of samples."
return len(self.paths)
# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
"Returns one sample of data, data and label (X, y)."
img = self.load_image(index)
class_name = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
class_idx = self.class_to_idx[class_name]
# Transform if necessary
if self.transform:
return self.transform(img), class_idx # return data, label (X, y)
else:
return img, class_idx # return data, label (X, y)
哇!一大堆代码。
这是创建自己的自定义“数据集”的缺点之一。然而,现在我们已经编写了一次,我们可以将它与其他一些有用的数据函数一起移动到一个 .py
文件中,例如 data_loader.py
并在以后重用它。
在我们测试新的“ImageFolderCustom”类之前,让我们创建一些转换来准备我们的图像。
# Augment train data
train_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])
# Don't augment test data, only reshape
test_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
现在到了关键时刻!
让我们使用我们自己的 ImageFolderCustom
类将我们的训练图像(包含在 train_dir
中)和我们的测试图像(包含在 test_dir
中)转换为 Dataset
。
train_data_custom = ImageFolderCustom(targ_dir=train_dir,
transform=train_transforms)
test_data_custom = ImageFolderCustom(targ_dir=test_dir,
transform=test_transforms)
train_data_custom, test_data_custom
我们可以通过一些方法来测试自定义数据集是否有效。例如用matplotlib来查看。见原notebook的5.2、5.3。
5.4 使用DataLoader加载
# Turn train and test custom Dataset's into DataLoader's
from torch.utils.data import DataLoader
train_dataloader_custom = DataLoader(dataset=train_data_custom, # use custom created train Dataset
batch_size=1, # how many samples per batch?
num_workers=0, # how many subprocesses to use for data loading? (higher = more)
shuffle=True) # shuffle the data?
test_dataloader_custom = DataLoader(dataset=test_data_custom, # use custom created test Dataset
batch_size=1,
num_workers=0,
shuffle=False) # don't usually need to shuffle testing data
train_dataloader_custom, test_dataloader_custom
测试
# Get image and label from custom DataLoader
img_custom, label_custom = next(iter(train_dataloader_custom))
# Batch size will now be 1, try changing the batch_size parameter above and see what happens
print(f"Image shape: {img_custom.shape} -> [batch_size, color_channels, height, width]")
print(f"Label shape: {label_custom.shape}")
6. 其它形式的转换(数据增强)
我们已经在我们的数据上看到了一些转换,但还有更多。您可以在 torchvision.transforms
文档 中查看它们。
转换的目的是以某种方式改变你的图像。这可能会将您的图像变成张量(正如我们之前所见)。或裁剪它或随机擦除一部分或随机旋转它们。
进行这种转换通常被称为数据增强(data augmentation)。数据增强是改变数据的过程,您可以人工增加训练集的多样性。在这个 人工 更改的数据集上训练模型有望产生一个能够更好地泛化 的模型。
您可以在 PyTorch 的 Illustration of Transforms 示例 中看到许多使用 torchvision.transforms
对图像执行数据增强的不同示例 )。
机器学习就是利用随机性的力量,研究表明随机变换(如 transforms.RandAugment()
和 transforms.TrivialAugmentWide()
)通常比手工挑选的变换表现更好。您有一组变换,并且随机选择其中的一些变换来在图像上执行,并且在给定范围之间以随机幅度执行(幅度越高意味着强度越大)。
在 transforms.TrivialAugmentWide()
中要注意的主要参数是 num_magnitude_bins=31
。它定义了强度值将被选择多少范围以应用特定变换,“0”表示无范围,“31”表示最大范围(最高强度的最高机会)。我们可以将 transforms.TrivialAugmentWide()
合并到 transforms.Compose()
中。
from torchvision import transforms
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.TrivialAugmentWide(num_magnitude_bins=31), # how intense
transforms.ToTensor() # use ToTensor() last to get everything between 0 & 1
])
# Don't need to perform augmentation on the test data
test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
接下来查看数据增强的结果:
# Get all image paths
image_path_list = list(image_path.glob("*/*/*.jpg"))
# Plot random images
plot_transformed_images(
image_paths=image_path_list,
transform=train_transforms,
n=3,
seed=None
)







