2

使用PyTorch开始自己的首次视觉竞赛系列(1)--如何使用DataLoader

 3 years ago
source link: https://muyun.work/Pytorch1.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

使用PyTorch开始自己的首次视觉竞赛系列(1)--如何使用DataLoader

写这篇文章的初衷是希望在众多的开源 base­line 中,我想要形成我自己的一套 pipeline。所以我将在近期的几个竞赛中开始尝试并逐渐整理出一套简洁易用的 pipeline。为后面的毕设工作也做好准备。我将为每个模块都进行部分知识的整理,并完成最后的系列博客

1 视觉模型的大体框架

    • DataLoader
    • Transform
    • 我们应该使用怎样的Backbone
    • 如何恰定义我们的优化器和损失
  • 训练及多折验证
  • TTA(Test Time Augmentation) 及预测最终结果

2 如何开始写DataLoader

2.1 使用 torch.utils.data.dataset.Dataset 来组织数据

将所有的数据集都表示为从 key 到 data sam­ples 的映射,使用方法是继承该类,重写子类。所有的子类都必须重载 __len__() 函数以及 __getitem__() 函数,前者返回数据集的大小,后者将根据给定的整数 key 取得 data sam­ple。

上面是 Py­Torch 官方文档中所说到的,下面结合实例来讲述具体的使用。这里是我们的数据集格式,如下图所示

├─test           # 其中是测试集的图片,序号从0开始
└─train          # 其中是训练集的图片,序号从0开始
└─train.csv     # 训练集的标注,其格式是filename label两列数据
image-20200311175143003.png#vwid=414&vhei=374
image-20200311175211910.png#vwid=464&vhei=376
class MyDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass

    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass

    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0
2.1.1 对 init() 函数进行重载

接下来我们需要对三个函数进行重载,首先对 init () 函数进行重载

  • 首先对于我们定义一个df用来存放train.csv的标注数据
  • 定义 transforms 用来对数据进行增强操作,在此处实现自己的增强策略。例如标准化和随机水平翻转等,
  • 定义 mode 变量用来区分训练集以及测试结合,默认是训练集
def __init__(self, df, transform, mode='train'):
    self.df = df
    self.transform = transform
    self.mode = mode
2.1.2 对 getitem() 函数进行重载

接着我们需要对 getitem () 函数进行重载

对训练集来说需要做下面几步,测试集则不用返回标签。除此之外,训练集的 df 是从 csv 文件中读取的,而测试集的 df 我们需要将测试集中的文件名组织为列表即可(可参照以下实例)

  • 用 Image 这个库将图片加载进来,并转换为 ‘RGB’ 格式
  • 进行数据增强
  • 返回的结果是增强后的图片以及图片的标签
def __getitem__(self, index):
    if self.mode == 'train':
        img = Image.open(self.df['filename'].iloc[index]).convert('RGB')
        img = self.transform(img)
        return img, torch.from_numpy(np.array(self.df['label'].iloc[index]))
    else:
        img = Image.open(self.df[index]).convert('RGB')
        img = self.transform(img)
        return img, torch.from_numpy(np.array(0))

# 测试集的组织方式实例
test_path_list = ['{}/{}.jpg'.format(config.image_test_path, x) for x in range(0, data_len)]
test_df = np.array(test_path_list)
2.1.3 对 len() 函数进行重载

len () 函数的重载很简单,返回 len (self.df) 即可

def __len__(self):
    return len(self.df)

2.2 使用 torchvision.datasets.ImageFolder 来组织数据

Im­age­Folder 组织数据比较局限,但我最开始使用的便是这种方式,其必须要求的数据格式如下:一个类的图像放在一个文件夹中。操作很不灵活,我刚开始用的时候花了一些时间将数据组织成要求的格式。现在还是重载 Dat­aLoader 类会方便很多。

20200311184057.png#vwid=220&vhei=192
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

2.3 使用 torch.utils.data.DataLoader 来加载数据

在数据组织完成之后,我们需要构建训练集、验证集和测试集的 Dat­aLoader。Dat­aLoader 能够有效的帮你进行批量的数据迭代,下面是 Dat­aLoader 的构造函数,下面简单解释一下常用的一些参数。

  • dataset:即上述的 DataSet 类或者 ImageFolder 类
  • batch_size:即进行批量训练的数据数量
  • shuffle:是否打乱,一般来说训练集为True,验证集和测试集为False
  • num_workers:即加载数据所用到的线程数量
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
    pass

一般来说,我们定义好 Dat­aLoader 后需要去测试 Dat­aLoader 是否写对了,我们就用一个函数来测试一下,并解释一下 Dat­aLoader 的迭代过程。我们有两种方法迭代 Dat­aLoader,如下面的代码:

for step, (batch_x, batch_y) in enumerate(dataloader):
    pass
for batch_x, batch_y in dataloader:
    pass

接下来则是完整的测试过程:

# 验证dataloader是否正确的取到数据
def dataloader_test(dataloader):
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(dataloader):
            print("单个batch的size: ", batch_x.shape)
            print("单个batch的label张量: ", batch_y.numpy())

            # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
            img = batch_x[0]
            img = img.numpy()                    # FloatTensor转为ndarray
            img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后

            plt.imshow(img)
            plt.show()
            break
        break

2.4 后言

至此,我们已经能够完成对数据集的读取,写出 Dat­aLoader 的代码并测试自己编写代码的正确性了。在总结这篇文章的过程中我也体会到,其实 Py­Torch 的官方文档真的写的超详细了。在编写整个系列中,我都会尽可能地多参考官方文档。下面我将对数据集很重要的一部分:数据增强进行单独的讲述,分享我在这个过程中学到的知识。

3. 参考资料


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK