54

使用迁移学习快速训练识别特定风格的图片

 6 years ago
source link: https://xiaozhuanlan.com/topic/4672501389?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

前几天接到一个任务,需要从我们app的feed流中的筛选一些「优质」图片,作为运营同学的精选feed候选池。这里「优质」的参考就是以前运营同学手工筛序的精选feed图片。问题并不难,最容易想到的方向有两个:

  1. 机器学习方向,训练一个能够识别这种「优质」风格图片的模型。
  2. 过滤推荐方向,利用用户来测试feed图片质量(根据点赞、评论、观看张数、停留时间等指标),使用用户来筛选优质feed图片(用户的偏好千奇百怪,筛选结果可能未必如你所想,典型如今日头条……)。

今天我们介绍如何使用机器学习解决这个问题。具体来讲,由于时间紧,任务重,我们决定使用迁移学习来完成这个任务。后面如果有时间,我们也会尝试一下使用用户来过滤和筛选优质图片。

什么是迁移学习

迁移学习 (Transfer learning) 顾名思义就是就是把已学训练好的模型参数迁移到新的模型来帮助新模型训练。考虑到大部分数据或任务是存在相关性的,所以通过迁移学习我们可以将已经学到的模型参数(也可理解为模型学到的知识)通过某种方式来分享给新模型从而加快并优化模型的学习效率不用像大多数网络那样从零学习。

为什么使用迁移学习

  • 很多时候,你可能并没有足够大的数据集来训练模型,更不用说带有高质量标签的数据集了。使用已经训练好的网络,可以降低用于训练的数据集大小要求。
  • 从零开始训练一个深度网络是非常消耗算力和时间的。如果再将模型调整、超参数调整等有点玄学的流程加进去,消耗的时间会更多。对于创业公司来说,很多时候是很难给出这么多的时间预算来解决一个模型问题的。
  • 基于迁移学习训练一个模型往往只需要训练有限的几层网络,或者使用已有网络作为特征生成器,使用常规机器学习方法(如svm)来训练分类器。整体训练时间大幅降低。效果可能不是最好的,但是往往能够在短时间内帮你训练出一个够用的模型,解决当前的实际问题。

也就是说,近几年深度学习的各种突破本质上还是建立在数据集的完善和算力的提升。算法方面的提升带来的突破其实不如前两者明显。如果你是一个开发者,具体到要使用机器学习解决特定问题的时候,你一定想清楚你能否搞定数据集和算力的问题,如果不能,不妨尝试一下迁移学习。

如何进行迁移学习

我们的任务是筛选优质feed图片,其实就是一个优质图片与普通图片的二分类问题。

运营给出的「优质」参考图片:

RjaEZfi.jpg!web

直观感受是,健身摆拍图、美食图和少量风光照是她们眼中的优质图片:joy:

运营给出的「普通」参考图片:

rIjUrqQ.jpg!web

直观感受是,屏幕截图和没什么特点的图片被认为是普通图片。

我们迁移学习的过程就是复用训练好的(部分)网络和权重,然后构建我们自己的模型进行训练:

vyymQvq.jpg!web

迁移学习在选择预训练网络时有一点需要注意:预训练网络与当前任务差距不大,否则迁移学习的效果会很差。这里根据我们的任务类型,我们选择了深度残差网络 ResNet50, 权重选择imagenet数据集。选择 RetNet 的主要原因是之前我们训练的图片鉴黄模型是参考雅虎开源的 open NSFW , 而这个模型使用的就是残差网络,模型效果让我们影响深刻。完整代码如下(keras + tensorflow):

```

from keras import applications

from keras.preprocessing.image import ImageDataGenerator

from keras import optimizers

from keras.models import Sequential, Model

from keras.layers import Dropout, Flatten, Dense, GlobalAveragePooling2D

from keras import backend as k

from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, EarlyStopping

img_width, img_height = 256, 256

train_data_dir = "tf_files/codoon_photos"

validation_data_dir = "tf_files/codoon_photos"

nb_train_samples = 4125

nb_validation_samples = 466

batch_size = 16

epochs = 50

model = applications.ResNet50(include_top=False, weights='imagenet', input_shape=(img_width, img_height, 3))

Freeze the layers which you don't want to train. Here I am freezing the all layers.

for layer in model.layers[:]:

layer.trainable = False

Adding custom Layer

We only add

x = model.output

x = Flatten()(x)

Adding even more custom layers

x = Dense(1024, activation="relu")(x)

x = Dropout(0.5)(x)

x = Dense(1024, activation="relu")(x)

predictions = Dense(2, activation="softmax")(x)

creating the final model

model_final = Model(input = model.input, output = predictions)

compile the model

model_final.compile(loss = "categorical_crossentropy", optimizer = optimizers.SGD(lr=0.0001, momentum=0.9), metrics=["accuracy"])

Initiate the train and test generators with data Augumentation

train_datagen = ImageDataGenerator(

rescale = 1./255,

horizontal_flip = True,

fill_mode = "nearest",

zoom_range = 0.3,

width_shift_range = 0.3,

height_shift_range=0.3,

rotation_range=30)

test_datagen = ImageDataGenerator(

rescale = 1./255,

horizontal_flip = True,

fill_mode = "nearest",

zoom_range = 0.3,

width_shift_range = 0.3,

height_shift_range=0.3,

rotation_range=30)

train_generator = train_datagen.flow_from_directory(

train_data_dir,

target_size = (img_height, img_width),

batch_size = batch_size,

class_mode = "categorical")

validation_generator = test_datagen.flow_from_directory(

validation_data_dir,

target_size = (img_height, img_width),

class_mode = "categorical")

Save the model according to the conditions

checkpoint = ModelCheckpoint("resnet50_retrain.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)

early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK