深入浅出GAN生成对抗网络 原理剖析与TensorFlow实践

978-7-115-51795-1
作者: 廖茂文 潘志宏
译者:
编辑: 俞彬

图书目录:

详情

本书首先从Python 基本语法开始讨论,逐步介绍必备的数学知识与神经网络的基本知识,并利用讨论的内容编写一个深度学习框架TensorPy,有了这些知识作为铺垫后,就开始讨论生成对抗网络(GAN)相关的内容。然后,本书使用比较简单的语言来描述GAN 涉及的思想、模型与数学原理,接着便通过TensorFlow实现传统的GAN,并讨论为何一定需要生成器或判别器。接下来,重点介绍GAN 各种常见的变体,包括卷积生成对抗网络、条件对抗生成网络、循环一致性、改进生成对抗网络、渐近增强式生成对抗网络等内容。 本书从模型与数学的角度来理解GAN 变体,希望通过数学符号表达出不同GAN 变体的核心思想,适合人工智能、机器学习、计算机视觉相关专业的人员学习使用。

图书摘要

内容提要

本书从Python基本语法入手,逐步介绍必备的数学知识与神经网络的基本知识,并利用介绍的内容编写一个深度学习框架TensorPy,有了这些知识作为铺垫后,就开始介绍生成对抗网络(GAN)相关的内容。然后,本书使用比较简单的语言来描述GAN涉及的思想、模型与数学原理,接着便通过TensorFlow实现传统的GAN,并讨论为何一定需要生成器或判别器。接下来,重点介绍GAN各种常见的变体,包括卷积生成对抗网络、条件生成对抗网络、循环一致性、改进生成对抗网络、渐近增强式生成对抗网络等内容。

本书从模型与数学的角度来理解GAN变体,希望通过数学符号表达出不同GAN变体的核心思想,适合人工智能、机器学习、计算机视觉相关专业的人员学习使用。

推荐序

随着人工智能技术的迅速发展,图像识别、语音识别、机器翻译等技术正在改变着我们的生活方式。目前,生成对抗网络(GAN)在图像和计算机视觉领域应用非常活跃,它既可以生成让人类已经难以分辨的逼真图像,还可以实现图像修复、模糊图像高清化、视频生成。除此之外,GAN还被应用于自然语言处理、信息安全等领域。

本书作者既有在企业一线的开发工程师,又有在应用型本科院校从教多年的教师,因而能基于丰富的开发经验和教学经验,精心设计内容,使之兼顾理论与实战。本书内容全面且有深度,介绍了传统GAN、DCGAN、CGAN、CycleGAN、InfoGAN、SeqGAN等各种不同类别的GAN模型,并且从生成器、判别器、损失定义、具体训练逻辑等多个方面展开讨论,从数学层面去推导证实,突出不同类别GAN架构的底层思想。此外,本书利用Tensorflow深度学习框架实现各种不同类别的GAN模型,实战性强。无论是深度学习的学习者,还是已经具备深度学习基础想进行生成对抗网络项目实战与应用的读者,都能从书中获益。我衷心希望这本书能够帮助更多读者深入理解各种GAN模型的理论,并帮助他们更好地利用GAN解决实际项目中的问题,为人工智能应用型人才的培养发挥积极作用。

汤庸 教授/学者网创始人
华南师范大学计算机学院院长
广东省服务计算工程研究开发中心主任

前言

生成对抗网络(GAN),在2014年被提出,随后便引起了大量深度学习研究人员与从业者的关注。GAN以两个网络相互对抗的方式进行训练从而获得有价值的模型,相比于传统的无监督学习的思路,GAN更加清晰易懂,并避开了传统无监督学习中遇到的困难,是这几年论文发表数最多的主题之一。据统计,至2018年年底,每20分钟就会有一篇GAN方面的论文。

目前,GAN在图像和视觉领域运用是最广泛的,GAN已经可以生成超高清的逼真图像,人类已经难以分辨出这些图像是生成的还是真实的,通过这种方法可以实现图像的修复、模糊图像高清化、马赛克去除、视频生成、为其他模型提供训练数据等,除此之外,GAN还在自然语言处理、强化学习、音频视频以及安全领域大展手脚,可以看出GAN拥有巨大的研究与运用前景。

GAN可以实现传统程序难以实现的目标,如马赛克去除、图像修复、逼真图像的生成等,笔者在编写实现不同GAN变体的过程中,常常会惊叹于其提供的各种巧妙的解决方法,在感受GAN魅力的同时,也感受到了其模型底层对应的数学之美。

国内关于GAN的书很少,没有原创的书籍,笔者结合多年的开发经验,编写了本书,以帮助读者快速学习生成对抗网络。

本书特色

1. 容易入门:本书会讨论线性代数、微积分、概率论、信息论等内容,尽力只提及后面内容需要的数学知识,并从原理角度去讲解这块内容,为后面篇幅做好铺垫。

2. 内容更深:介绍GAN的各种变体时,除了介绍架构以外,还会讲解目标函数为何要这样设计,并从数学层面去推导证实,可以说本书比较重视不同类别GAN架构的底层思想,并从数学上表示它。

3. 涉及面广:囊括了GAN的各个应用领域,包括传统GAN、DCGAN、CGAN、ColorGAN、CycleGAN、StarGAN、DTN、XGAN、WGAN、WGAN-GP、SN-GAN、StackGAN-v1、StackGAN-v2、PGGAN等10多个方向。

4. 实战性强:提供了很多代码,并给出运行结果。考虑到篇幅原因,并没有将每个类别的所有代码都放上去,而是主要讲解生成器、判别器、损失定义、具体训练逻辑等主要内容。

本书内容

本书分为3个部分。

第一部分(第1~3章)介绍背景知识,包括Python的基础用法和一些进阶技巧、线性代数积分、概率论、信息率、神经网络以及优化算法以及实现自己的深度学习框架TensorPy。

第二部分(第4章和第5章)介绍GAN的基础知识,包括传统GAN的模型结构、数学原理以及TensorFlow实现方法,同时探讨为何不可以单独使用生成器或单独使用判别器进行图像生成。

第三部分(第6~11章)介绍各种GAN变体,包括DCGAN、CGAN、ColorGAN、CycleGAN、StarGAN、DTN、XGAN、WGAN、WGAN-GP、SN-GAN、StackGAN-v1、StackGAN-v2、PGGAN、InfoGAN、VAE-GAN等。

作者介绍

廖茂文:游戏AI研究员、高级工程师、中国人工智能学会高级会员。研究兴趣为自然语言处理、生成对抗网络、游戏AI,曾参与多项机器学习项目。

潘志宏:高级工程师,中山大学新华学院“百名骨干教师”,中国人工智能学会高级会员、中国计算机学会会员。研究兴趣为机器学习、深度学习、物联网。主持和参与省市级、校级项目10余项,其中主持广东省普通高校青年创新人才项目、教育部产学合作协同育人项目各一项。发表论文18篇,其中SCI、EI、北大核心期刊12篇,第一作者论文获得北大核心期刊优秀论文、东莞市计算机学会优秀论文。申请发明专利、实用新型专利共8项,其中已授权3项,获得软件著作权3项,已出版教材3部。指导学生获得国家级和省级竞赛奖项50余项,多次获得国家级和省级优秀指导教师奖。

本书读者对象

□ 深度学习相关程序员

□ 算法工程师

□ 人工智能开发人员

□ 游戏开发人员

□ 计算机视觉开发人员

□ 各类院校的学生

□ 其他对GAN感兴趣的各类人员

第1章 优雅Python

本书选择Python作为主要的开发语言,原因其实很简单,首先,Python的语法结构比较简单,即便读者没有接触过Python,只要有其他编程语言的开发经验也非常容易上手Python。其次,Python是目前机器学习的主流语言,大多数知名的机器学习框架都支持Python语言。本书后面涉及深度学习与生成对抗网络的内容都会使用TensorFlow框架来构建相应的神经网络结构,而TensorFlow对Python来说是具有良好支持的框架。基于以上原因,我们选择Python作为本书主要的开发语言。

虽然Python具有众多优点,但其有个明显的缺点就是运行速度慢,这是因为通常深度学习会涉及大量的运算,所以,为了扬长避短,大多数机器学习框架底层都是用C/C++等语言开发的,然后在这些底层逻辑之上使用Python进行封装,实现易用与快速运行这两个优点。

为了让读者方便理解本书后面的内容,本章先简单地介绍一下Python。

1.1 Anaconda

  conda create -n tfpy36 python=3.6
  #进入虚拟环境
  source activate tfpy36
  #退出虚拟环境
  source deactivate
  #进入虚拟环境
  activate tfpy36
  #退出虚拟环境
  deactivate tfpy26
  # 安装仅CPU版
  pip install tensorflow==1.9
  # 安装GPU版
  pip install tensorflow-gpu==1.9
  In [1]: import tensorflow as tf
  In [2]: a = tf.constant(1.0, tf.float32)
  In [3]: b = tf.constant(2.0)
  In [4]: sess = tf.Session()
  2018-08-12 09:43:53.060073: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your
CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
  In [5]: print(a,b)
  Tensor("Const:0", shape=(), dtype=float32) Tensor("Const_1:0", shape=(), dtype=float32)
  In [6]: print(sess.run([a,b]))
  [1.0, 2.0]

首先Python有两个系列的版本,分别是Python 2与Python 3。两个系列版本是相互不兼容的,造成这个现象的历史原因不多提及,读者只需知道通过Python 3编写的代码并不一定能通过Python 2运行,反之亦然。Python 2的最高版本是Python 2.7,官方会对其维护到2020年,随后便不再支持。本书使用Python 3作为开发语言,具体的版本为Python 3.6.7。

为了方便后面的开发,这里通过Anaconda的方式来安装Python。Anaconda是Python的免费增值开源发行版,它直接为我们安装好了各种用于科学计算的依赖库。如果我们直接地安装Python 3,那么这些第三方依赖库还需自己手动去下载。在下载安装使用的过程中可能还会遇到依赖冲突等问题。为了避免这些问题带来的困扰,应直接下载并安装Anaconda。Anaconda同样分为以Python 2为基础的版本与以Python 3为基础的版本,这里推荐直接下载并安装以Python 3为基础的版本。

Anaconda除了帮助我们预先安装了各种常用的科学计算的依赖库外,还提供了包管理和部署工具conda。我们可以通过conda来创建一个专门用于开发深度学习项目的Python虚拟环境。

首先来聊一下所谓的Python虚拟环境,通常使用Python开发时,为了提高开发效率,都会使用各种第三方库,如科学计算库numpy、scipy,图像处理库pillow、opencv等。随着编写项目的增加,就会在本地环境中安装各种各样的工具,此时就会显得混乱,难以管理。一个常见的情况就是,在开发项目A时使用了1.0版B库,此时开发一个新的项目也要使用B库,但版本要求是2.0。如果升级B库,此前开发的项目A就可能会出现问题。如果不升级,新项目开发就遇到阻碍。为了避免这种情况,最好的做法就是单独创建不同的Python虚拟环境。每个虚拟环境都是一个独立的不会影响系统原本Python环境的空间,在这个空间中编写程序和安装依赖库都不会影响系统本身的Python环境以及其他Python虚拟环境,这样不仅方便管理,也避免了很多包冲突的问题。

下面使用conda来创建一个名为tfpy36的虚拟环境。

这样conda就会为我们自动创建一个Python 3.6的虚拟环境,名为tfpy36。如果无法直接使用conda命令,则需要对系统的环境变量进行相应的修改。

等待创建完成后,就可以进入该虚拟环境进行操作了。

□ Mac/Linux进入方式

□ Windows进入方式

进入虚拟环境后,就可以在该虚拟环境中安装各种依赖库,以及使用该虚拟环境进行模型的开发了,这里直接通过pip来安装TensorFlow,方便后面直接使用。

TensorFlow迭代速度较快,在编写本书时,TensorFlow版本为1.9,所以这里推荐安装1.9版本的TensorFlow,不同的系统在TensorFlow的安装上会有一些差异,可以参考官方提供的安装文档。

安装完成后,可以简单测试使用一下,首先通过pip安装增强式Python交互环境IPython,pip install ipython,然后在命令行中输入ipython进入增强式Python交互环境,导入TensorFlow并进行一个简单的测试,以检查TensorFlow是否安装成功,具体代码如下。

当开发比较复杂的项目时,通常会使用相应的IDE进行开发,这里推荐使用PyCharm作为Python的开发工具,下载安装后新建一个名为tfgan的项目,新建项目时PyCharm本身支持为该项目创建独立的Python虚拟环境,这里直接导入此前创建好的Python虚拟环境即可。如果每个项目都创建一个单独的虚拟环境,个人觉得太冗余与繁杂了。如果每个项目都创建一个虚拟环境,那么在每个项目都要重复安装常用依赖库的过程,比较好的做法是同类型项目使用同一个虚拟环境,如图1.1所示。

如果是已经存在的项目,要使用conda创建的虚拟环境,则需要打开PyCharm的设置面板,进行图1.2所示的选择。

值得一提的是,Anaconda通常将不同的虚拟环境都放置在根目录的envs文件夹下,所以在使用conda创建的虚拟环境时,导入envs目录不同Python虚拟环境bin目录下的python即可,具体如图1.3所示。

1.2 Python基础

Python语言是一种动态语言,也是一种强类型语言。本节先简单地介绍一下Python常用的数据类型、流程控制与函数定义。只要学会了这些,你就可以开始编写程序了。

1.2.1 常用数据类型

  In:
  a = 1
  b = 1.
  c = True
  d = 'python'
  print(type(a), type(b), type(c), type(d))
  
  Out:
  <class 'int'><class 'float'><class 'bool'><class 'str'>
  In:
  print(d[0])
  print(d[0:4])#切片
  
  Out:
  'p'
  'pyth'
  In:
  l = [1,1,2,2,3,4,5]
  print(l[0])
  print(l[0:4])
  l.append(6)
  l.remove(4)
  print(l)
  
  Out:
  1
  [1, 1, 2, 2]
  [1, 1, 2, 2, 3, 5, 6]
  In: 
  t = (1,1,2,2,3,4,5)
  print( t[0])
  t[0:4]
  t[0]=6
  
  Out:
  1
  (1, 1, 2, 2)
  ---------------------------------------------------------------------------
  TypeError                                 Traceback (most recent call last)
  <ipython-input-18-d83209c3a892> in <module>
  ----> 1 t[0]=6
  TypeError: 'tuple' object does not support item assignment
  In: 
  d = {'name':'ayu','like':'python'}
  d.get('name','') #查询
  d['age']=28 #添加
  print(d)
  print(d.keys()) # 获取所有的key
  print(d.values()) #获取所有的values
  
  Out:
  'ayu'
  {'name': 'ayu', 'like': 'python', 'age': 28}
  dict_keys(['name', 'like', 'age'])
  dict_values(['ayu', 'python', 28])
  In:
  s = {1,1,2,2,3,4,5}
  print(type(s))
  print(s)
  s.add(6) #添加
  s.discard(1) #删除
  print(s)
  
  Out:
  set
  {1, 2, 3, 4, 5}
  {1, 2, 3, 4, 5, 6}
  {2, 3, 4, 5, 6}

Python支持的基本类型有int、float、bool、str等,同时也提供几种标准数据,包括list、tuple、dict、set。

首先来使用以下基本类型。

其中1.表示float类型,它其实是1.0的缩写。这里不再多介绍int、float与bool类型,但需要讲讲str(字符串)类型。Python对字符串有非常强大的支持,让我们可以很轻松地使用字符串,在Python中一个字符串可以看成是由多个字符元素组成的数组,我们可以使用下标以及切片的方式来操作字符串。

标准数据也是Python中常用的类型,从list(列表)开始介绍,list是一种有序的容器,该容器中的对象都是可变对象,我们可以对存入list中的元素进行增、删、改、查等操作。

从上面的操作代码可以看出,list同样支持下标与切片取值的操作。在list中,可以存入重复的数据,其使用append()方法存入数据,通过remove()方法删除数据。

虽然tuple(元组)与list类似,但是两者仍有较大的区别:一方面,list用中括号表示,tuple用小括号表示;另一方面,tuple中的元素是不可变的,即无法修改。

从上面的操作代码可以看出,tuple同样支持下标与切片操作,但tuple不支持增、删、改操作。

dict(字典)类型是Python中除list外最灵活的标准数据类型。list是有序的容器,而dict是无序的容器,dict通过key来获得对应的value。

可以看出,dict同样支持增、删、改、查操作,只是dict通过key来进行操作,而list通过下标来进行。

set类型是无序且不重复的容器,通常可以使用set对数据进行去重操作,其基本操作如下。

至此,Python中关于数据类型的操作就介绍完了,这些都是Python中最基础的内容,更深入的内容请读者自行了解。

1.2.2 流程控制

  num = 6
  if num == 1:
       #do some thing
       print(num)
  elif num > 10:
       # do some thing
       print(num)
  else:
       # do some thing
       print(num)
  mylist = [1,2,3,4,5]
  # for迭代
  for i in mylist:
        # do some thing
        print(i)
  i = 1
  # while循环
  while i <= 5:
        # do some thing
        print(i)
        i += 1

非常简单地介绍完Python数据类型后,接着来介绍一下Python中的流程控制。流程控制主要分两类:一类是判断,另一类是循环。理论上而言,凭借判断与循环就可以编写任何程序了。

这里先从判断开始,在Python中主要使用if来实现判断(Python中没有switch/case结构),使用方式如下。

Python的语法糖与其他语言有比较大的差别,很多语言都使用{}将一个程序块括起来,而Python中使用相同个数的空格缩进表示一个程序块。例如,if判断下的语句都需要相对于if判断语句本身多缩进4个空格,表示其下的语句是if判断的一个程序块。

上述Python代码就是常见的if判断代码,if关键字或elif关键字后连接具体的判断条件。如果满足判断条件,则只需执行该条件下的代码逻辑。

Python中的循环语法结构也类似,在Python中可以使用for关键字与while关键字来实现循环,两者效果是类似的。

在Python中虽然两者都可以实现循环,但还是有差异的。对for关键字而言,它执行的是迭代(iterate)操作,即按某种顺序逐个访问容器中的每一项的行为;而对while关键字而言,它执行的就是我们常说的循环(loop),即满足一定条件时,重复执行同一段代码的行为。

1.2.3 函数定义

  def add(x,y):
       return x+y
  print(add(1,2))
  def add(x,y=10):
       return x+y
  print(add(1))
  def add(*args, **kwargs):
       sum = 0
       for i in args:
             sum += i
       #循环获得dict中的值
       for k,v in kwargs.items():
             sum += kwargs.get(k,0)
       return sum
  print(add(1,1,x=1,y=1))

当编写程序时,如果遇到一些需要重复使用的逻辑,就可以将其封装成一个函数,在需要使用的地方调用该函数即可,从而降低了代码的冗余度。

在Python中使用def关键字来定义函数,常见方式如下。

上述代码中定义了名为add()的函数方法,该函数的作用就是返回两个参数的累加值,有时我们会给参数赋予默认的值。

有时为了考虑通用性,不一定会传入2个值,还有传入3个或4个值等各种可能。这种不知道具体会传什么参数的方式可以使用*args关键字与**kwargs关键字,代码如下。

从上述代码中可以看出,*args会接收所有没有指定参数名的值,如一开始的两个1,而**kwargs会接收指定参数名的所有值。其中args其实是list类型,而kwargs则是dict类型,此时使用for循环取出args对象与kwargs对象中的值并累加,最后返回累加值。

1.3 Python进阶

前面关于Python基础内容的讲解较为简单,接下来介绍Python中比较常用的进阶技巧,这些技巧在后面编写神经网络模型时都会使用到,在此做个铺垫。

1.3.1 生成式

  In:
  l = [i for i in range(10) if i%2 == 0] #列表生成器
  print(l)
  print(type(l))
  d = {k:v for (k,v) in [('a',1),('b',2)]} #字典生成器
  print(d)
  print(type(d))
  
  Out:
  [0, 2, 4, 6, 8]
  list
  {'a': 1, 'b': 2}
  dict

Python中列表与字典都可以通过生成式的方式来生成。

可以发现,生成式的写法就是将使用for循环创建list或dict的逻辑代码缩短成一行。

1.3.2 可迭代对象与迭代器

  In:
  l = [1,2,3,4,5]
  next(l)
  
  Out:
  ---------------------------------------------------------------------------
  TypeError                                 Traceback (most recent call last)
  <ipython-input-4-101c36968c6d> in <module>()
  ----> 1 next(l)
  TypeError: 'list' object is not an iterator
  from itertools import islice
  class Fib:
      ''
  获得斐波那契数列
      '''
      def __init__(self):
          self.prev = 0
          self.curr = 1
      def __iter__(self):
          return self
      def __next__(self):
          value = self.curr
          self.curr += self.prev
          self.prev = value
          return value
  f = Fib()
  print(list(islice(f, 0, 20)))
  [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765]
  class myIterable(object):
      def __init__(self, mylist):
          self.mylist = mylist
      def __getitem__(self, item):
          return self.mylist[item]
  l = myIterable([1,2,3])
  for i in l:
      print(i)
  1
  2
  3

为了加深对for关键字的理解,需要讨论一下Python中关键的概念——可迭代对象与迭代器。

在Python中,任意对象只要定义了__iter__方法或者定义了可以支持下标索引的__getitem__方法,它就是一个可迭代对象。可以通过内置的dir()方法来查看某个对象是否定义了这两个方法中的一个,从而判断该对象是否为可迭代对象。其中list与str就是可迭代对象,在Python中还有很多可迭代对象,例如文件流files、网络流sockets等。

任何对象只要定义了__iter__方法和__next__方法,它就是一个迭代器。由此可知,迭代器一定是可迭代对象。因为迭代器需要定义__iter__方法,而只要定义了__iter__方法,就可以认为该对象是可迭代对象。以list为例,通过dir()方法查看list列表对象定义时,可以发现list定义了__iter__,则list就是一个可迭代对象。而迭代器相对于可迭代对象通常多定义了一个__next__方法,当然也有例外的情况(如可迭代对象没有定义__iter__方法,只定义了__getitem__方法的情况)。由于list中没有__next__方法,可知list不是一个迭代器,只是一个可迭代对象,当通过next()方法调用list时,会因该list不是迭代器而报错。

接着我们来定义一个迭代器,只需在自定义对象中定义__iter__方法和__next__方法即可。

一般而言,在定义迭代器时,会希望通过迭代器对象本身来取得其中的值,所以__iter__方法只需返回迭代器自身。上述代码中定义了Fib类,该类实例化的对象是一个迭代器,通过该迭代器可以获得一个无限的斐波那契数列,这里使用了islice方法限制其只获取前20个数,其输出如下。

每次运行next()方法获取迭代器中下一次的值时,next()方法主要做了两件事:一是返回此次调用next()方法生成的返回结果,二是为下一次调用next()方法修改状态。当然实现一个斐波那契数列根本不需要动用迭代器,设计一下,用一个简单循环就可以了。那么为何还要有迭代器呢?因为使用迭代器省内存。如果你需要打印前1000万个斐波那契数,单纯地使用循环,就需要将这1000万个值都存到内存中,这会消耗大量的内存。如果使用迭代器,就不会出现大量消耗内存的情况。迭代器很懒、很健忘,只有在需要某值的时候才执行函数内的逻辑,返回相应的值,然后就将它忘了,这样就几乎不消耗内存了。当你需要读入大量数据对模型进行训练时,就可以通过这种方式减少内存的占用。

值得一提的是,for兼容了两种机制:第一种就是上面提及的,对于有定义__iter__方法的可迭代对象,for会通过__iter__方法来实现迭代;第二种就是一些指定了__getitem__方法的可迭代对象。对于第二种可迭代对象,for没有__iter__方法可以调用,那么就会改用下标迭代的方式来实现迭代,一个具体的示例如下。

上述代码可以输出如下内容。

1.3.3 生成器

  #生成器
  def generator(n):
      for i in range(n):
          yield i+1
  for i in generator(5):
      print(i)
  def Fib(prev, curr,n):
      while curr<n:
          yield curr
  prev,curr = curr, prev+curr
  
  for f in Fib(0,1,20):
      print(f)

前面我们了解了Python中的可迭代对象与迭代器,这有益于我们理解Python的生成器。在Python中生成器的定义很简单,就是用yield关键字的函数直接来定义一个生成器。

在上述代码中,generator()方法就是一个生成器,一个明显的特征就是使用yield关键字替换常用的return关键字,其作用是返回yield关键字后面表达式的值,同时将程序中断,并保存程序运行到当前这一步的上下文。这一句话可能有点绕,拆分来看,yield在这段代码中的作用就是返回yield关键字后面表达式的值,这里返回i+1的值;中断程序,将程序运行停止在这一步;保存程序运行到这一步的上下文;在程序恢复运行时,使用此前保存好的上下文。需要注意的是,生成器中不允许使用return关键字。yield关键字后面表达式的值不会在函数被调用时就立刻返回,而是当next()方法被调用时才会返回。

使用生成器的一个明显的优势就是非常节省内存,它可以轻松地将十几GB的文件逐步读入程序中进行处理,非常适合深度学习中模型训练数据的读入。可以说生成器就是迭代器的另一种更加优雅的实现方式,生成器利用yield关键字实现了迭代器的所有功能,同时让代码变得更加简明。下面通过生成器的方式实现斐波那契数列的计算。

可见,相比于迭代器的实现方式,生成器的代码明显简化了很多。

1.3.4 装饰器

  def speed_time(func):
      def print_time(*args, **kwargs):
          func_name = func.__name__
          t0 = time.perf_counter()
          res = func(*args, **kwargs)
          t1 = time.perf_counter()
  print('%s run time is (%s), the res is (%s)' % (func_name,t1-t0, res))
      return print_time
  @speed_time
  def for_10000():
      sum = 0
      for i in range(10000):
          sum += i
      return sum
  for_10000()
  for_10000 run time is (0.0012948440271429718), the res is (49995000)
  def logger(level):
      def decorate(func):
          def wrapper(*args, **kwargs):
              if level == 'warn':
                    print('Warn Info')
              elif level == 'error':
                    print('error Info')
              return func(*args)
          return wrapper
      return decorate
  @logger(level='error')
  def myname(name='ayuliao'):
  print('My name is %s'%name)
  myname()
  error Info
  My name is ayuliao
  class Logger(object):
      def __init__(self, func):
          self._func = func
      def __call__(self):
          print(self._func.__name__ + ' is running')
  self._func()
  @Logger
  def ayu():
      print('ayu')
  ayu()
  ayu is running
  ayu

装饰器是Python中比较特殊的用法,装饰器本质上就是一个利用闭包特性的Python函数,其作用是装饰已存在的函数。善用Python的装饰器可以很好地优化代码的结构。

下面举一个具体的例子——实现一个性能测试的装饰器,其核心功能是打印函数运行前和运行后的时间差,具体代码如下。

在上述代码中,speed_time()函数的参数其实也是一个函数,该函数也就是被装饰的函数,speed_time()函数内部是print_time()函数,该函数的逻辑就是打印被装饰函数运行的时间差。简单来看,speed_time()方法的作用就是将func()被装饰函数替换成print_time()函数。

简单使用一下,代码如下。

可以获得的结果如下。

一般而言,装饰器是为了在不修改被装饰函数的情况下给被装饰函数添加一些新的功能,本质上是返回一个具有相应功能的新函数来代替被装饰函数。对于装饰器,不使用@关键字,直接使用for_10000=speed_time(for_10000),效果是一样的,但显而易见的是使用@更加方便。

还需要注意的是装饰器的运行时间,函数装饰器会在导入模块时就立即执行,而被装饰的函数只有在明确调用时才会运行。在实际情况中,装饰器通常都在一个模块中定义,然后应用到其他模块上,那么在引用import时,装饰器就已经被调用了。

除上面实现的简单装饰器外,还有带参数的装饰器。带参数的装饰器可以实现更加复杂的逻辑,例如可以在装饰器中打印指定级别的日志,代码如下。

输出结果如下。

可以发现,所谓带参数的装饰器就是对原有装饰器的一个函数封装,并返回一个装饰器。解释器看到@logger(level='error')时,Python能发现最外层封装,它会将参数传递给内部装饰器环境,@logger(level='error')等于@decorate。

实现装饰器的方式不局限于函数,类同样也可以实现一个装饰器,而且类装饰器的灵活性、封装性都比函数实现的装饰器好。

先写一个简单的类装饰器,用于打印日志,代码如下。

其运行结果如下。

从上述代码中可以看出,所谓的类装饰器,主要就是定义了__call__方法,当使用@调用类装饰器时,Python解释器就会调用该方法。

1.4 小结

本章简单地介绍了Anaconda环境以及Python的相关内容,其中包括Python的基础内容和进阶内容。通过本章的学习,相信大家对Python有了一定的了解,这非常有益于大家理解后续章节的有关模型代码的编写内容。因为本书不是专门讨论Python的书籍,对于其中的很多细节并没有提及,如果想要进一步理解Python,可以参考其他优秀的Python书籍。

相关图书

图神经网络:基础、前沿与应用
图神经网络:基础、前沿与应用
图神经网络前沿
图神经网络前沿
Python神经网络项目实战
Python神经网络项目实战
TensorFlow深度学习项目实战
TensorFlow深度学习项目实战
PyTorch深度学习和图神经网络(卷2)——开发应用
PyTorch深度学习和图神经网络(卷2)——开发应用
PyTorch深度学习和图神经网络(卷1)——基础知识
PyTorch深度学习和图神经网络(卷1)——基础知识

相关文章

相关课程