引文
Gradio是一个是用友好的web界面演示机器学习模型的最快方法,它的操作非常简便,很方便上手。
MNIST是几乎每个接触机器学习的同学都使用过的数据集,内含0~9共10个数字的上千张手写图片,其每张图片的大小均为28px*28px
Pytorch是一款非常方便的深度学习库,可以轻松搭建深度神经网络。
本文将使用Pytorch训练卷积神经网络(CNN)来进行手写数字识别,然后使用Gradio中的手写板功能输入手写数字,进行识别测试。在文章末尾可以下载本文中的全部代码(ipynb格式)
现状
目前网络上的中文教程大多局限于搭建MNIST识别模型,并使用数据集的内建测试集进行测试,并未涉及自行输入图片进行测试,而使用Gradio进行可视化展示的更是少之又少。
Gradio官方文档中的使用方法利用了外部模型,并未涉及自行训练。而外网大部分教程都使用tensorflow。
流程
导入必备包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import gradio
设定一些超参数
BATCH_SIZE=512 #大概需要2G的显存
EPOCHS=10 # 总共训练批次
DEVICE = torch.device("mps") # mps, cuda or cpu
超参数可以理解为决定模型如何进行训练的设置参数。BATCH_SIZE代表一次进入训练的图片数量;EPOCHS代表训练多少个周期;DEVICE代表使用什么硬件进行训练——本文使用了MacBook的M1芯片,因此选择mps。N卡用户请选择cuda,其余可使用cpu进行训练(或使用其他核心,在此不详述)。
加载内建的训练数据和测试数据
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
定义卷积神经网络
class ConvNet (nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*8*8,640)
self.fc2=nn.Linear(640,10)
def forward (self,x):
in_size = x.size(0)
out=self.conv1(x)
out=F.relu(out)
out=F.max_pool2d(out,2,2)
out = self.conv2(out)
out = F.relu(out)
out = out.view(in_size, -1)
out=self.fc1(out)
out=self.fc2(out)
out = F.log_softmax(out, dim=1)
return out
神经网络使用了卷积层--->relu激活层--->2*2最大池化层--->卷积层--->relu激活层--->两个全连接层--->softmax
激活层的结构。
定义损失函数和优化器
model=ConvNet().to(DEVICE)
optimizer=optim.Adam(model.parameters())
封装训练和测试函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output=model(data)
loss=F.nll_loss(output,target)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model,device,test_loader):
model.eval()
test_loss=0
correct=0
with torch.no_grad():
for data,target in test_loader:
data,target=data.to(device),target.to(device)
output=model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
开始训练
for epoch in range(EPOCHS):
#test(model, DEVICE, test_loader)
train(model,DEVICE,train_loader,optimizer,epoch)
test(model,DEVICE,test_loader)
输出样例:
Train Epoch: 0 [14848/60000 (25%)] Loss: 0.327558
Train Epoch: 0 [30208/60000 (50%)] Loss: 0.129759
Train Epoch: 0 [45568/60000 (75%)] Loss: 0.135278
Test set: Average loss: 0.0902, Accuracy: 9718/10000 (97%)
加载Gradio
定义一个预测函数
不同于之前版本可以加载pytorch模型,当前版本的Gradio必须自行书写函数传入。在Gradio启动后,手写板中书写的数字将会以单通道(B/W)的形式传入预测函数中。通过type( )
函数查看,该图片为numpy类,因此需要先将其通过transforms.ToTensor()
将其转化为pytorch的张量tensor形式。随后,拓展图片维度以适应神经网络要求,随后将其送入DEVICE中进行推理。代码如下:
def predict(inp):
img = transforms.ToTensor()(inp)
#img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
img_ = img.unsqueeze(0) # 拓展维度, 拓展batch_size那一维
img_ = img_.to(DEVICE)
# 推理过程
output = model(img_) # net是提前读取的模型
pred_index = int(torch.argmax(output, dim=1))
return pred_index
启动Gradio
inp = gradio.Sketchpad()
io = gradio.Interface(fn=predict,inputs=inp, outputs="text",live=True)
io.launch()
gradio.Sketchpad()
是指加载Gradio的手写板,将其接收到的笔划赋值给inp,gradio.Interface()
是指启动Gradio界面,fn代表使用的函数,这里用到了上面的predict()
函数,即输入inp,输出pred_index,输出方式为“text”文本格式。
代码下载
使用Colab 在线运行:(需要连接Google)
https://colab.research.google.com/drive/1CgcdxfgQkHth98IqHo3uQlMFMeN0bM5C#scrollTo=0zrX_u-0KN6e
国内云盘下载:
gradio.ipynb: https://url80.ctfile.com/f/35431880-763654876-9c0f8c?p=9119 (密码:9119)
参考
[1] 加载外部图片进行测试
[2] 旧版Gradio的使用方式(现在已经不再适用)
[3] Gradio官网
Comments | 4 条评论
博主 SkEy
《Swift》
《PHP》
《Makefile》
《Objective-C》
连SDATA都出来了,蚌埠住了
博主 Astrophel
@SkEy 确实难绷
博主 1439429910
神!
博主 ShawnJR
(╯°口°)╯(┴—┴