基于PaddlePaddle的强化学习算法CycleGAN Fork 72 收藏
source link: https://my.oschina.net/u/4067628/blog/3234532
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
基于PaddlePaddle的强化学习算法CycleGAN Fork 72 收藏
生成对抗网络(Generative Adversarial Network[1], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。
下载安装命令 ## CPU版本安装命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安装命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
CycleGAN可以利用非成对的图片进行图像翻译,即输入为两种不同风格的不同图片,自动进行风格转换。传统的GAN是单向生成,而CycleGAN是互相生成,网络是个环形,所以命名为Cycle。并且CycleGAN一个非常实用的地方就是输入的两张图片可以是任意的两张图片,也就是unpaired。
CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用convolution-norm-ReLU作为基础结构,解码部分的网络结构由transpose convolution-norm-ReLU组成,判别网络基本是由convolution-norm-leaky_ReLU作为基础结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。CycleGAN的结构如下: Cycle-Gan总结构有四个网络,第一个网络为生成(转化)网络命名为G:X---->Y;第二个网络为生成(转化)网络命名为F:Y--->X;第三个网络为对抗网络命名为Dx,鉴别输入图像是不是X;第四个网络为对抗网络命名为Dy,鉴别输入图像是不是Y。
如上图,以马(X)和斑马(Y)为例,G网络将马的图像转化为斑马图像;F网络将斑马的图像转化为马的图像;Dx网络鉴别输入的图像是不是马;Dy网络鉴别输入图像是不是斑马;
这四个网络仅有两个网络结构,即G和F都是生成(转化)网络,这两者的网络结构相同,Dx和Dy都是对抗性网络,这两者的网络结构相同。CycleGAN 效果展示:
#代码结构
# ├── data_reader.py # 读取、处理数据。
# ├── layers.py # 封装定义基础的layers。
# ├── model.py # 定义基础生成网络和判别网络。
# ├── trainer.py # 构造loss和训练网络。
# ├── train.py # 训练脚本。
# └── infer.py # 预测脚本。
!cd /home/aistudio/data/data10040/ && unzip -qo horse2zebra.zip
本项目使用 horse2zebra 数据集 来进行模型的训练测试工作,horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。
# 数据准备
# 本教程使用 horse2zebra 数据集 来进行模型的训练测试工作,horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。
# 以下路径结构:
# data
# |-- horse2zebra
# | |-- testA
# | |-- testA.txt
# | |-- testB
# | |-- testB.txt
# | |-- trainA
# | |-- trainA.txt
# | |-- trainB
# | `-- trainB.txt
# 以上数据文件中,data文件夹需要放在训练脚本train.py同级目录下。testA为存放野马测试图片的文件夹,testB为存放斑马测试图片的文件夹,testA.txt和testB.txt分别为野马和斑马测试图片路径列表文件,格式如下:
# testA/n02381460_9243.jpg
# testA/n02381460_9244.jpg
# testA/n02381460_9245.jpg
# 训练数据组织方式与测试数据相同。
#安装scipy
!pip install imageio
!pip install scipy==1.2.1
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.6.1) Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio) (1.16.4) Requirement already satisfied: pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio) (6.2.0) Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/ Collecting scipy==1.2.1 WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='pypi.mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /simple/scipy/ WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/ WARNING: Retrying (Retry(total=3, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/ WARNING: Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/ WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ConnectTimeoutError(<pip._vendor.urllib3.connection.VerifiedHTTPSConnection object at 0x7f658b56e450>, 'Connection to mirrors.tuna.tsinghua.edu.cn timed out. (connect timeout=15)')': /pypi/web/simple/scipy/ Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3e/7e/5cee36eee5b3194687232f6150a89a38f784883c612db7f4da2ab190980d/scipy-1.2.1-cp37-cp37m-manylinux1_x86_64.whl (24.8MB) |████████████████████████████████| 24.8MB 141kB/s eta 0:00:01 Requirement already satisfied: numpy>=1.8.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scipy==1.2.1) (1.16.4) Installing collected packages: scipy Found existing installation: scipy 1.3.0 Uninstalling scipy-1.3.0: Successfully uninstalled scipy-1.3.0 Successfully installed scipy-1.2.1
#训练
#在GPU单卡上训练:
!python cycle_gan/train.py --epoch=2 --use_gpu True
----------- Configuration Arguments ----------- batch_size: 1 epoch: 2 init_model: None output: ./output profile: False run_ce: False run_test: True save_checkpoints: True use_gpu: 1 ------------------------------------------------ W0316 17:50:21.449174 346 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0 W0316 17:50:21.453552 346 device_context.cc:245] device: 0, cuDNN Version: 7.3. I0316 17:50:23.286839 346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel. I0316 17:50:23.343647 346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1 I0316 17:50:23.495808 346 parallel_executor.cc:375] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0 I0316 17:50:24.629184 346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel. I0316 17:50:24.633793 346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1 I0316 17:50:24.757025 346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel. I0316 17:50:24.810863 346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1 I0316 17:50:26.055040 346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel. I0316 17:50:26.059253 346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1 epoch0; batch0; g_A_loss: 14.300308227539062; d_B_loss: 1.703646183013916; g_B_loss: 13.938850402832031; d_A_loss: 1.6837800741195679; Batch_time_cost: 2.89 epoch0; batch50; g_A_loss: 6.143564224243164; d_B_loss: 0.43444156646728516; g_B_loss: 5.804959297180176; d_A_loss: 0.3155096173286438; Batch_time_cost: 0.12 epoch0; batch100; g_A_loss: 6.256138324737549; d_B_loss: 0.2702423632144928; g_B_loss: 5.725310325622559; d_A_loss: 0.3236050307750702; Batch_time_cost: 0.12 epoch0; batch150; g_A_loss: 7.066346645355225; d_B_loss: 0.31580042839050293; g_B_loss: 6.818117141723633; d_A_loss: 0.4519233703613281; Batch_time_cost: 0.13 epoch0; batch200; g_A_loss: 6.909765243530273; d_B_loss: 0.3620782196521759; g_B_loss: 7.228463172912598; d_A_loss: 0.36587968468666077; Batch_time_cost: 0.13 epoch0; batch250; g_A_loss: 5.192142009735107; d_B_loss: 0.24265910685062408; g_B_loss: 5.763035297393799; d_A_loss: 1.1450282335281372; Batch_time_cost: 0.13 epoch0; batch300; g_A_loss: 5.316933631896973; d_B_loss: 0.19533368945121765; g_B_loss: 4.987677097320557; d_A_loss: 0.22845549881458282; Batch_time_cost: 0.13 epoch0; batch350; g_A_loss: 7.186776638031006; d_B_loss: 0.16458189487457275; g_B_loss: 5.839504241943359; d_A_loss: 0.1668826937675476; Batch_time_cost: 0.12 epoch0; batch400; g_A_loss: 5.885252952575684; d_B_loss: 0.30592358112335205; g_B_loss: 5.449808120727539; d_A_loss: 0.1534576117992401; Batch_time_cost: 0.13 epoch0; batch450; g_A_loss: 6.440225124359131; d_B_loss: 0.11404794454574585; g_B_loss: 5.590849876403809; d_A_loss: 0.1966126412153244; Batch_time_cost: 0.13 epoch0; batch500; g_A_loss: 6.872323036193848; d_B_loss: 0.19951260089874268; g_B_loss: 6.571059226989746; d_A_loss: 0.23495124280452728; Batch_time_cost: 0.13 epoch0; batch550; g_A_loss: 5.0895466804504395; d_B_loss: 0.16031894087791443; g_B_loss: 5.121213912963867; d_A_loss: 0.25687021017074585; Batch_time_cost: 0.13 epoch0; batch600; g_A_loss: 6.153958797454834; d_B_loss: 0.21091848611831665; g_B_loss: 6.268405914306641; d_A_loss: 0.11446988582611084; Batch_time_cost: 0.13 epoch0; batch650; g_A_loss: 8.350132942199707; d_B_loss: 0.13534343242645264; g_B_loss: 8.49114990234375; d_A_loss: 0.17544472217559814; Batch_time_cost: 0.13 epoch0; batch700; g_A_loss: 4.763492584228516; d_B_loss: 0.25141626596450806; g_B_loss: 4.048842430114746; d_A_loss: 0.1611996442079544; Batch_time_cost: 0.13 epoch0; batch750; g_A_loss: 6.567712783813477; d_B_loss: 0.2876878082752228; g_B_loss: 6.591218948364258; d_A_loss: 0.17310750484466553; Batch_time_cost: 0.13 epoch0; batch800; g_A_loss: 4.201617240905762; d_B_loss: 0.28061026334762573; g_B_loss: 4.326346397399902; d_A_loss: 0.17125019431114197; Batch_time_cost: 0.13 epoch0; batch850; g_A_loss: 6.599902153015137; d_B_loss: 0.33524584770202637; g_B_loss: 5.3691182136535645; d_A_loss: 0.16773417592048645; Batch_time_cost: 0.12 epoch0; batch900; g_A_loss: 7.22122049331665; d_B_loss: 0.0716591626405716; g_B_loss: 6.713944435119629; d_A_loss: 0.16595764458179474; Batch_time_cost: 0.12 epoch0; batch950; g_A_loss: 6.072809219360352; d_B_loss: 0.17901702225208282; g_B_loss: 5.398962020874023; d_A_loss: 0.17636674642562866; Batch_time_cost: 0.13 epoch0; batch1000; g_A_loss: 5.913625717163086; d_B_loss: 0.0724739134311676; g_B_loss: 4.69740629196167; d_A_loss: 0.27375420928001404; Batch_time_cost: 0.13 epoch0; batch1050; g_A_loss: 9.782687187194824; d_B_loss: 0.12104632705450058; g_B_loss: 9.672157287597656; d_A_loss: 0.14316022396087646; Batch_time_cost: 0.13 epoch0; batch1100; g_A_loss: 4.571684837341309; d_B_loss: 0.06749674677848816; g_B_loss: 4.33621883392334; d_A_loss: 0.17319558560848236; Batch_time_cost: 0.13 epoch0; batch1150; g_A_loss: 5.194060802459717; d_B_loss: 0.13884581625461578; g_B_loss: 5.199503421783447; d_A_loss: 0.1418939232826233; Batch_time_cost: 0.13 epoch0; batch1200; g_A_loss: 4.909609794616699; d_B_loss: 0.12311942875385284; g_B_loss: 4.3402299880981445; d_A_loss: 0.2318917214870453; Batch_time_cost: 0.13 epoch0; batch1250; g_A_loss: 5.436093330383301; d_B_loss: 0.19272129237651825; g_B_loss: 6.465127944946289; d_A_loss: 0.23860839009284973; Batch_time_cost: 0.13 epoch0; batch1300; g_A_loss: 8.133463859558105; d_B_loss: 0.0819319561123848; g_B_loss: 7.404263496398926; d_A_loss: 0.10251586139202118; Batch_time_cost: 0.13 cycle_gan/train.py:134: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (fake_B_temp + 1) * 127.5).astype(np.uint8)) cycle_gan/train.py:136: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (fake_A_temp + 1) * 127.5).astype(np.uint8)) cycle_gan/train.py:138: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (cyc_A_temp + 1) * 127.5).astype(np.uint8)) cycle_gan/train.py:140: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (cyc_B_temp + 1) * 127.5).astype(np.uint8)) cycle_gan/train.py:142: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (input_A_temp + 1) * 127.5).astype(np.uint8)) cycle_gan/train.py:144: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (input_B_temp + 1) * 127.5).astype(np.uint8)) #保存生成的图 saved checkpoint to ./output/checkpoints/ epoch1; batch0; g_A_loss: 5.827764987945557; d_B_loss: 0.07218477874994278; g_B_loss: 5.5165019035339355; d_A_loss: 0.3245842158794403; Batch_time_cost: 0.12 epoch1; batch50; g_A_loss: 5.765468597412109; d_B_loss: 0.167547345161438; g_B_loss: 5.18796443939209; d_A_loss: 0.15456309914588928; Batch_time_cost: 0.12 epoch1; batch100; g_A_loss: 5.45959997177124; d_B_loss: 0.09713760763406754; g_B_loss: 4.907477855682373; d_A_loss: 0.06172751635313034; Batch_time_cost: 0.12 epoch1; batch150; g_A_loss: 6.784292221069336; d_B_loss: 0.11054498702287674; g_B_loss: 6.478621959686279; d_A_loss: 0.44314247369766235; Batch_time_cost: 0.12 epoch1; batch200; g_A_loss: 7.896561145782471; d_B_loss: 0.16192974150180817; g_B_loss: 7.6382269859313965; d_A_loss: 0.10663460940122604; Batch_time_cost: 0.12 epoch1; batch250; g_A_loss: 6.924862384796143; d_B_loss: 0.029113346710801125; g_B_loss: 6.796955585479736; d_A_loss: 0.07325638085603714; Batch_time_cost: 0.12 epoch1; batch300; g_A_loss: 4.884801864624023; d_B_loss: 0.050231873989105225; g_B_loss: 4.698901176452637; d_A_loss: 0.14451992511749268; Batch_time_cost: 0.13 epoch1; batch350; g_A_loss: 5.315099239349365; d_B_loss: 0.17729483544826508; g_B_loss: 5.134288787841797; d_A_loss: 0.18281573057174683; Batch_time_cost: 0.13 epoch1; batch400; g_A_loss: 7.244136810302734; d_B_loss: 0.2588624954223633; g_B_loss: 6.518290042877197; d_A_loss: 0.03540463373064995; Batch_time_cost: 0.12 epoch1; batch450; g_A_loss: 5.102941513061523; d_B_loss: 0.03143639117479324; g_B_loss: 4.737099647521973; d_A_loss: 0.22579550743103027; Batch_time_cost: 0.13 epoch1; batch500; g_A_loss: 6.038484573364258; d_B_loss: 0.07545653730630875; g_B_loss: 5.054262638092041; d_A_loss: 0.1474056839942932; Batch_time_cost: 0.12 epoch1; batch550; g_A_loss: 5.1528449058532715; d_B_loss: 0.05942108482122421; g_B_loss: 5.586222171783447; d_A_loss: 0.1414523869752884; Batch_time_cost: 0.13 epoch1; batch600; g_A_loss: 7.961068630218506; d_B_loss: 0.07359579205513; g_B_loss: 7.495856285095215; d_A_loss: 0.08565278351306915; Batch_time_cost: 0.12 epoch1; batch650; g_A_loss: 4.3616814613342285; d_B_loss: 0.17840822041034698; g_B_loss: 3.6735024452209473; d_A_loss: 0.15866130590438843; Batch_time_cost: 0.13 epoch1; batch700; g_A_loss: 4.804023742675781; d_B_loss: 0.18827767670154572; g_B_loss: 4.755307197570801; d_A_loss: 0.2991688549518585; Batch_time_cost: 0.13 epoch1; batch750; g_A_loss: 5.776893138885498; d_B_loss: 0.24969376623630524; g_B_loss: 5.4265851974487305; d_A_loss: 0.1804438978433609; Batch_time_cost: 0.13 epoch1; batch800; g_A_loss: 8.234237670898438; d_B_loss: 0.328776478767395; g_B_loss: 7.330533027648926; d_A_loss: 0.07429933547973633; Batch_time_cost: 0.13 epoch1; batch850; g_A_loss: 7.487462520599365; d_B_loss: 0.08431682735681534; g_B_loss: 7.30022668838501; d_A_loss: 0.1268233209848404; Batch_time_cost: 0.13 epoch1; batch900; g_A_loss: 3.5420680046081543; d_B_loss: 0.15772652626037598; g_B_loss: 2.79952335357666; d_A_loss: 0.2009810507297516; Batch_time_cost: 0.13 epoch1; batch950; g_A_loss: 6.323566436767578; d_B_loss: 0.09212709963321686; g_B_loss: 5.999871730804443; d_A_loss: 0.1642218381166458; Batch_time_cost: 0.12 epoch1; batch1000; g_A_loss: 4.416383743286133; d_B_loss: 0.4296090602874756; g_B_loss: 3.9065892696380615; d_A_loss: 0.14740586280822754; Batch_time_cost: 0.13 epoch1; batch1050; g_A_loss: 4.462809085845947; d_B_loss: 0.5468143820762634; g_B_loss: 4.7028489112854; d_A_loss: 0.10133746266365051; Batch_time_cost: 0.12 epoch1; batch1100; g_A_loss: 6.614782810211182; d_B_loss: 0.09065119922161102; g_B_loss: 6.871949672698975; d_A_loss: 0.265675812959671; Batch_time_cost: 0.12 epoch1; batch1150; g_A_loss: 4.825323104858398; d_B_loss: 0.029798883944749832; g_B_loss: 4.237264633178711; d_A_loss: 0.15145432949066162; Batch_time_cost: 0.13 epoch1; batch1200; g_A_loss: 5.156582355499268; d_B_loss: 0.31561416387557983; g_B_loss: 5.017688274383545; d_A_loss: 0.1431412398815155; Batch_time_cost: 0.13 epoch1; batch1250; g_A_loss: 5.140244007110596; d_B_loss: 0.12650391459465027; g_B_loss: 4.05291223526001; d_A_loss: 0.2101406753063202; Batch_time_cost: 0.12 epoch1; batch1300; g_A_loss: 3.725010633468628; d_B_loss: 0.25355270504951477; g_B_loss: 4.305200576782227; d_A_loss: 0.31700998544692993; Batch_time_cost: 0.13 saved checkpoint to ./output/checkpoints/
#应用固化的模型(训练10轮)进行预测,结果保存在output,训练150轮,可以达到上面简介模块展示的效果
!python cycle_gan/infer.py \
--init_model="output/freeze" \
--input="./data/data10040/horse2zebra/testA/n02381460_4260.jpg" \
--input_style A \
--output="output/freeze_infer_result"
# 可视化转换前后的效果
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import cv2
img= cv2.imread('data/data10040/horse2zebra/testA/n02381460_4260.jpg')
result_img= cv2.imread('output/freeze_infer_result/fake_n02381460_4260.jpg')
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
plt.imshow(result_img)
plt.show()
----------- Configuration Arguments ----------- init_model: output/freeze input: ./data/data10040/horse2zebra/testA/n02381460_4260.jpg input_style: A output: output/freeze_infer_result use_gpu: True ------------------------------------------------ /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:804: UserWarning: There are no operators in the program to be executed. If you pass Program manually, please use fluid.program_guard to ensure the current Program is being used. warnings.warn(error_info) W0316 17:58:40.541411 441 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0 W0316 17:58:40.545682 441 device_context.cc:245] device: 0, cuDNN Version: 7.3. cycle_gan/infer.py:63: DeprecationWarning: `imsave` is deprecated! `imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0. Use ``imageio.imwrite`` instead. (fake_temp + 1) * 127.5).astype(np.uint8))
使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/169459
下载安装命令 ## CPU版本安装命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安装命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
>> 访问 PaddlePaddle 官网,了解更多相关内容。
Recommend
-
55
pytorch-CycleGAN-and-pix2pix - Image-to-image translation in PyTorch (e.g. horse2zebra, edges2cats, and more)
-
40
引言 提起 ICNET,就不得不说说 ICNET 构建的初衷 - 解决图像语义分割在实时应用中的挑战。图像语义分割(semantic segmentation)是结合了图像分类和对象检测,对图像中的每个像素实现细粒度的分类,就像下面的图中看到的那样...
-
60
本文转载自百度 PaddlePaddle 百度深度学习平台PaddlePaddle于近期开源了基于会话(session-based)的推荐系统模型(SR-GNN)。 相较于之前通过循环神经网络(RNN)来对会话进行序列化建模导致的不能够得到用户的精...
-
50
Last updated on 2019年7月1日 CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。
-
26
Image-to-image translation involves generating a new synthetic version of a given image with a specific modification, such as translating a summer landscape to winter. Training a model for image-to-image translat...
-
26
小丑(The Joker)是美国DC漫画旗下的超级反派,首次登场于《蝙蝠侠》第1卷第1期(1940年6月),由鲍勃·凯恩、比尔·芬格和杰瑞·罗宾逊联合创造。 小丑常年稳居美媒票选的TOP100漫画反派角色第一名的位置,IGN评选的史上最伟大漫画反...
-
10
基于PaddlePaddle的Attention Cluster 视频分类模型 Attention Cluster模型为ActivityNet Kinetics Challenge 2017中最佳序列模型。该模型通过带Shifting Opeation的Attention Clusters处理已抽取好的RGB、Flow、...
-
6
量子计算, 遗传算法 & 进化算法, PaddlePaddle, X-Deep Learning 2019-03-04 https://www.zhihu.com/people/xuan-xing-29/posts 某量子通信牛人...
-
5
EBGAN & BEGAN(续) AutoEncoder:Encoder->Decoder GAN:Decoder->Encoder 对于生成模型来说,细节是很重要的,然而原始的GAN,通常只能生成模糊的图片。如何提高图片的分辨率呢? EBGAN将D网络的结构由Encoder...
-
8
Deep Generative Art – Monet Style Transfer with GANs (CycleGAN) This problem appeared as a project in the coursera course Deep Learning (by the University of Colorado Boulder) and also appeared in a...
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK