7

如何将训练好的pytorch模型部署到安卓设备上

 2 years ago
source link: https://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650429903&%3Bidx=5&%3Bsn=9d8b73133eb5f729fba06b6bf6b44bcb
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

如何将训练好的pytorch模型部署到安卓设备上

点这里关注我→ AINLP 2022-04-10 05:46

来源:投稿 作者:AI浩

编辑:学姐

640?wx_fmt=jpeg

这篇文章演示如何将训练好的pytorch模型部署到安卓设备上。我也是刚开始学安卓,代码写的简单。

环境:pytorch版本:1.10.0

# 模型转化

pytorch_android支持的模型是.pt模型,我们训练出来的模型是.pth。所以需要转化才可以用。

先看官网上给的转化方式:

import torchimport torchvisionfrom torch.utils.mobile_optimizer import optimize_for_mobilemodel = torchvision.models.mobilenet_v3_small(pretrained=True)model.eval()example = torch.rand(1, 3, 224, 224)traced_script_module = torch.jit.trace(model, example)optimized_traced_model = optimize_for_mobile(traced_script_module)optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")

这个模型在安卓对应的包:

repositories {    jcenter()}dependencies {    implementation 'org.pytorch:pytorch_android_lite:1.9.0'    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'}

注:pytorch_android_lite版本和转化模型用的版本要一致,不一致就会报各种错误。

目前用这种方法有点问题,我采用的另一种方法。

转化代码如下:

import torchimport torch.utils.data.distributed# pytorch环境中model_pth = 'model_31_0.96.pth' #模型的参数文件mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件model = torch.load(model_pth)model.eval() # 模型设为评估模式device = torch.device('cpu')model.to(device)# 1张3通道224*224的图片input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式mobile = torch.jit.trace(model, input_tensor) # 模型转化mobile.save(mobile_pt) # 保存文件

定义模型文件和转化后的文件路径。

load模型。(这里要注意,如果保存模型)

torch.save(model,'models.pth')

加载模型则是

model=torch.load('models.pth')

如果保存模型是

torch.save(model.state_dict(),"models.pth")

加载模型则是

model.load_state_dict(torch.load('models.pth'))

定义输入数据格式。

模型转化,然后再保存模型。

# 安卓部署

新建安卓项目,选择Empy Activity,然后选择Next

640?wx_fmt=png

然后,填写项目信息,选择安卓版本,我用的4.4,点击完成

640?wx_fmt=png

导入pytorch_android的包

//pytorchimplementation 'org.pytorch:pytorch_android:1.10.0'implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
640?wx_fmt=png

如果有参数报错请参照我的完整的配置,代码如下:

plugins {    id 'com.android.application'}android {    compileSdk 32    defaultConfig {        applicationId "com.example.myapplication"        minSdk 21        targetSdk 32        versionCode 1        versionName "1.0"        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"    }    buildTypes {        release {            minifyEnabled false            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'        }    }    compileOptions {        sourceCompatibility JavaVersion.VERSION_1_8        targetCompatibility JavaVersion.VERSION_1_8    }}dependencies {    implementation 'androidx.appcompat:appcompat:1.3.0'    implementation 'com.google.android.material:material:1.4.0'    implementation 'androidx.constraintlayout:constraintlayout:2.0.4'    testImplementation 'junit:junit:4.13.2'    androidTestImplementation 'androidx.test.ext:junit:1.1.3'    androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'    //pytorch    implementation 'org.pytorch:pytorch_android:1.10.0'    implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'}

页面的配置如下:

<?xml version="1.0" encoding="utf-8"?><FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"    xmlns:tools="http://schemas.android.com/tools"    android:layout_width="match_parent"    android:layout_height="match_parent"    tools:context=".MainActivity">    <ImageView        android:id="@+id/image"        android:layout_width="match_parent"        android:layout_height="match_parent"        android:scaleType="fitCenter" />    <TextView        android:id="@+id/text"        android:layout_width="match_parent"        android:layout_height="wrap_content"        android:layout_gravity="top"        android:textSize="24sp"        android:background="#80000000"        android:textColor="@android:color/holo_red_light" /></FrameLayout>

这个页面只有两个空间,一个展示图片,一个显示文字。

640?wx_fmt=png

新增assets文件夹,然后将转化的模型和待测试的图片放进去。

640?wx_fmt=png

新增ImageNetClasses类,这个类存放类别名字。

640?wx_fmt=png

代码如下:

package com.example.myapplication;public class ImageNetClasses {    public static String[] IMAGENET_CLASSES = new String[]{            "Black-grass",            "Charlock",            "Cleavers",            "Common Chickweed",            "Common wheat",            "Fat Hen",            "Loose Silky-bent",            "Maize",            "Scentless Mayweed",            "Shepherds Purse",            "Small-flowered Cranesbill",            "Sugar beet",    };}

在MainActivity类中,增加模型推理的逻辑。

完成代码如下:

package com.example.myapplication;import android.content.Context;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.os.Bundle;import android.util.Log;import android.widget.ImageView;import android.widget.TextView;import org.pytorch.IValue;import org.pytorch.Module;import org.pytorch.Tensor;import org.pytorch.torchvision.TensorImageUtils;import org.pytorch.MemoryFormat;import java.io.File;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStream;import java.io.OutputStream;import androidx.appcompat.app.AppCompatActivity;public class MainActivity extends AppCompatActivity {    @Override    protected void onCreate(Bundle savedInstanceState) {        super.onCreate(savedInstanceState);        setContentView(R.layout.activity_main);        Bitmap bitmap = null;        Module module = null;        try {            // creating bitmap from packaged into app android asset 'image.jpg',            // app/src/main/assets/image.jpg            bitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));            // loading serialized torchscript module from packaged into app android asset model.pt,            // app/src/model/assets/model.pt            module = Module.load(assetFilePath(this, "models.pt"));        } catch (IOException e) {            Log.e("PytorchHelloWorld", "Error reading assets", e);            finish();        }        // showing image on UI        ImageView imageView = findViewById(R.id.image);        imageView.setImageBitmap(bitmap);        // preparing input tensor        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);        // running the model        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();        // getting tensor content as java array of floats        final float[] scores = outputTensor.getDataAsFloatArray();        // searching for the index with maximum score        float maxScore = -Float.MAX_VALUE;        int maxScoreIdx = -1;        for (int i = 0; i < scores.length; i++) {            if (scores[i] > maxScore) {                maxScore = scores[i];                maxScoreIdx = i;            }        }        System.out.println(maxScoreIdx);        String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];        // showing className on UI        TextView textView = findViewById(R.id.text);        textView.setText(className);    }    /**     * Copies specified asset to the file in /files app directory and returns this file absolute path.     *     * @return absolute file path     */    public static String assetFilePath(Context context, String assetName) throws IOException {        File file = new File(context.getFilesDir(), assetName);        if (file.exists() && file.length() > 0) {            return file.getAbsolutePath();        }        try (InputStream is = context.getAssets().open(assetName)) {            try (OutputStream os = new FileOutputStream(file)) {                byte[] buffer = new byte[4 * 1024];                int read;                while ((read = is.read(buffer)) != -1) {                    os.write(buffer, 0, read);                }                os.flush();            }            return file.getAbsolutePath();        }    }}

然后运行。

640?wx_fmt=png
0?wx_fmt=png
AINLP
一个有趣有AI的自然语言处理公众号:关注AI、NLP、机器学习、推荐系统、计算广告等相关技术。公众号可直接对话双语聊天机器人,尝试自动对联、作诗机、藏头诗生成器,调戏夸夸机器人、彩虹屁生成器,使用中英翻译,查询相似词,测试NLP相关工具包。
343篇原创内容
Official Account
进技术交流群请添加AINLP小助手微信(id: ainlper)
请备注具体方向+所用到的相关技术点
640?wx_fmt=jpeg

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。

640?wx_fmt=jpeg

阅读至此了,分享、点赞、在看三选一吧🙏


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK