卷积神经网络实现咖啡豆分类 - P7

2023-09-15 21:31:56


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118
  • 显卡:A5000 24G

步骤

环境设置

包引用

import torch
import torch.nn as nn # 网络
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader, random_split # 数据集划分
from torchvision import datasets, transforms # 数据集加载,转换

import pathlib, random, copy # 文件夹遍历,实现模型深拷贝
from PIL import Image # python自带的图像类
import matplotlib.pyplot as plt # 图表
import numpy as np 
from torchinfo import summary # 打印模型参数

全局设备对象

方便将模型和数据统一拷贝到目标设备中

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

查看图像的信息

data_path = 'coffee_data'
data_lib = pathlib.Path(data_path)
coffee_images = list(data_lib.glob('*/*'))

# 打印5张图像的信息
for _ in range(5):
	image = random.choice(coffee_images)
	print(np.array(Image.open(str(image))).shape)

图像信息
通过打印的信息,可以看出图像的尺寸都是224x224的,这是一个CV经常使用的图像大小,所以后面我就不使用Resize来缩放图像了。

# 打印20张图像粗略的看一下
plt.figure(figsize=(20, 4))
for i in range(20):
	plt.subplot(2, 10, i+1)
	plt.axis('off')
	image = random.choice(coffee_images) # 随机选出一个图像
	plt.title(image.parts[-2]) # 通过glob对象取出它的文件夹名称,也就是分类名
	plt.imshow(Image.open(str(image))) # 展示

数据集预览
通过展示,对数据集内的图像有个大概的了解

制作数据集

先编写数据的预处理过程,用来使用pytorch的api加载文件夹中的图像

transform = transforms.Compose([
	transforms.ToTensor(), # 先把图像转成张量
	transforms.Normalize( # 对像素值做归一化,将数据范围弄到-1,1
       mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225],
	),
])

加载文件夹

dataset = datasets.ImageFolder(data_path, transform=transform)

从数据中取所有的分类名

class_names = [k for k in dataset.class_to_idx]
print(class_names)

图像分类名
将数据集划分出训练集和验证集

train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

将数据集按划分成批次,以便使用小批量梯度下降

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

在一开始时,直接手动创建了Vgg-16网络,发现少数几个迭代后模型就收敛了,于是开始精简模型。

手动搭建的vgg16网络

