目标检测
深度学习和目标检测系列教程 9-凯发ag旗舰厅登录网址下载
@author:runsen
上次对xml文件进行提取,使用到一个albumentation模块。albumentation模块是一个数据增强的工具,目标检测图像预处理通过使用“albumentation”来应用的,这是一个易于与pytorch数据转换集成的python库。
albumentation 是一种工具,可以在将(图像/图片)插入模型之前自定义 处理(弹性、网格、运动模糊、移位、缩放、旋转、转置、对比度、亮度等])到图像/图片。
对此,albumentation 官方文档:
- https://albumentations.ai/
为什么要看看这个东西?因为将 torchvision 代码重构为 albumentation 的效果最好,运行更快。
上图是使用 intel xeon platinum 8168 cpu 在 imagenet中通过 2000 个验证集图像的测试结果。每个单元格中的值表示在单个核心中处理的图像数量。可以看到 albumentation在许多转换方面比所有其他库至少高出 2 倍。
albumentation github 的官方 cpu 基准测试https://github.com/albumentations-team/albumentations
下面,我导入了下面的模块:
from pil import image import time import torch import torchvision from torch.utils.data import dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as np为了演示的目的,我找了一张前几天毕业回校拍的照片
原始 torchvision 数据管道
创建一个 dataloader 来使用 pytorch 和 torchvision 处理图像数据管道。
- 创建一个简单的 pytorch 数据集类
- 调用图像并进行转换
- 用 100 个循环测量整个处理时间
首先,从torch.utils.data获取 dataset抽象类,并创建一个 torchvision数据集类。然后我插入图像并使用__getitem__方法进行转换。另外,我用来total_time = (time.time() - start_t测量需要多长时间
class torchvisiondataset(dataset):def __init__(self, file_paths, labels, transform=none):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = image.open(file_path)start_t = time.time()if self.transform:image = self.transform(image)total_time = (time.time() - start_t)return image, label, total_time然后将图像大小调整为 256x256(高度 * 重量)并随机裁剪到 224x224 大小。然后以 50% 的概率应用水平翻转并将其转换为张量。输入文件路径应该是您的图像所在的 google drive 的路径。
torchvision_transform = transforms.compose([transforms.resize((256, 256)),transforms.randomcrop(224),transforms.randomhorizontalflip(),transforms.totensor(), ])torchvision_dataset = torchvisiondataset(file_paths=["demo.jpg"],labels=[1],transform=torchvision_transform, )下面计算从 torchvision_dataset 中提取样本图像并对其进行转换所花费的时间,然后运行 100 次循环以检查它所花费的平均毫秒。
torchvision time/sample: 7.31137752532959 ms在torch中的gpu,原始 torchvision 数据管道数据预处理的速度大约是0.0731137752532959 ms。最后输出的图像则为 224x224而且发生了翻转!
albumentation 数据管道
现在创建了一个 albumentations dataset 类,具体的transform和原始 torchvision 数据管道完全一样。
from pil import image import time import torch import torchvision from torch.utils.data import dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as npclass albumentationsdataset(dataset):"""__init__ and __len__ functions are the same as in torchvisiondataset"""def __init__(self, file_paths, labels, transform=none):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]# read an image with opencvimage = cv2.imread(file_path)# by default opencv uses bgr color space for color images,# so we need to convert the image to rgb color space.image = cv2.cvtcolor(image, cv2.color_bgr2rgb)start_t = time.time()if self.transform:augmented = self.transform(image=image)image = augmented['image']total_time = (time.time() - start_t)return image, label, total_timealbumentations_transform = albumentations.compose([albumentations.resize(256, 256),albumentations.randomcrop(224, 224),albumentations.horizontalflip(), # same with transforms.randomhorizontalflip()albumentations.pytorch.transforms.totensor() ]) albumentations_dataset = albumentationsdataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform, )total_time = 0 for i in range(100):sample, _, transform_time = albumentations_dataset[0]total_time = transform_timeprint("albumentations time/sample: {} ms".format(total_time*10))plt.figure(figsize=(10, 10)) plt.imshow(transforms.topilimage()(sample)) plt.show()具体输出如下:
albumentations time/sample: 0.5056881904602051 ms在torch中的gpu,albumentation 数据管道 数据管道数据预处理的速度大约是0.005056881904602051 ms。
因此,在真正的工业落地,基本需要将原始 torchvision 数据管道改写成albumentation 数据管道,因为落地项目的速度很重要。
albumentation数据增强
最后,我将展示如何使用albumentations中oneof函数进行书增强,我个人觉得这个函数在 albumentation 中非常有用。
from pil import image import time import torch import torchvision from torch.utils.data import dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2class albumentationsdataset(dataset):"""__init__ and __len__ functions are the same as in torchvisiondataset"""def __init__(self, file_paths, labels, transform=none):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = cv2.imread(file_path)image = cv2.cvtcolor(image, cv2.color_bgr2rgb)if self.transform:augmented = self.transform(image=image)image = augmented['image']return image, label# oneof随机采用括号内列出的变换之一。 # 我们甚至可以将发生的概率放在函数本身中。例如,如果 ([…], p=0.5) 之一,它会以 50% 的机会跳过整个变换,并以 1/6 的机会随机选择三个变换之一。 albumentations_transform_oneof = albumentations.compose([albumentations.resize(256, 256),albumentations.randomcrop(224, 224),albumentations.oneof([albumentations.horizontalflip(p=1),albumentations.randomrotate90(p=1),albumentations.verticalflip(p=1)], p=1),albumentations.oneof([albumentations.motionblur(p=1),albumentations.opticaldistortion(p=1), albumentations.gaussnoise(p=1)], p=1),albumentations.pytorch.totensor() ])albumentations_dataset = albumentationsdataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform_oneof, )num_samples = 5 fig, ax = plt.subplots(1, num_samples, figsize=(25, 5)) for i in range(num_samples):ax[i].imshow(transforms.topilimage()(albumentations_dataset[0][0]))ax[i].axis('off')plt.show()
上面的oneof是在水平翻转、旋转、垂直翻转中随机选择,在模糊、失真、噪声中随机选择。所以在这种情况下,我们允许 3x3 = 9 种组合
总结
以上是凯发ag旗舰厅登录网址下载为你收集整理的深度学习和目标检测系列教程 9-300:torchvision和albumentation性能对比,如何使用albumentation对图片数据做数据增强的全部内容,希望文章能够帮你解决所遇到的问题。
如果觉得凯发ag旗舰厅登录网址下载网站内容还不错,欢迎将凯发ag旗舰厅登录网址下载推荐给好友。
- 上一篇:
- 下一篇: