仿照PyTorch官网的衣服的例子自己写了个图片识别模型,算是自己的第一个PyTorch模型,记录下过程。
一路踩坑是免不了的。
自定义Dataset
官网给出的CustomImageDataset
的代码大体可用,需要根据自己的图片存入路径和csv文件格式稍做修改。
图片尺寸
官网例子的图片都是28×28的,自己所用的图片尺寸也要统一大小。我是自己写了个方法统一缩放成了128×128。
灰度处理
官网例子的图片都是灰白的,也就是经过灰度处理的,自己的图片也要处理成这种格式。
transform=transforms.Compose([
transforms.Grayscale(1), #转为单通道灰度图像
])
NeuralNetwork
定义NeuralNetwork
时,第一个Linear
函数的第一个参数要与图片尺寸保持一致,不然会报错 mat1 and mat2 shapes cannot be multiplied
,翻译过来就是mat1和mat2不能相乘。
pytorch RuntimeError: expected scalar type Float but found Byte
这个错误比较意外,查了下说需要把图片的RGB数值转为Float类型,最简单的方法就是直接除以255。
if self.transform:
image = self.transform(image)
image = image / 255 # 要求是float类型,最简单的就是直接除以255
准确度
由于数据集过小,两个分类加起来只有55张图片,第一个分类51张,第二个分类4张,最后训练出来的模型准确率是92.7%,正好是第二个分类的4张图片预测错误。准备多找些图片再试试。
PyTorch第一个图片识别模型就算完成了!