PyTorch第一个图片识别模型

仿照PyTorch官网的衣服的例子自己写了个图片识别模型,算是自己的第一个PyTorch模型,记录下过程。

一路踩坑是免不了的。

自定义Dataset

官网给出的CustomImageDataset的代码大体可用,需要根据自己的图片存入路径和csv文件格式稍做修改。

图片尺寸

官网例子的图片都是28×28的,自己所用的图片尺寸也要统一大小。我是自己写了个方法统一缩放成了128×128。

灰度处理

官网例子的图片都是灰白的,也就是经过灰度处理的,自己的图片也要处理成这种格式。

transform=transforms.Compose([
    transforms.Grayscale(1), #转为单通道灰度图像
])

NeuralNetwork

定义NeuralNetwork时,第一个Linear函数的第一个参数要与图片尺寸保持一致,不然会报错 mat1 and mat2 shapes cannot be multiplied,翻译过来就是mat1和mat2不能相乘。

pytorch RuntimeError: expected scalar type Float but found Byte

这个错误比较意外,查了下说需要把图片的RGB数值转为Float类型,最简单的方法就是直接除以255。

if self.transform:
    image = self.transform(image)
    image = image / 255 # 要求是float类型,最简单的就是直接除以255

准确度

由于数据集过小,两个分类加起来只有55张图片,第一个分类51张,第二个分类4张,最后训练出来的模型准确率是92.7%,正好是第二个分类的4张图片预测错误。准备多找些图片再试试。

PyTorch第一个图片识别模型就算完成了!

Leave a Comment

豫ICP备19001387号-1