class Vgg16(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		
		self.block1 = nn.Sequential(
			nn.Conv2d(3, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block2 = nn.Sequential(
			nn.Conv2d(64, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.Conv2d(128, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block3 = nn.Sequential(
			nn.Conv2d(128, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.Conv2d(256, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.Conv2d(256, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block4 = nn.Sequential(
			nn.Conv2d(256, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block5 = nn.Sequential(
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.pool = nn.AdaptiveAvgPool2d(7)
		self.classifier = nn.Sequential(
			nn.Linear(7*7*512, 4096),
			nn.Dropout(0.5),
			nn.ReLU(),
			nn.Linear(4096, 4096),
			nn.Dropout(0.5),
			nn.ReLU(),
			nn.Linear(4096, num_classes),
		)

	def forward(self, x):
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)
		x = self.block4(x)
		x = self.block5(x)
		x = self.pool(x)
		x = x.view(x.size(0),-1)
		x = self.classifier(x)
		return x
vgg = Vgg16(len(class_names)).to(device)
summary(vgg, input_size=(32, 3, 224, 224))

VGG16模型
通过模型结构的打印可以发现,VGG-16网络共有134285380个可训练参数(我加了BatchNorm,和官方的比会稍微多出一些),参数量非常巨大,对于咖啡豆识别这种小场景,这么多可训练参数肯定浪费,于是对原始的VGG-16网络结构进行精简。

精简后的咖啡豆识别网络

class Network(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		
		self.block1 = nn.Sequential(
			nn.Conv2d(3, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)

		self.block2 = nn.Sequential(
			nn.Conv2d(64, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.Conv2d(128, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		
		self.block3 = nn.Sequential(
			nn.Conv2d(128, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)

		self.pool = nn.AdaptiveAvgPool2d(7),

		self.classifier = nn.Sequential(
			nn.Linear(7*7*64, 64),
			nn.Dropout(0.4),
			nn.ReLU(),
			nn.Linear(64, num_classes)
		)
	
	def forward(self, x):
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)
		x = self.pool(x)
		x = x.view(x.size(0), -1)
		x = self.classifier(x)
		return x
model = Network(len(class_names)).to(device)
summary(model, input_size=(32, 3, 224, 224))

精简后的模型
可以看到精简后的网络模型参数量还不到原来的1/10,但是其在测试集上的正确率依然能够达到100%!

模型训练

编写训练函数

def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)
	
		pred = model(x)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= num_batches
	train_acc /= size

	return train_loss, train_acc

编写测试函数

def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)
	
	test_loss, test_acc = 0, 0
	for x, y in test_loader:
		x, y = x.to(device), y.to(device)

		pred = model(x)
		loss = loss_fn(pred, y)

		test_loss += loss.item()
		test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

开始训练

首先定义损失函数,优化器设置学习率,这里我们再弄一个学习率的衰减,再加上总的迭代次数,最佳模型的保存位置

epochs = 30
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.92**(epoch//2))
best_model_path = 'best_coffee_model.pth'

然后编写训练+测试的循环,并记录训练过程的数据

best_acc = 0
train_loss, train_acc = [], []
test_loss, test_acc = [], []
for epoch in epochs:
	model.train()
	epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
	scheduler.step()

	model.eval()
	with torch.no_grad();
		epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)
	
	train_loss.append(epoch_train_loss)
	train_acc.append(epoch_train_acc)
	test_loss.append(epoch_test_loss)
	test_acc.append(epoch_test_acc)

	lr = optimizer.state_dict()['param_groups'][0]['lr']

	if best_acc < epoch_test_acc:
		best_acc = epoch_test_acc
		best_model = copy.deepcopy(model)

	print(f"Epoch: {epoch+1}, Lr:{lr}, TrainAcc: {epoch_train_acc*100:.1f}, TrainLoss: {epoch_train_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}")

print(f"Saving Best Model with Accuracy: {best_acc*100:.1f} to {best_model_path}")
torch.save(best_model.state_dict(), best_model_path)
print('Done')

训练过程日志
可以看出,模型在测试集上的正确率最高达到了100%

展示训练过程

epoch_ranges = range(epochs)
plt.figure(figsize=(20,6))
plt.subplot(121)
plt.plot(epoch_ranges, train_loss, label='train loss')
plt.plot(epoch_ranges, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.figure(figsize=(20,6))
plt.subplot(122)
plt.plot(epoch_ranges, train_acc, label='train accuracy')
plt.plot(epoch_ranges, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练过程参数

模型效果展示

model.load_state_dict(torch.load(best_model_path))
model.to(device)
model.eval()

plt.figure(figsize=(20,4))
for i in range(20):
	plt.subplot(2, 10, i+1)
	plt.axis('off')
	image = random.choice(coffee_images)
	input = transform(Image.open(str(image))).to(device).unsqueeze(0)
	pred = model(input)
	plt.title(f'T:{image.parts[-2]}, P:{class_names[pred.argmax()]}')
	plt.imshow(Image.open(str(image)))

模型效果展示
通过结果可以看出,确实是所有的咖啡豆都正确的识别了。

总结与心得体会

  • 因为目前网络还是很快就收敛到一个很高的水平,所以应该还有很大的精简的空间,但是可能会稍微牺牲一些正确率。
  • 模型的选取要根据实际任务来确定,像咖啡豆种类识别这种任务,使用VGG-16太浪费了。
  • 在精简的过程中,没有感觉到训练速度有明显的变化 ,说明参数量和训练速度并没有直接的相关关系。
  • 连续多层参数一样的卷积操作好像比只用一层效果要好。
更多推荐

【lesson8】操作系统的理解和类比

文章目录操作系统是什么?为什么要有操作系统?怎么做?学校的例子(理解管理)银行的例子(类比操作系统)操作系统是什么?操作系统是一款软件,是为了进行软硬件资源管理的软件。为什么要有操作系统?操作系统是为了给用户提供一个良好,安全,简单的运行环境。这就是操作系统的目的。怎么做?上面的两个话题我们在Linux发展史这篇博客中

设计模式之代理模式的懂静态代理和动态代理

目录1概述1.1如何实现?1.2优点1.3缺点1.4适用场景2静态代理实现3JDK动态代理实现4CGlib动态代理实现5总结1概述代理模式(ProxyPattern)是一种结构型设计模式,它的概念很简单,它通过创建一个代理对象来控制对原始对象的访问。代理模式主要涉及两个角色:代理角色和真实角色。代理类负责代理真实类,为

mybatis简介&idea导入mybatis

mybatis简介Mybatis是Apache的一个Java开源项目,是一个支持动态Sql语句的持久层框架。Mybatis可以将Sql语句配置在XML文件中,避免将Sql语句硬编码在Java类中。与JDBC相比:1)Mybatis通过参数映射方式,可以将参数灵活的配置在SQL语句中的配置文件中,避免在Java类中配置参

设计模式:解释器模式

目录组件代码示例优缺点总结解释器模式(InterpreterPattern)是一种行为型设计模式,它定义了一种语言的文法,并且定义了该语言中各个元素的解释器。通过使用解释器,可以解析和执行特定的语言表达式。解释器模式的核心思想是将一个语言的文法表示为一个类的层次结构,并使用该类的实例来表示语言中的各个元素。每个元素都有

优化代码,提升代码性能

文章目录一、方法1.尽量指定类、方法的final修饰符二、变量1.循环内不要不断创建对象引用2.基本类型转换成字符串3.如果变量的初值会被覆盖,就没有必要给变量赋初值4.尽量使用基本数据类型,避免不必要的装箱、拆箱和空指针判断三、常量1.将常量声明为staticfinal,并以大写命名2.禁止使用JSON转化对象四、对

nvme prp模型代码处理流程分析

以下函数是prp相关的源码。/**prp模型,除了第一个dmaaddr不是page_size对齐的其余的dmaaddr都要求是page_size对齐的*/staticblk_status_tnvme_pci_setup_prps(structnvme_dev*dev,structrequest*req,structnv

Google Data Fusion构建数据ETL任务

Google云平台提供了一个DataFusion的产品,是基于开源的CDAP做的一个图形化的编辑工具,可以很方便的来完成数据处理的任务,而无需编写代码。假设我们现在要构建一个ETL的任务,从Kafka中消费一些数据,经过处理之后把数据存放到Bigquery中。首先我们要准备一些测试数据发送到Kafka。这里我是在GKE

2023年腾讯云轻量应用服务器16核32G28M配置测评

腾讯云轻量应用服务器16核32G28M配置优惠价3468元15个月(支持免费续3个月/送同配置3个月),轻量应用服务器具有100%CPU性能,系统盘为380GBSSD盘,28M带宽下载速度3584KB/秒,月流量6000GB,折合每天200GB流量,超出月流量包的流量按照0.8元每GB的价格支付流量费,地域节点可选广州

【自学开发之旅】Flask-restful-Jinjia页面编写template-回顾(五)

restful是web编程里重要的概念–一种接口规范也是一种接口设计风格设计接口:要考虑:数据返回、接收数据的方式、url、方法统一风格rest–表现层状态转移web–每一类数据–资源资源通过http的动作来实现状态转移GET、PUT、POST、DELETEpath组成:/{version}/{resources}/{

分布式运用之企业级日志ELFK+logstash的过滤模块

一、ELFK集群部署(Filebeat+ELK)在搭建ELK的基础上安装Filebeat服务,Filebeat服务可以布置在以下任意一台主机,本次实验将布置在apache服务器的节点上步骤一:安装Filebeat(在apache节点操作)#上传软件包filebeat-6.7.2-linux-x86_64.tar.gz到

面向对象进阶

文章目录面向对象进阶一.static1.静态变量2.静态方法3.static的注意事项二.继承1.概述2.特点3.子类可以继承父类中的内容4.继承中成员变量的访问特点5.继承中成员方法的访问特点6.继承中构造方法的访问特点7.this和super使用总结三.多态1.认识多态2.多态中调用成员的特点3.多态的优势和弊端四

热文推荐