8

论文解读:Unsupervised Domain Adaptation by Backpropagation

 2 years ago
source link: https://weisenhui.top/posts/28120.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client
论文解读:Unsupervised Domain Adaptation by Backpropagation

会议:ICML 2015
论文题目:Unsupervised Domain Adaptation by Backpropagation
论文地址:http://proceedings.mlr.press/v37/ganin15.pdf
论文代码:https://github.com/fungtion/DANN

问题描述:深度学习的模型在source domain数据集上训练的很好(90%左右),但是迁移到target domain的效果就很差(54%左右),这种现象叫做domain shift。

Target Domain的图片是无标签

《Unsupervised domain Adaptation by Backpropagation》这篇论文发表于2015 ICML,目前引用量已经3000+,这篇文章讲得非常好,把对抗训练的思想应用到分布迁移上面。具体来说在原来feature mapping的基础上外接一个domain classifier(之前那个叫做label classifier),这个domain classifier的作用是判别当前样本是属于哪个domain的,如果你的数据集只有两个分布,那么这个classifier就是一个二分类任务。如果你正常进行梯度更新的话,feature mapping这个向量在不同的domain上就会dissimilar,但是如果你加了梯度反向层(让encoder这个部分关于domain的loss传回来的梯度是反向更新的),那么这个features mapping在不同的domain(数据集)上就可以表现出特征不变性(也就是体现出特征解耦)

DANN模型架构

本文核心思想:训练上面的模型,让domain classifier无法分辨出你输入的x属于哪个domain。这时,我们的目的就达到了,因为我们在分不清这个x是哪个domain的情况下,模型还是能用label predictor分类器做出x的label的预测。说明我们换个domain的数据,label预测任务的性能仍然较好。

不同模块的功能

  • Domain Classifier:区分输入图片是属于哪个domain的(source domain或target domain)
  • Label Predictor:识别source domain图片上的数字
  • Feature Extrator:帮助Label Predictor做预测;捅domain classifier一刀,做和他相反的事情。

下图来自李宏毅老师的课件,正好把论文中的架构图解释清楚了

训练时需要用到三组数据和标签

  1. source domain:黑色背景的4,图片上的数字为”4”
  2. source domain:黑色背景的4,图片来自domain0
  3. target domain:彩色背景的4,图片来自domain1

梯度反转层

核心代码部分

结合代码去理解上面的模型结构图就很简单了

1. 梯度翻转层的作用

通过梯度翻转层来提取domain-invariant特征的方法真的是最好的方法吗?其实你feature extractor最大化Ld loss(极端点,source domain图片预测成target,target domain图片预测成source),在某种程度是也是把source domain和target domain分开。所以它未必是最好的做法。不过用梯度翻转层这个trick也确实是有用的。 LHY 18:57

2. DANN和GAN的训练方式对比

DANN是通过梯度翻转层,来实现一步到位的模型更新的(不再需要像GAN一样交替训练网络)。

那这里我们再回顾下GAN是怎么更新模型的。

GAN的训练过程(交替训练)

  1. 训练Discriminator,固定Generator梯度不更新梯度不更新,使他能分清楚真假图片
  2. 训练Generator(梯度上升法),固定Discriminator,使Fake image输入到Discriminator的输出标签接近Real
  3. 交替训练1,2两步,最终我们希望Generator能骗过Discriminator(即真假图片输入到Discriminator,输出都是Real

3. Domain Adaptation的类型

本文介绍的是一种Adversarial based Domain Adaptation,但Domain Adaptation还有其他的类型,Divergence based Domain Adaptation,Reconstruction based Domain Adaptation

Divergence based Domain Adaptation

Reconstruction based Domain Adaptation

更多的内容推荐阅读博客:Medium - Understanding Domain Adaptation


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK