4

基于PaddlePaddle的强化学习算法CycleGAN Fork 72 收藏

 3 years ago
source link: https://my.oschina.net/u/4067628/blog/3234532
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

基于PaddlePaddle的强化学习算法CycleGAN Fork 72 收藏

生成对抗网络(Generative Adversarial Network[1], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

CycleGAN可以利用非成对的图片进行图像翻译,即输入为两种不同风格的不同图片,自动进行风格转换。传统的GAN是单向生成,而CycleGAN是互相生成,网络是个环形,所以命名为Cycle。并且CycleGAN一个非常实用的地方就是输入的两张图片可以是任意的两张图片,也就是unpaired。 

CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用convolution-norm-ReLU作为基础结构,解码部分的网络结构由transpose convolution-norm-ReLU组成,判别网络基本是由convolution-norm-leaky_ReLU作为基础结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。CycleGAN的结构如下: Cycle-Gan总结构有四个网络,第一个网络为生成(转化)网络命名为G:X---->Y;第二个网络为生成(转化)网络命名为F:Y--->X;第三个网络为对抗网络命名为Dx,鉴别输入图像是不是X;第四个网络为对抗网络命名为Dy,鉴别输入图像是不是Y。

如上图,以马(X)和斑马(Y)为例,G网络将马的图像转化为斑马图像;F网络将斑马的图像转化为马的图像;Dx网络鉴别输入的图像是不是马;Dy网络鉴别输入图像是不是斑马;

这四个网络仅有两个网络结构,即G和F都是生成(转化)网络,这两者的网络结构相同,Dx和Dy都是对抗性网络,这两者的网络结构相同。CycleGAN 效果展示:

In[1]
#代码结构
# ├── data_reader.py  # 读取、处理数据。
# ├── layers.py   # 封装定义基础的layers。
# ├── model.py   # 定义基础生成网络和判别网络。
# ├── trainer.py   # 构造loss和训练网络。
# ├── train.py     # 训练脚本。
# └── infer.py    # 预测脚本。
In[2]
!cd /home/aistudio/data/data10040/ && unzip -qo horse2zebra.zip

本项目使用 horse2zebra 数据集 来进行模型的训练测试工作,horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。

In[3]
# 数据准备
# 本教程使用 horse2zebra 数据集 来进行模型的训练测试工作,horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。
# 以下路径结构:
# data
# |-- horse2zebra
# |   |-- testA
# |   |-- testA.txt
# |   |-- testB
# |   |-- testB.txt
# |   |-- trainA
# |   |-- trainA.txt
# |   |-- trainB
# |   `-- trainB.txt
# 以上数据文件中,data文件夹需要放在训练脚本train.py同级目录下。testA为存放野马测试图片的文件夹,testB为存放斑马测试图片的文件夹,testA.txt和testB.txt分别为野马和斑马测试图片路径列表文件,格式如下:
# testA/n02381460_9243.jpg
# testA/n02381460_9244.jpg
# testA/n02381460_9245.jpg
# 训练数据组织方式与测试数据相同。
In[4]
#安装scipy
!pip install imageio
!pip install scipy==1.2.1
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.6.1)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio) (1.16.4)
Requirement already satisfied: pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imageio) (6.2.0)
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Collecting scipy==1.2.1
  WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='pypi.mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /simple/scipy/
  WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/
  WARNING: Retrying (Retry(total=3, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/
  WARNING: Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError("HTTPSConnectionPool(host='mirrors.ustc.edu.cn', port=443): Read timed out. (read timeout=15)")': /pypi/web/simple/scipy/
  WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ConnectTimeoutError(<pip._vendor.urllib3.connection.VerifiedHTTPSConnection object at 0x7f658b56e450>, 'Connection to mirrors.tuna.tsinghua.edu.cn timed out. (connect timeout=15)')': /pypi/web/simple/scipy/
  Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3e/7e/5cee36eee5b3194687232f6150a89a38f784883c612db7f4da2ab190980d/scipy-1.2.1-cp37-cp37m-manylinux1_x86_64.whl (24.8MB)
     |████████████████████████████████| 24.8MB 141kB/s eta 0:00:01
Requirement already satisfied: numpy>=1.8.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scipy==1.2.1) (1.16.4)
Installing collected packages: scipy
  Found existing installation: scipy 1.3.0
    Uninstalling scipy-1.3.0:
      Successfully uninstalled scipy-1.3.0
Successfully installed scipy-1.2.1
In[7]
#训练
#在GPU单卡上训练:
!python cycle_gan/train.py --epoch=2 --use_gpu True
-----------  Configuration Arguments -----------
batch_size: 1
epoch: 2
init_model: None
output: ./output
profile: False
run_ce: False
run_test: True
save_checkpoints: True
use_gpu: 1
------------------------------------------------
W0316 17:50:21.449174   346 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0316 17:50:21.453552   346 device_context.cc:245] device: 0, cuDNN Version: 7.3.
I0316 17:50:23.286839   346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0316 17:50:23.343647   346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0316 17:50:23.495808   346 parallel_executor.cc:375] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0
I0316 17:50:24.629184   346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0316 17:50:24.633793   346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0316 17:50:24.757025   346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0316 17:50:24.810863   346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
I0316 17:50:26.055040   346 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0316 17:50:26.059253   346 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
epoch0; batch0; g_A_loss: 14.300308227539062; d_B_loss: 1.703646183013916; g_B_loss: 13.938850402832031; d_A_loss: 1.6837800741195679; Batch_time_cost: 2.89
epoch0; batch50; g_A_loss: 6.143564224243164; d_B_loss: 0.43444156646728516; g_B_loss: 5.804959297180176; d_A_loss: 0.3155096173286438; Batch_time_cost: 0.12
epoch0; batch100; g_A_loss: 6.256138324737549; d_B_loss: 0.2702423632144928; g_B_loss: 5.725310325622559; d_A_loss: 0.3236050307750702; Batch_time_cost: 0.12
epoch0; batch150; g_A_loss: 7.066346645355225; d_B_loss: 0.31580042839050293; g_B_loss: 6.818117141723633; d_A_loss: 0.4519233703613281; Batch_time_cost: 0.13
epoch0; batch200; g_A_loss: 6.909765243530273; d_B_loss: 0.3620782196521759; g_B_loss: 7.228463172912598; d_A_loss: 0.36587968468666077; Batch_time_cost: 0.13
epoch0; batch250; g_A_loss: 5.192142009735107; d_B_loss: 0.24265910685062408; g_B_loss: 5.763035297393799; d_A_loss: 1.1450282335281372; Batch_time_cost: 0.13
epoch0; batch300; g_A_loss: 5.316933631896973; d_B_loss: 0.19533368945121765; g_B_loss: 4.987677097320557; d_A_loss: 0.22845549881458282; Batch_time_cost: 0.13
epoch0; batch350; g_A_loss: 7.186776638031006; d_B_loss: 0.16458189487457275; g_B_loss: 5.839504241943359; d_A_loss: 0.1668826937675476; Batch_time_cost: 0.12
epoch0; batch400; g_A_loss: 5.885252952575684; d_B_loss: 0.30592358112335205; g_B_loss: 5.449808120727539; d_A_loss: 0.1534576117992401; Batch_time_cost: 0.13
epoch0; batch450; g_A_loss: 6.440225124359131; d_B_loss: 0.11404794454574585; g_B_loss: 5.590849876403809; d_A_loss: 0.1966126412153244; Batch_time_cost: 0.13
epoch0; batch500; g_A_loss: 6.872323036193848; d_B_loss: 0.19951260089874268; g_B_loss: 6.571059226989746; d_A_loss: 0.23495124280452728; Batch_time_cost: 0.13
epoch0; batch550; g_A_loss: 5.0895466804504395; d_B_loss: 0.16031894087791443; g_B_loss: 5.121213912963867; d_A_loss: 0.25687021017074585; Batch_time_cost: 0.13
epoch0; batch600; g_A_loss: 6.153958797454834; d_B_loss: 0.21091848611831665; g_B_loss: 6.268405914306641; d_A_loss: 0.11446988582611084; Batch_time_cost: 0.13
epoch0; batch650; g_A_loss: 8.350132942199707; d_B_loss: 0.13534343242645264; g_B_loss: 8.49114990234375; d_A_loss: 0.17544472217559814; Batch_time_cost: 0.13
epoch0; batch700; g_A_loss: 4.763492584228516; d_B_loss: 0.25141626596450806; g_B_loss: 4.048842430114746; d_A_loss: 0.1611996442079544; Batch_time_cost: 0.13
epoch0; batch750; g_A_loss: 6.567712783813477; d_B_loss: 0.2876878082752228; g_B_loss: 6.591218948364258; d_A_loss: 0.17310750484466553; Batch_time_cost: 0.13
epoch0; batch800; g_A_loss: 4.201617240905762; d_B_loss: 0.28061026334762573; g_B_loss: 4.326346397399902; d_A_loss: 0.17125019431114197; Batch_time_cost: 0.13
epoch0; batch850; g_A_loss: 6.599902153015137; d_B_loss: 0.33524584770202637; g_B_loss: 5.3691182136535645; d_A_loss: 0.16773417592048645; Batch_time_cost: 0.12
epoch0; batch900; g_A_loss: 7.22122049331665; d_B_loss: 0.0716591626405716; g_B_loss: 6.713944435119629; d_A_loss: 0.16595764458179474; Batch_time_cost: 0.12
epoch0; batch950; g_A_loss: 6.072809219360352; d_B_loss: 0.17901702225208282; g_B_loss: 5.398962020874023; d_A_loss: 0.17636674642562866; Batch_time_cost: 0.13
epoch0; batch1000; g_A_loss: 5.913625717163086; d_B_loss: 0.0724739134311676; g_B_loss: 4.69740629196167; d_A_loss: 0.27375420928001404; Batch_time_cost: 0.13
epoch0; batch1050; g_A_loss: 9.782687187194824; d_B_loss: 0.12104632705450058; g_B_loss: 9.672157287597656; d_A_loss: 0.14316022396087646; Batch_time_cost: 0.13
epoch0; batch1100; g_A_loss: 4.571684837341309; d_B_loss: 0.06749674677848816; g_B_loss: 4.33621883392334; d_A_loss: 0.17319558560848236; Batch_time_cost: 0.13
epoch0; batch1150; g_A_loss: 5.194060802459717; d_B_loss: 0.13884581625461578; g_B_loss: 5.199503421783447; d_A_loss: 0.1418939232826233; Batch_time_cost: 0.13
epoch0; batch1200; g_A_loss: 4.909609794616699; d_B_loss: 0.12311942875385284; g_B_loss: 4.3402299880981445; d_A_loss: 0.2318917214870453; Batch_time_cost: 0.13
epoch0; batch1250; g_A_loss: 5.436093330383301; d_B_loss: 0.19272129237651825; g_B_loss: 6.465127944946289; d_A_loss: 0.23860839009284973; Batch_time_cost: 0.13
epoch0; batch1300; g_A_loss: 8.133463859558105; d_B_loss: 0.0819319561123848; g_B_loss: 7.404263496398926; d_A_loss: 0.10251586139202118; Batch_time_cost: 0.13
cycle_gan/train.py:134: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (fake_B_temp + 1) * 127.5).astype(np.uint8))
cycle_gan/train.py:136: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (fake_A_temp + 1) * 127.5).astype(np.uint8))
cycle_gan/train.py:138: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (cyc_A_temp + 1) * 127.5).astype(np.uint8))
cycle_gan/train.py:140: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (cyc_B_temp + 1) * 127.5).astype(np.uint8))
cycle_gan/train.py:142: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (input_A_temp + 1) * 127.5).astype(np.uint8))
cycle_gan/train.py:144: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (input_B_temp + 1) * 127.5).astype(np.uint8))  #保存生成的图
saved checkpoint to ./output/checkpoints/
epoch1; batch0; g_A_loss: 5.827764987945557; d_B_loss: 0.07218477874994278; g_B_loss: 5.5165019035339355; d_A_loss: 0.3245842158794403; Batch_time_cost: 0.12
epoch1; batch50; g_A_loss: 5.765468597412109; d_B_loss: 0.167547345161438; g_B_loss: 5.18796443939209; d_A_loss: 0.15456309914588928; Batch_time_cost: 0.12
epoch1; batch100; g_A_loss: 5.45959997177124; d_B_loss: 0.09713760763406754; g_B_loss: 4.907477855682373; d_A_loss: 0.06172751635313034; Batch_time_cost: 0.12
epoch1; batch150; g_A_loss: 6.784292221069336; d_B_loss: 0.11054498702287674; g_B_loss: 6.478621959686279; d_A_loss: 0.44314247369766235; Batch_time_cost: 0.12
epoch1; batch200; g_A_loss: 7.896561145782471; d_B_loss: 0.16192974150180817; g_B_loss: 7.6382269859313965; d_A_loss: 0.10663460940122604; Batch_time_cost: 0.12
epoch1; batch250; g_A_loss: 6.924862384796143; d_B_loss: 0.029113346710801125; g_B_loss: 6.796955585479736; d_A_loss: 0.07325638085603714; Batch_time_cost: 0.12
epoch1; batch300; g_A_loss: 4.884801864624023; d_B_loss: 0.050231873989105225; g_B_loss: 4.698901176452637; d_A_loss: 0.14451992511749268; Batch_time_cost: 0.13
epoch1; batch350; g_A_loss: 5.315099239349365; d_B_loss: 0.17729483544826508; g_B_loss: 5.134288787841797; d_A_loss: 0.18281573057174683; Batch_time_cost: 0.13
epoch1; batch400; g_A_loss: 7.244136810302734; d_B_loss: 0.2588624954223633; g_B_loss: 6.518290042877197; d_A_loss: 0.03540463373064995; Batch_time_cost: 0.12
epoch1; batch450; g_A_loss: 5.102941513061523; d_B_loss: 0.03143639117479324; g_B_loss: 4.737099647521973; d_A_loss: 0.22579550743103027; Batch_time_cost: 0.13
epoch1; batch500; g_A_loss: 6.038484573364258; d_B_loss: 0.07545653730630875; g_B_loss: 5.054262638092041; d_A_loss: 0.1474056839942932; Batch_time_cost: 0.12
epoch1; batch550; g_A_loss: 5.1528449058532715; d_B_loss: 0.05942108482122421; g_B_loss: 5.586222171783447; d_A_loss: 0.1414523869752884; Batch_time_cost: 0.13
epoch1; batch600; g_A_loss: 7.961068630218506; d_B_loss: 0.07359579205513; g_B_loss: 7.495856285095215; d_A_loss: 0.08565278351306915; Batch_time_cost: 0.12
epoch1; batch650; g_A_loss: 4.3616814613342285; d_B_loss: 0.17840822041034698; g_B_loss: 3.6735024452209473; d_A_loss: 0.15866130590438843; Batch_time_cost: 0.13
epoch1; batch700; g_A_loss: 4.804023742675781; d_B_loss: 0.18827767670154572; g_B_loss: 4.755307197570801; d_A_loss: 0.2991688549518585; Batch_time_cost: 0.13
epoch1; batch750; g_A_loss: 5.776893138885498; d_B_loss: 0.24969376623630524; g_B_loss: 5.4265851974487305; d_A_loss: 0.1804438978433609; Batch_time_cost: 0.13
epoch1; batch800; g_A_loss: 8.234237670898438; d_B_loss: 0.328776478767395; g_B_loss: 7.330533027648926; d_A_loss: 0.07429933547973633; Batch_time_cost: 0.13
epoch1; batch850; g_A_loss: 7.487462520599365; d_B_loss: 0.08431682735681534; g_B_loss: 7.30022668838501; d_A_loss: 0.1268233209848404; Batch_time_cost: 0.13
epoch1; batch900; g_A_loss: 3.5420680046081543; d_B_loss: 0.15772652626037598; g_B_loss: 2.79952335357666; d_A_loss: 0.2009810507297516; Batch_time_cost: 0.13
epoch1; batch950; g_A_loss: 6.323566436767578; d_B_loss: 0.09212709963321686; g_B_loss: 5.999871730804443; d_A_loss: 0.1642218381166458; Batch_time_cost: 0.12
epoch1; batch1000; g_A_loss: 4.416383743286133; d_B_loss: 0.4296090602874756; g_B_loss: 3.9065892696380615; d_A_loss: 0.14740586280822754; Batch_time_cost: 0.13
epoch1; batch1050; g_A_loss: 4.462809085845947; d_B_loss: 0.5468143820762634; g_B_loss: 4.7028489112854; d_A_loss: 0.10133746266365051; Batch_time_cost: 0.12
epoch1; batch1100; g_A_loss: 6.614782810211182; d_B_loss: 0.09065119922161102; g_B_loss: 6.871949672698975; d_A_loss: 0.265675812959671; Batch_time_cost: 0.12
epoch1; batch1150; g_A_loss: 4.825323104858398; d_B_loss: 0.029798883944749832; g_B_loss: 4.237264633178711; d_A_loss: 0.15145432949066162; Batch_time_cost: 0.13
epoch1; batch1200; g_A_loss: 5.156582355499268; d_B_loss: 0.31561416387557983; g_B_loss: 5.017688274383545; d_A_loss: 0.1431412398815155; Batch_time_cost: 0.13
epoch1; batch1250; g_A_loss: 5.140244007110596; d_B_loss: 0.12650391459465027; g_B_loss: 4.05291223526001; d_A_loss: 0.2101406753063202; Batch_time_cost: 0.12
epoch1; batch1300; g_A_loss: 3.725010633468628; d_B_loss: 0.25355270504951477; g_B_loss: 4.305200576782227; d_A_loss: 0.31700998544692993; Batch_time_cost: 0.13
saved checkpoint to ./output/checkpoints/
In[8]
#应用固化的模型(训练10轮)进行预测,结果保存在output,训练150轮,可以达到上面简介模块展示的效果
!python cycle_gan/infer.py \
    --init_model="output/freeze" \
    --input="./data/data10040/horse2zebra/testA/n02381460_4260.jpg" \
    --input_style A \
    --output="output/freeze_infer_result"
    
# 可视化转换前后的效果
%matplotlib inline
import matplotlib.pyplot as plt  
import numpy as np
import cv2

img= cv2.imread('data/data10040/horse2zebra/testA/n02381460_4260.jpg')
result_img= cv2.imread('output/freeze_infer_result/fake_n02381460_4260.jpg')

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
plt.imshow(result_img)
plt.show()
-----------  Configuration Arguments -----------
init_model: output/freeze
input: ./data/data10040/horse2zebra/testA/n02381460_4260.jpg
input_style: A
output: output/freeze_infer_result
use_gpu: True
------------------------------------------------
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:804: UserWarning: There are no operators in the program to be executed. If you pass Program manually, please use fluid.program_guard to ensure the current Program is being used.
  warnings.warn(error_info)
W0316 17:58:40.541411   441 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0316 17:58:40.545682   441 device_context.cc:245] device: 0, cuDNN Version: 7.3.
cycle_gan/infer.py:63: DeprecationWarning: `imsave` is deprecated!
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  (fake_temp + 1) * 127.5).astype(np.uint8))
up-8481efaf00da3aa3e73094e847f8bbc5a5d.png

使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/169459 

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

>> 访问 PaddlePaddle 官网,了解更多相关内容


Recommend

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK