3

基于火炬闪电的图像分类

 3 years ago
source link: https://panchuang.net/2021/08/05/%e5%9f%ba%e4%ba%8e%e7%81%ab%e7%82%ac%e9%97%aa%e7%94%b5%e7%9a%84%e5%9b%be%e5%83%8f%e5%88%86%e7%b1%bb/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

在本教程中,我将使用PyTorch Lightning从https://www.kaggle.com/brsdincer/vehicle-detection-image-set数据集中对图像进行分类,目的是使用PyTorch Lightning将图像分类为车辆和非车辆。本教程假定您熟悉数据科学和Pytorch。https://www.kaggle.com/brsdincer/vehicle-detection-image-set

数据集包含两个文件夹,分别包含车辆和非车辆的图像。我们的任务是创建一个可以对车辆和非车辆进行分类的分类器。

笔记本电脑

此笔记本使用GPU在Kaggle笔记本上运行。请注意,在没有GPU的情况下运行此笔记本可能需要较长时间。原始笔记本可在此处找到here

如果运行下面的单元时出现错误,则说明您的环境中尚未安装一个或多个模块。

import pandas as pd
import os
from sklearn.model_selection import train_test_split
import glob

接下来,我们将创建一个包含路径和目标的数据帧,以使访问数据变得更容易。

vehicles = glob.glob(f"../input/vehicle-detection-image-set/data/vehicles/*.png")#returns a list of paths of images in Vehicles folder

接下来,我们将创建一个Pytorch-Lightning数据模块发送给培训师

X  = df["mask_id"]
y = df["mask"]
x_train , x_test , y_train, y_test = train_test_split(X,y,test_size = 0.25,random_state = 42 ,)
x_val , x_test , y_val, y_test = train_test_split(x_test,y_test,test_size = 0.25,random_state = 42 ,)

“模型”(The Model)

我们将使用Resnet50型号。如果你使用的是你自己的神经网络,那么可能需要更长的训练时间才能得到同样的结果。

neural_network = torchvision.models.resnet50(pretrained= True)
neural_network.fc = torch.nn.Linear(2048,2)# changing the number of output features to 2

接下来,我们将使用闪电模块来制作模型。我们将使用StepLr计划程序进行学习速率

loss_func = torch.nn.CrossEntropyLoss()

如果您到目前为止已经学习了本教程,则可以使用一个纪元来训练该模型。关于PyTorch Lightning最好的部分是,您可以通过简单地设置“GPU=[GPU数量]”来设置GPU的数量

%%time # Checking the amount of time the cell takes to run
from pytorch_lightning import Trainer
model = Vehicle_Model()
module = Vehicle_DataModule()
trainer = Trainer(max_epochs=1,gpus = 1,callbacks = [checkpoint_callback])
trainer.fit(model,module)
trainer.test()
predictons = trainer.predict()

由此可见,我们获得了97.9%的准确率

原创文章,作者:fendouai,如若转载,请注明出处:https://panchuang.net/2021/08/05/%e5%9f%ba%e4%ba%8e%e7%81%ab%e7%82%ac%e9%97%aa%e7%94%b5%e7%9a%84%e5%9b%be%e5%83%8f%e5%88%86%e7%b1%bb/


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK