oohcode

$\bigodot\bigodot^H \rightarrow CODE$

go modules

依赖的发展历史

2009 年 1 月

刚开始 Go 语言雏形还不完善,用户只能使用官方库和自己开发的库来构建自己的程序,这时候使用还是 Makefile, 但是 Go 的设计目标是使写代码变得更加容易。 所以后来 Go 命令不使用 Makefiles 或者其他配置文件来指导程序的构建。Go 使用源代码来查找依赖关系并确定构建的条件。

2010 年 2 月

在 2010 年 2 月, 一篇名为:goinstall: an experiment in package installation 的文章在 Go 的讨论组中发布,主要是建立一个代码共享的机制,通过 goinstall 命令把依赖的代码库下载到本地,并通过 import 来引用这些代码。goinstall 的设计是下载的地址和引用的地址保持一致,比如 :

-  goinstall github.com/user/project.git

那么使用的时候也是

1
import "github.com/user/project.git"

这种方式在社区引起了广泛的讨论, 主要包括引用的名字是否应该包含完整的路径,是否要考虑引用的版本,甚至有人提出把版本好加到引用路径中, 在当时的情况下 Russ Cox 选择了继续使用当前的方式,但是后来也证明有些建议是有用的,比如版本号。

2011 年 12 月

到了 2011 年 12 月, Go 语言统一了使用 go 命令代替之前的 gomake/gotest/goinstall等命令。其中 go get 命令代替之前的 goinstall 正式加入 Go 语言中。

2013 年 9 月

go get 虽然可以下载外部依赖,但是却没有版本的概念,这样就无法向用户保证每次下载代码的预期。 于是一些开源方案出现了。这些开源方案借鉴了其它语言的依赖管理方式,并结合 Go 语言自身的特点产生了各种各样的方案,其中最早的就是 godep, godep 主要是提供了一个依赖文件,这个文件记录了每个的依赖的具体的版本和路径,编译的时候把这些依赖下载到一个特定的目录 workspace 中, 然后 checkout 到特定的版本,再把这个目录加入到 GOPATH中,这样就能利用 Go 的编译查找路径编译特定的版本了。

2013 年 11 月

Go 官方同样注意到了这个问题,为了改善中情况在 Go 1.2 中添加了有关版本控制的相关建议:

Packages intended for public use should try to maintain backwards compatibility as they evolve. The Go 1 compatibility guidelines are a good reference here: don’t remove exported names, encourage tagged composite literals, and so on. If different functionality is required, add a new name instead of changing an old one. If a complete break is required, create a new package with a new import path.

主要是让用户遵循兼容性原则,不要随便删除和修改已经提供的方法,如果要修改最好是增加新的方法而不是修改旧的方法,如果真的需要不兼容则重新建立一个包。

但是这种建议是没有约束力的,开发者如果没有遵循这个建议还是会产生各种问题,无法从根本上保证。

2014 年 3 月

还有一种方案是 gopkg.in 采用的, 那就是对不同的版本进行不同的命名,通过引入路径的不同区分不同的版本。
例如: gopkg.in/yaml.v1, 我们使用这个版本就要使用 import "gopkg.in/yaml.v1"。对于 gopkg.in 来说只是提供了一种利用 URL 版本化解决版本依赖的问题, 实际的存储地址在 github.com 上,如果版本号不发生改变就要保证向前兼容。

这种做法需要我们引入的时候带上版本号,并且之前 go get 的方案也有人提出了使用带版本的引入方案,但是都没有得到官方的认可,认为这种方法具有一定的侵入性。

2014 年 10 月

Glide 的发布提出了 _vendor 的概念,把所有依赖都放入这个目录,这其实与 godep 的 workspace 作用一样。Glide 还有一些优势就是可以添加包的映射关系,因为随着时间的推移很多包的地址发生了改变,或者报名发生了改变,但是为了不让使用者修改代码,可以使用这个映射来解决。 这也是后来 go module 方案借鉴的。

2015 年 6 月

随着社区的发展,官方也开始接纳社区的一些方案,总于在 Go 1.5 版本中正式加入了 vendor 概念。1.5 版本以后 Go 的编译从原来依赖 GOPATH和 GOROOT 有增加了一个 vendor 目录的依赖,这个目录存在于当前库的根目录,编译的时候会优先使用 vendor 目录,然后再去 GOPATH 下查找。
这个方案推出以后很多开源软件也很快进行跟进,把原来依赖 GOPATH 的修改改为依赖 vendor 目录。

之前GOPATH 的问题: Go程序通常由来自许多不同来源的程序包组成。 这些来源中的每一个都可以从GOPATH或标准库中获取。 但是,只有他们的项目受其自己的源代码管理。 依赖的包不受他们的管理和控制,如果依赖包发生更改或消失时就会影响项目的编译。 通常为了避免这些问题,有以下几种做法:

  1. 将依赖包复制到项目目录中,并重写引用它的导入。
  2. 将相关程序包复制到项目代码库中,并修改GOPATH变量以包括项目特定的子目录。
  3. 将依赖包版本写到一个文件中,然后将现有的GOPATH软件包更新为该版本。

但是这些做的后果同样会产生对应的问题:

  1. 需要修改导入路径,还会存在包名冲突的问题, 分不清自己的和依赖的包。
  2. 存在多个 GOPATH, 而且是嵌套的,很容易出现查找不准等问题。
  3. 在正常的GOPATH中修改程序包要求每个项目都具有唯一的GOPATH。 否则不同项目就会产生相互干扰。

vendor 的加入是为了解决 GOPATH 的不足, 开发者可以把依赖放到 vendor 中而不影响其它项目,也可以防止依赖的变更和丢失。对于第二个作用也会产生一个副作用:代码库依赖太庞大,照成存储上的浪费。
vendor 还有一个问题是版本不明确,无法通过 vendor 很好的进行版本管理。

2016 年 8 月

前面说开源工具比较混乱,是因为大家使用不同的依赖管理工具,有一个问题就是你的间接依赖的版本如何确定? 因为某些间接依赖的版本根据依赖管理工具的使用来标明的,如果使用不同的依赖管理工具就无法获取你的间接依赖版本,或者需要兼容你得依赖管理工具,这样做太难了。最好的方式当然是统一依赖管理工具,这样能够很好的进行依赖的分析。
2016 年 GopherCon 大会后, 一个计划改善 Go 依赖管理的民间组织成立了。这个组织的目的就是为了开发一个能够完善的解决包依赖管理问题的方案。这个方案就是: dep 。这个小组甚至邀请了一些开源的依赖管理工具的作者:

  • Daniel Theophanes (govendor)
  • Dave Cheney (gb)
  • Matt Farina (Glide)
  • Steve Francia (Hugo, Cobra)

作为顾问。
dep 工具和开源的依赖管理工具思路是一样的:

  1. 依赖 GOPATH 判断包的依赖地址和版本
  2. 依赖 vendor 下载依赖到项目下,保证编译版本一致

为了让你能够从其它依赖管理工具迁移到 dep, dep 还会分析已经项目中已经存在的常见的依赖管理文件,并在初始化的时候与里面的依赖版本保持一致。

2018 年 2 月

最终还是来到了 2018 年,这年 2 月 Russ Cox 在自己的个人网站上发布了 7 篇博文,并与 3 月份在 Go 官方网站开启了此提案, 并且于 5 月 21 日(日子不错) 被接受。文章发布社区顿时就炸了,特别是 dep 的成员。因为这个提案完全不同于之前的工作方式,让之前的工作都付诸东流,甚至有人觉得 Go 的官方不尊重社区,太独裁了。随后 Russ Cox 又发表了 4 篇博文, 加上之前的一共发布了 11 篇博文来论证新的依赖管理方案: vgo。
关于 vgo 的方案下面会详细介绍。

2018 年 8 月

随着 Go 1.11 的发布,vgo 化身为 go modules, 对应命令 go mod 加入到了官方的工具链当中,并且从语言底层进行了支持。Go Modules 的发布意味着 Go 语言的依赖管理方案最终又要走向官方的统一方案,开源方案也纷纷表示不再支持更新,建议大家选用官方方案。

以上就是 Go 语言依赖管理发展的前世今生。

当然这不表示这个方案已经完美,Go Modules 还在不断发展当中,很多特性也会加入,但是总体的思想和设计方案不再会有变动。

Go Modules

Go Modules 已经被官方发出,所以后面的趋势就是会统一 Go 语言依赖管理的方案,结束现在混乱的状态,所以每个 Go 语言开发者都应该学习和使用 Go Modules。 下面会花更多的篇幅对Go Modules的设计思想和使用进行详细介绍。

Go 依赖管理的三个原则

Go 从出生开始的动机就是为了简化软件工程。这种动机也体现在 Go 的依赖管理设计上,分别是:兼容性,可重复性和合作性三个原则。
Go 依赖管理的设计原则是 Go Modules 设计与其它管理方案如: Dep,Cargo,Bundler等设计不同的原因。 也是为什么没有选择 Dep 的原因。

1 兼容性 (引入版本号到 import 语句中)

兼容性或者成为稳定性,是指程序中名称的含义不应随时间而发生改变。 如果确实需要发生改变, 则需要跟之前的引入路径保持不同。
对于兼容性主要有一些异议:

  1. 美学
    对于 Go 来说良好的软件工程性比一些主观上的美学更加重要。 视觉美学我们会习惯它们, 但我们更重视它们带来的精确性和简单性。对于 Go 的设计可以举两个例子:
    1.1 Go 语言去掉 export 关键字而使用首字母大小写代替是否可导出,从编程习惯是需要适应,但是也能够一眼看出哪些调用是可导出的。
    1.2 导入路径看上去有点儿长,但是能够更加精细的表示导入的模块,从而避免了不必要的重复。
  2. 需要修改导入路径
    我们把版本号加入导入路径使导入路径在语义上保持精确。另一个好处是,当您从软件包的v2升级到v3时,您可以逐步,分阶段地(一次可能一个软件包)逐步更新程序。这种方案之前也有提出,但是被否定了,就是处于美观的考虑,但是这种方式除了不太美观没有其他副作用,而且能够更好的解决问题。
  3. 构建中的多个主要版本
    其它依赖管理工具不允许同一个代码库的不同版本同时被依赖, 这样能够降低开发者的复杂性,但是对用户来说则可能会更加复杂,而且时间情况中无法避免同一个项目的代码在他的所有依赖文件中出现不同的版本。

2 可重复性 (最小版本选择)

当您构建特定版本的程序包时,构建应以可重复的方式决定要使用的依赖项版本,该依赖关系不会随着时间的推移而改变。
对这个原则的主要异议在于,很多人认为使用最新的版本是一项基本的诉求,很多人希望能够使用最新的版本到达更新一些 bug 的修复和性能的提升等。但是同时大家都会认为构建的可重复性是更加重要的,因为这能够保证程序的稳定性,这是最基本的要求。一些依赖工具例如 Dep 会优先使用最新版本,但是为了保持稳定性还有一个 lock 文件保证某些依赖的版本不变。Dep 的 lock 文件问题在于只会对当前代码库生效,如果这个代码库是其他库的一部分,那么就无法保证这种稳定性。所以 Go Modules 通过最小版本依赖原则来优先保证构建的稳定性。

3 合作性 (共同致力于兼容性)

为了维护Go软件包生态系统,我们必须共同努力, 致力于保证兼容性。工具无法解决缺乏合作的问题。
无论是基于 SAT 算法的 Dep 还是基于最小版本原则的 Go Modules 都无法保证兼容性。但是最小版本选择比 SAT 更加具有稳定性( 证明可以看这里: Go Modules 与 SAT 解决方案对比)

语义化版本

Go Modules 引入了语义化版本来标识软件的版本,而且是强制的。语义化版本就是把版本号分为三位,分别代表不同的含义:

  1. 第一位版本号(major version): 一些不兼容的修改
  2. 第二位版本号(minor version): 一些新特性的增加
  3. 第三位版本号(patch version): 一些 bug 的修复

小版本的更新不会破坏你的代码(向后兼容), 如果出现了错误你应该报告给作者,让他进行修改;大版本的更新可能会让你的代码无法编译,因为大版本可能不是向后兼容的,但是这种改变可能正式作者期望的。

这个规范是建议性的并没有约束力,而且根据 hyrums 定律:

With a sufficient number of users of an API, it does not matter what you promise in the contract. All observable behaviors of your system will be depended on by somebody.
当接口的使用者数量达到一定数量后,你指定的规则就不重要了,所有的表现都取决于他的使用者。

也就是无法保证大家都按照这个规范执行了,但是语义版本控制仍然是一种构架对发布之间关系的期望的有用方法。目前还没有其它更好的代替方法。

关于语义化版本的引入, Russ Cox 还说:

A year ago, I believed that putting versions in import paths like this was ugly, undesirable, and probably avoidable. But over the past year, I’ve come to understand just how much clarity and simplicity they bring to the system. In this post I hope to give you a sense of why I changed my mind. @Russ Cox

可见大佬的想法也会随时间而改变, 没有一个东西开始就是完美的,但是谨慎的引入,原则的坚持总不会错的。

最小版本原则

在 Go Modules 之前,Go 版本选择有两种情况:

  1. 第一种算法是默认行为go get:如果您有本地版本,请使用该版本,否则请下载并使用最新版本。此模式可能使用的版本过旧:如果您已安装B 1.1并运行go get以下载A,则go get不会更新为B 1.2,从而导致构建失败或错误。

  2. 第二种算法的行为是go get -u:下载并使用所有内容的最新版本。此模式通过使用太新的版本而失败:如果您运行go get -u下载A,它将正确更新为B 1.2,但也会更新为C 1.3和E 1.3,这不是A所要求的,可能没有经过测试,可能无法正常工作。

可见这两种情况都不能保证构建的稳定性。

最小版本原则的计算方式

为了说明最小版本原则,我们通过 Russ Cox 博客中给出的例子进行讲解。首先给出一个初始的依赖关系:

为了验证我在自己的 github 页面建立的对应的代码库,和对应的标签。 但是有一个循环依赖的问题我认为是无法满足的,所以做了一些改变, 但是并不影响整体的逻辑。 有关循环依赖的这个问题可以参考 #issue24098

1. 构建需求列表

我们根据初始的依赖关系,可以按照下图中黄色的部分表示查找的路径:
查找路径

对于同一个模块的不同版本,我们会选择最高的版本,下面就是这个算法的选择过程:
算法演进

最终我们的依赖关系可以在原来的基础上表示,黄色的部分我们保留,表示之前的查找路径,红色部分就是表示最终的依赖路径和依赖的模块版本:

选择路径

我们也可以使用 go list 命令查看响应的依赖关系:

1
2
3
4
5
6
$go list -m all
github.com/two/a
github.com/two/b v1.2.0
github.com/two/c v1.2.0
github.com/two/d v1.4.0
github.com/two/e v1.2.0

这里的最小版本原则就是指及时 C 发布了更新的版本 C1.3,但是如果不主动更新这个版本,每次构建的依赖还是 C1.2 版本,最大程度的保证了构建稳定性。 而 dep 这些依赖管理工具则会选择最新版本,很容易就会发现依赖发生了非常大的变化,每次都见都会随着版本的发布而改变,不可预期的可能性大大提高了。

2. 升级所有模块

如果我们想升级目前依赖的所有模块到最新版本(这里的最新版本是指经过发布的语义化版本), 保持之前的黄色模块依赖,红色部分表示最新的依赖关系和选择的模块。
升级所有模块

升级所有模块的命令是:

1
$go get -u

升级后可以查看模块选择:

1
2
3
4
5
6
7
8
$go list -m all
github.com/two/a
github.com/two/b v1.2.0
github.com/two/c v1.3.0
github.com/two/d v1.4.0
github.com/two/e v1.3.0
github.com/two/f v1.1.0
github.com/two/g v1.2.0

3. 计算最小需求列表

对与使用 go.mod 的项目, 如果依赖的包也使用了 go.mod, 那其依赖的版本在go.mod 中已经做了说明了,所以当前的项目就可以不写这个依赖文件了。 例如前面升级所有模块后的 go.mod 文件长这样:

1
2
3
4
5
6
7
8
9
10
11
module github.com/two/a

go 1.13

require (
github.com/two/b v1.2.0
github.com/two/c v1.3.0
github.com/two/d v1.4.0 // indirect
github.com/two/e v1.3.0 // indirect
github.com/two/g v1.2.0 // indirect
)

这里只是列出了必须的模块的依赖, 例如 github.com/two/f 模块可以通过 github.com/two/c模块的依赖表示,就没有必要写进去。
但是 github.com/two/f 引用的是 github.com/two/gv1.1.0, 这里由于需要的是 v1.2.0, 所以回单独列出来。

4. 升级单个模块

大多数情况下我们并不会一次升级所有模块,因为这样带来的不确定性太大,我们一般都会根据需求来升级模块,如果你的依赖没有你需要的新的特性或者 bug 的修复你就没有必要升级。 假如我们要在初始的依赖关系中升级 C1.2 到 C1.3 版本, 我们可以使用命令:

1
go get -u github.com/two/c

升级完成后我们可以看一下依赖的关系:

1
2
3
4
5
6
7
8
$go list -m all
github.com/two/a
github.com/two/b v1.2.0
github.com/two/c v1.3.0
github.com/two/d v1.4.0
github.com/two/e v1.2.0
github.com/two/f v1.1.0
github.com/two/g v1.1.0

根据这个依赖关系,可以画出对应的依赖关系图,红色表示最终的依赖。
升级单个模块

这里注意一个问题, C1.3 并不依赖 D1.4, 而 B1.2 依赖的是 D1.3, 也就是说没有模块依赖 D1.4, 但是在最终的版本选中却保留了 D1.4 而不是 D1.3。如果我们降级 D 则会带来一些非预期的结果,并且我们的依赖不再稳定,也违背了最小更改的原则。 所以 Go Modules 要保证不能为你带来非预期的自动降级。

5. 降级

假如我们发现 D1.4 有一个 bug , 这个 Bug 是 D1.3 引入的,我们需要将 D1.3 降级到 D1.2, 这时我们也需要将引入 D1.3 及以上版本的 B1.2 和 C1.2进行降级,因为如果只降级了 D 则 B1.2 和 C1.2 很可能使用了 D1.3以上版本的功能导致不可预期的事情发生。 可以用灰色来表示不可用的模块版本:
不可用模块

下面我们对 D 进行降级:

1
go get github.com/two/d@v1.2.0

降级后的依赖关系:

1
2
3
4
5
6
go list -m all
github.com/two/a
github.com/two/b v1.1.0
github.com/two/c v1.1.0
github.com/two/d v1.2.0
github.com/two/e v1.2.0

由于我们要遵循最小更改原则,所以我们不会主动降级下一层的依赖 E1.2, 最终我们的依赖关系用图来表示就是:
降级模块

生成的 Go Module 文件如下:

1
2
3
4
5
6
7
8
9
10
module github.com/two/a

go 1.13

require (
github.com/two/b v1.1.0
github.com/two/c v1.1.0
github.com/two/d v1.2.0 // indirect
github.com/two/e v1.2.0 // indirect
)

依赖文件中的 // indirect 是非常有用的,特别是对于一些升级和降级的模块来说,我们不能直接按照新的规则来计算依赖的版本,而要保持最小的变动原则,这个注释就是指那些发生了模块的变动但是需要保持版本和依赖的一些模块。 如果我们不保留这个模块的描述最终的依赖就会是跟新生成的依赖一样,不会有这些版本的要求,会存在一些非预期的风险, 无法尽可能的保证依赖的稳定性。

假如我们删除这两行,改为:

1
2
3
4
5
6
7
8
9
10
module github.com/two/a

go 1.13

require (
github.com/two/b v1.1.0
github.com/two/c v1.1.0
// github.com/two/d v1.2.0 // indirect
// github.com/two/e v1.2.0 // indirect
)

通过命令查看依赖关系:

1
2
3
4
5
6
go list -m all
github.com/two/a
github.com/two/b v1.1.0
github.com/two/c v1.1.0
github.com/two/d v1.1.0
github.com/two/e v1.1.0

对应的依赖关系图:
降级模块

这种降级是没有必要的,因为我们认为 D1.2 要比 D1.1 更好,而 E1.2也会比 E1.1 更好, 而且他们都做到了向下兼容,如果都进行了降级反而会产生一些不好的结果,稳定性也下降。

Go Modules 引入的变化

有了以上的理论和算法基础,Go Modules 的引入就变得不那么难了,Module 的概念就是公共引用前缀,是版本控制的单位。Module 的引入带来了几个变化:

  1. 提倡使用明确的发行版本而不是某个提交的 ID,可以清楚地表明预期。
  2. 引入了代理的概念(GOPROXY),
    2.1 不依赖各种版本控制工具下载,防止碎片化,都改为使用 HTTP 协议, 代码库都是以 zip 的形式存在。
    2.2 通过 GOPROXY 缓存依赖,保证可用性(可重复下载)和安全性(安全检测)
    2.3 将来还会引入共享代理,可以默认使用共享代理(类似其它语言的集中式管理)
  3. 通过独立的版本控制在单个代码库中开发多个模块
    有两种方式:

    1. 使用单独的分支表示不同的版本
    2. 通过子目录表示不同的版本

      vgo 两种都支持,但是第二种能够更加平滑
  4. 不再需要 vendor
    vendor 目录有两个作用。首先,他们通过其内容指定要在期间使用的依赖项的确切版本。其次,即使原始副本消失了,它们也可以确保这些依赖项的可用性。
    但是 vendor 也有一个弊端就是代码的副本太多了,每个代码库都要提交依赖,占用了大量的存储库的信息。vgo 已经解决了 vendor 所带来的两个好处,所以它就没有存在的必要性了。
    在编译时 vgo 会忽略 vendor 目录的存在,如果你还想强制使用 vendor 可以使用: go build -mod=vendor

  5. 不再需要 GOPATH
    注意你的代码库不再依赖 GOPATH, 但是目前如果还在 GOPATH 下,默认是不开启 Go Modules 的,如果你不在 GOPATH 则会开启 Go Modules。不依赖 GOPATH 不意味着不需要 GOPATH, 只是你自己的代码不需要放到这个目录下,但是一些编译的输出,和依赖库的缓存还是会放到这个目录下。

为了保证再现性,可验证性和经过验证的构建方式, Go Modules 做了以下几点:

  1. 通过最低版本原则保证每次构建下载的都是同一版本的代码
  2. 通过 hash 值保证每个版本代码都有一个唯一的标识
  3. 通过 hash 对比保证每个依赖的版本与 hash 对应,防止篡改

使用

go.mod & go.sum 文件

go.mod 和 go.sum 文件成对同时出现在模块的根目录中。
go.mod 文件记录了依赖的关系和版本。 go.mod 文件支持四种命令:

  • module:
    出现在文件第一行表示的是此模块的名字,这个模块名字就是被其他模块引用时的名字,如果不一致就会报错。
  • require:
    记录依赖的包及版本号
  • replace
    有些包的地址发生了变化,可以通过这个来指向另一个包, 不必修改 import 的地址
  • exclude
    有些包的某个版本有问题,需要在这里明确支持不使用某个版本

下面是一个 go.mod 文件的例子:

1
2
3
4
5
6
7
8
module github.com/two/a

go 1.13

require (
github.com/two/b v1.2.0
github.com/two/c v1.3.0
)

go.sum 记录用于安全性和完整性校验的信息。 go.sum 文件 每行的格式都是:

1
<模块> <版本> [/ go.mod] <哈希>

没有 /go.mod 的是表示这个版本的模块的源代码的 hash 值(使用 SHA-256算法)
带有 /go.mod 的表示这个版本的模块的 go.sun 文件的 hash 值

1
2
3
4
5
6
7
8
9
10
11
12
github.com/two/b v1.2.0 h1:1w6ZrvIUmiXBRX/cmUlzAy1fA76mgBY55/5LlpxfbiA=
github.com/two/b v1.2.0/go.mod h1:I1qidS2xpjDqFf5kPEEWqlYII81sIAsewITgM3NnpnM=
github.com/two/c v1.3.0 h1:BkIIZs1in6e4+8E/JrPG02IdO3Cw+k4YtlwXteCQaiI=
github.com/two/c v1.3.0/go.mod h1:3Xpyx3nev6KimFlAAv4VqnfklWKOM6EWQHId5qr7cvw=
github.com/two/d v1.3.0 h1:CowHIy3VdlGGBzmobCQXjS+8xkvAwTXqGJOEdSODwPQ=
github.com/two/d v1.3.0/go.mod h1:fAg6MyXvtnCUuov3tcLoCmXTV7c99ECaEKVw442oC/c=
github.com/two/e v1.2.0 h1:jJIqW7+D0MJh8g5B1amYojQvtQdhKsYqEe17QWU6oEw=
github.com/two/e v1.2.0/go.mod h1:RczHMX2xzAngm2z2jYPtHRxyJWl99oJDlZI7RHy7zfo=
github.com/two/f v1.1.0 h1:dNFUxmhP64A4BIxlI4F9vraG6snBV/UjGe0Vod8CXho=
github.com/two/f v1.1.0/go.mod h1:buQ/ZEGBVlMU2xGMfgKig/Sn1d/6addjRnsFb1mH2DQ=
github.com/two/g v1.1.0 h1:eBWE3BIwdZ3/tcnA/4KZWKRBxHnWGWopcn72wbhN/M8=
github.com/two/g v1.1.0/go.mod h1:a/jbi0S1ZL9XI+Fqd3Ca1618vwkl8Js166KP13u/EHw=

go.sum 文件并不是必须的,go.sum 中记录的模块主要是用户本地校验,既使用下载到本地的模块与 go.sum 中的校验值进行对比,如果 go.sum 中没有对应的模块,则会到一个远程的校验数据库进行校验,这个远程数据库通过 Go 语言环境变量 GOSUMDB 来设置,当然也可以选择不校验。

相关环境变量

通过 go env 命令我们可以看到所有的环境变量,其中跟 Go Modules 相关的主要有下面几个:

注意: 这里是针对 Go 1.13 版本进行说明。由于 Go Modules 还在不断发展中,很多东西不太成熟,会存在一些新增的环境变量。

  • GOPATH:
    这个变量大家都很熟悉了,但是在 Go Modules 中他的含义发生了变化,用户的代码不需要放到这个路径下,但是依然需要这个路径,Go Modules 下载的依赖放在这个路径下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
$GOPATH/pkg/
├── linux_amd64
└── mod
├── cache
│   ├── download
│   │   └── github.com
│   │   └── two
│   │   ├── b
│   │   │   └── @v
│   │   │   ├── list
│   │   │   ├── list.lock
│   │   │   ├── v1.2.0.lock
│   │   │   ├── v1.2.0.mod
│   │   │   ├── v1.2.0.zip
│   │   │   └── v1.2.0.ziphash
│   │   ├── c
│   │   │   └── @v
│   │   │   ├── list
│   │   │   ├── list.lock
│   │   │   ├── v1.3.0.lock
│   │   │   ├── v1.3.0.mod
│   │   │   ├── v1.3.0.zip
│   │   │   └── v1.3.0.ziphash
. .
. .
. .
│   └── lock
└── github.com
└── two
├── b@v1.2.0
├── c@v1.3.0
├── d@v1.3.0
├── e@v1.2.0
├── f@v1.1.0
└── g@v1.1.0

其中 mod/cache 目录下是从代理服务器下载下来的原始文件,包括打包好的源代码,依赖文件和hash 值等信息。
mod/github.com 目录下放的是解压后的代码,其中每个代码库都会放到以 module@version 格式命名的目录中,这也解决了同一个环境下放置多个版本文件的问题,解决了环境的相互干扰。
另外说明的一点是: 下载的依赖都是针对特定版本的,不会包含提交的历史信息,大大减少依赖的文件大小。

如果我们要清楚缓存文件,可以使用 go clean -modcache 命令, 这个命令会删除 $GOPATH/pkg/mod 路径下的所有内容。

  • GO111MODULE:
    这个变量表示是否开启 Go Modules 功能,因为是在 Go 1.11 版本加入的,所以叫 GO111MODULE。这个变量又三个值,分别是:

    • GO111MODULE=off : 强制不使用,依赖 GOPATH
    • GO111MODULE=on : 强制使用
    • GO111MODULE=auto : 如果当前模块在 $GOPATH/src 中不使用,如果不在 $GOPATH/src 目录中并且存在 go.mod 文件则使用
  • GOPROXY & GONOPROXY:
    GOPROXY 是指代理服务器的地址, 用户可以指定代理服务器来下载依赖,目前有几个知名的代理服务器,包括 goproxy.io, goproxy.cn 以及官方提供的 proxy.golang.org。Go Modules 发布后有很多企业没有使用还有一个原因是所有的代码库都要走代理,但是有些代码库是企业内部的,并且有权限的控制的,这些不想走代理,于是在 Go 1.13 又加入了一个 GONOPROXY 变量,通过设置这个变量可以决定哪些依赖不走代理,而是直接通过原来的方式从源地址下载。

  • GOSUMDB & GONOSUMDB:
    这两变量是控制完整性校验的。通过 GOSUMDB 我们可以设置远程的校验地址,当我们下载的依赖不在 go.sum 文件中时,我们就要通过这个地址远程校验。当然我们也可以忽略校验,通过设置 GONOSUMDB 可以选择哪些模块不需要进行校验,特别是一些私有的代码库。

  • GOPRIVATE:
    这个其实可以看成是 GOPRIVATE = GONOPROXY + GONOSUMDB。 如果设置了 GOPRIVATE 就相当于同时设置了 GONOPROXYGONUSUMDB

  • GOMOD:
    当前模块的 go.mod 文件路径,是动态的,如果不在 go 模块中则默认是: /dev/null

相关命令行

go mod

运行 go help mod 命令可以看到支持下列几个参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Go mod provides access to operations on modules.

Note that support for modules is built into all the go commands,
not just 'go mod'. For example, day-to-day adding, removing, upgrading,
and downgrading of dependencies should be done using 'go get'.
See 'go help modules' for an overview of module functionality.

Usage:

go mod <command> [arguments]

The commands are:

download download modules to local cache
edit edit go.mod from tools or scripts
graph print module requirement graph
init initialize new module in current directory
tidy add missing and remove unused modules
vendor make vendored copy of dependencies
verify verify dependencies have expected content
why explain why packages or modules are needed

Use "go help mod <command>" for more information about a command.

一般我们的使用步骤是这样:

初始化

go mod init [package name]
如果在 GOPATH下面之行,则默认使用的是相对 GOPATH 的路径,如果不是则需要加包名作为参数。建立一个空的包:
例如:
go mod init example.com/m

生成 go.mod 文件:

1
2
3
module example.com/m

go 1.13

自动添加

我们可以不用特意添加自己的依赖,如果我们执行 go build 或者 go test 等命令则会自动帮我们下载依赖并写入到 go.mod 文件中。版本选择,如果我们的依赖中有依赖文件并且写明了版本则会根据最小版本原则进行选择,如果我们的依赖没有指名版本则会下载最新的版本,这个最新版本是指语义优化标签最大的版本,如果没有语义化标签则自动生成一个标签: (v0.0.0)-(提交时间, UTC 时间)-(commit id) 作为版本的标识。

自动整理

go mod tidy, 通过这个命令我们可以自动整理当前的依赖文件,把需要的文件加入,把不再需要的依赖删除。 其实如果我们执行 go build 或者 go test 也会自动更新这个依赖文件。Go 语言的目的是尽量减轻用户的操作,所以能帮你做的都做了。

更新

go get -u [package name] 可以用来更新依赖的版本,如果加了具体的包名则只更新这一个包,如果没加则更新所有依赖包。

本地调试

使用 Go Module 后如果我们需要依赖一个或者多个包来进行调试,而这个依赖的包还没有正式提交,我们该如何做呢? 这里有两种方案:

  1. 使用 replace:
    我们可以借助 go modules 的 replace 关键字把依赖的地址改为本地未提交的版本所在的地址, 例如我们依赖 example.com/m2,但是这个包还在开发中并未发布,我们可以在 go mod 文件中加上:

    1
    replace example.com/m2 => /local/path/to/my/module

    这样我们就可以在本地随时修改依赖包进行调试了。 但是这样做有一个缺点就是我们需要修改 go.mod 文件,如果我们调试完不小心提交了这个文件,可能会照成一些不必要的麻烦,我们还有第二种方式。

  2. 使用 vendor:
    虽然 vendor 目录已经不是 go modules 所需要的了,但是这个目录不会从 Go 语言中去掉,相关的讨论可以看 vgo & vendoring 的内容。默认如果开启了 go module 那么 go build 将会忽略 vendor 目录的内容,但是如果我们在编译的时候加上一个参数: go build -mod=vendor 则还是会优先查找 vendor 下的依赖。如果我们将在开发中的依赖放入 vendor 目录中,编译的时候加上这个 mod=vendor 参数则可以满足我们调试的需求。如果我们将 vendor 目录加入版本控制之外,则不会影响我们的正常开发和提交。

go list

  • go list -m [all]: 列出当面 module, all 表示所有依赖的模块也列出来
  • go list -m -versions (package name): 列出某个依赖的所有版本

go clean

清除一些编译的缓存文件,go moduels 之后加了一个新参数: -modcache, 如果执行:
go clean -modcache 则会清楚所有$GOPATH/pkg/mod 目录下的已经下载的依赖文件,这个在一些第三方依赖存在错误,或者校验失败的情况下可以使用。

参考文献

go context

Contex 的作用

在 Go 服务器中, 每个请求都是由一个独立的 goroutine 进行处理的。请求处理程序
往往会启动其它的 goroutine 来访问后端,比如数据库和 RPC 服务。 处理请求的
goroutine 通常需要访问特定的值,比如用户身份的标识,token, 请求的超时时间。
当一个请求被取消或者超时时,处理改请求的所有 goroutine 都应该迅速退出,这样
系统就可以回收他们正在使用的资源。

对于 Go 语言,由于是单进程的模式 goroutine 之间内存是共享的,那么 goroutine 是
如何获取自己的上下文数据的呢?对于一些多线程模式运行的语言中,比如 Java 可以
通过 ThreadLocal 来传递线程间的上下文,但是 Go 语言并不提倡这种模式,Go 语言中
你甚至无法知道 goroutine 的编号,一切都是 Go 自己帮你管理的。为了解决这个问题
Go 使用的就是传递 Context 参数。

这种方式是 Go 语言比较特殊的地方,也是很多人诟病的地方,如果你要传递上下文正规
的方式就是这种,Go 语言甚至规定了它的具体用法:

  1. 不要把它放到一个结构体中, 而是在需要的地方直接传递它
  2. 放到函数的第一个参数中,并且命名为 ctx

对于第一个限制,后面会详细讲解其原因。

Context 结构

Context 本质上是为了传递上下文,这个上下文不止是一些变量,还包括传递事件。
下面我们给一个使用的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package main

import (
"context"
)

func main() {
ctx1 := context.Background()

ctx2, _ := context.WithCancel(ctx1)
ctx3, _ := context.WithCancel(ctx1)

ctx4, _ := context.WithCancel(ctx2)
ctx5, _ := context.WithCancel(ctx2)

ctx6, _ := context.WithCancel(ctx3)
ctx7, _ := context.WithCancel(ctx3)

ctx8, _ := context.WithCancel(ctx4)
ctx9, _ := context.WithCancel(ctx4)

ctx10, _ := context.WithCancel(ctx5)
ctx11, _ := context.WithCancel(ctx5)

ctx12, _ := context.WithCancel(ctx6)
ctx13, _ := context.WithCancel(ctx6)

ctx14, _ := context.WithCancel(ctx7)
ctx15, _ := context.WithCancel(ctx7)

println(ctx8)
println(ctx9)
println(ctx10)
println(ctx11)
println(ctx12)
println(ctx13)
println(ctx14)
println(ctx15)
}

对于前面这个例子,最终会形成一个 Context 的树形结构,结构如下:

根节点

前面的结构中 ctx1 是根节点, 根节点通过 context.Background() 函数创建的,
函数源码如下:

1
2
3
4
5
6
7
8
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)

func Background() Context {
return background
}

可以看到根节点是一个 emptyCtx 类型的数据, 这个结构实现了 Context 接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
type Context interface {
Deadline() (deadline time.Time, ok bool)
Done() <-chan struct{}
Err() error
Value(key interface{}) interface{}
}

type emptyCtx int

func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}

func (*emptyCtx) Done() <-chan struct{} {
return nil
}

func (*emptyCtx) Err() error {
return nil
}

func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}

可以看到 emptyCtx 实现很简单,基本都是返回 nil。对于根节点来说它并不能够
真正的传递一些信息和事件,其它的 context 则是依赖这个作为根节点来实现的。

cancelCtx

树的建立

cancelCtx 是一个可以传递 cancel 事件的 context, 通过 WithCancel 函数可以获取
一个 cancelCtx 类型的结构,并且还会返回它对应的 cancel 函数,当我们调用这个函数
时就会把事件传递到这个结构及他的所有子节点。源码实现如下:

1
2
3
4
5
6
7
8
9
10
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := newCancelCtx(parent) // 新建一个 cancelCtx
propagateCancel(parent, &c) // 把当前新节点放到 parent 的子节点中
return &c, func() { c.cancel(true, Canceled) } // 返回 cancel 函数
}

// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{Context: parent}
}

通过 newCancelCtx 函数我们把当前节点的 Context 字段指向了 parent, 也就是父节点。

通过调用 propagateCancel 函数我们可以把当前新建的节点放到对应的 context 树中,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
if parent.Done() == nil { // 对于 parent.Done() == nil 直接返回,
// 因为这个节点是空节点, 没有 children 字段
return // parent is never canceled
}
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err) // 如果父节点存在错误信息,
// 证明父节点已经被取消,那么子节点也应该取消
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{} // 放入子节点
}
p.mu.Unlock()
} else {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
}

对于根节点,执行到下面这里就会返回:

1
2
3
if parent.Done() == nil {
return // parent is never canceled
}

也就是根节点是没有 children 字段的, 无法通过根节点查找子节点, 但是子节点
可以通过 Context 字段找到父节点。

我们再看另一种情况 ctx4 是如何挂载到 ctx2。前面的操作基本都一致,但是会走到
下面这段逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
} else {
...
}

首先调用 parentCancelCtx 函数判断父节点的类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
for {
switch c := parent.(type) {
case *cancelCtx:
return c, true
case *timerCtx:
return &c.cancelCtx, true
case *valueCtx:
parent = c.Context
default:
return nil, false
}
}
}

对于 cancelCtx, timerCtx 这些类型会返回 true, 然后判断父节点是否已经有错误信息,
如果有错误信息表示父节点已经调用了 cancel, 那么为了传播这个事件,子节点也应该调用
cancel, 对于没有 cancel 的父节点则把当前节点放到父节点的 children 结构中。

如果 parentCancelCtx 返回 false 呢?也就是不属于前面几种类型。这个跟前面解释的 context
的使用原则: “1. 不要把它放到一个结构体中, 而是在需要的地方直接传递它” 相关的,也就是当我们
把 context 放到结构体中进行传递则会满足这个条件,走到下面的逻辑:

1
2
3
4
5
6
7
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()

会开起一个新的 goroutine 来监听这个节点,而不会放到树中。 可以看出其实是监听了父节点和
本身节点, 因为如果父节点 cancel 了,子节点也需要 cancel ,因为父节点的事件要传播到子节点;
本身节点也是需要监听,调用 cancel 后也要结束这个 goroutine,如果不监听则需要依赖父节点,
如果父节点不接受这个节点即使调用了 cancel 也无法结束,所以两者缺一不可。

事件传递

前面说 WithCancel 函数会返回一个 cancel 函数,如果我们调用的话会传递这个消息到所有的
子节点中。 我们修改一下前面的代码:

1
2
3
4
...
ctx4, cancel4 := context.WithCancel(ctx2)
cancel4()
...

当我调用 cancel4(), 其实调用的是 cancel(true, err), 第一个参数传 true 表示需要从树中
删除其子节点,第二个参数传取消的错误信息。 具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
// 通过 chan 传递给所有监听的程序
if c.done == nil {
c.done = closedchan
} else {
close(c.done)
}
// 递归传递消息给所有子节点
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()

// 从树中删除子节点
if removeFromParent {
removeChild(c.Context, c)
}
}

cancel 函数主要是有三个作用:

  1. 通过 close(chan) 传递给所有的监听的程序这个消息
    监听的程序如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    func (c *cancelCtx) Done() <-chan struct{} {
    c.mu.Lock()
    if c.done == nil {
    c.done = make(chan struct{})
    }
    d := c.done
    c.mu.Unlock()
    return d
    }
  2. 消息传递个所有子节点,所有监听子节点的程序也收到消息

  3. 把当前子节点机器及其下面的所有节点从树中删除

前面可以看出消息传递给子节点的时候调用了 cancel 函数,但是第一个参数传递的是 false,
为什么呢? 显然后面把当前节点的子节点已经删除了, 没有必要在对其所有下面的节点执行
删除操作了, 否则就是重复删除。

下面用图来表示, 首先是消息的传播,红色表示收到了消息的节点:

然后把节点从树中删除:

timerCtx

timerCtx 是跟时间相关的 Contex, 可以通过这个设置过期时间,并且传播消息

1
2
3
4
5
6
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.

deadline time.Time
}

要想新建一个 timerCtx 需要通过 WithDeadline 函数(也可以通过 WithTimeout, 但这个
函数其实是 WithDeadline 的包装调用,无需详细讲解), 第一个参数是父节点,第二个参数是
过期时间。 源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
// 由于父节点的过期时间比子节点的早,不需要单独设置过期,可以直接返回
if cur, ok := parent.Deadline(); ok && cur.Before(d) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: d,
}
// 构建context 树
propagateCancel(parent, c)
dur := time.Until(d)
// 已经过了过期时间,直接返回, 不需要设置过期时间
if dur <= 0 {

c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(false, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
// 设置过期时间,在过期时间会调用 cancel 函数
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded)
})
}
// 返回取消函数
return c, func() { c.cancel(true, Canceled) }
}

这个函数的实现由很多优化的地方,首先对于设置的过期时间会和父节点进行比较,如果父节点过期
时间比当前节点的过期时间早,则直接返回一个 cancelCtx, 不需要设置过期时间,因为父节点肯定
比子节点过期的早,会触发消息的传递,然后传递个子节点,子节点没有机会执行自己的消息传递。
其次,计算完时间后,如果发现已经过期了,直接调用子节点的 cancel 函数,这时已经出发了消息
传递。 上面两个条件都不满足,则会调用 time.AfterFunc 函数设置一个时间,到这个时间后会
主动调用 cancel 函数进行消息的传播。

函数也会返回对应的 cancel 函数,我们也可以主动调用, 这个函数实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(false, err)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}

这里实际调用了 cancelCtx的 cancel 函数, 还有一点要注意这里调用了c.timer.Stop(),
这里是如果主动调用了 cancel 函数,则其对应的计时器就没有作用了,应该提前停止,这样
可以主动释放资源。

valueCtx

前面讲的都是如何利用 context 传递消息,这里讲的是如何通过 context 传递数据。
context 的数据传递是通过 valueCtx 来完成的,他的定义如下:

1
2
3
4
type valueCtx struct {
Context
key, val interface{}
}

主要是包含了 一对 key, val

valueCtx 的生成是通过 WithValue 来实现的:

1
2
3
4
5
6
7
8
9
10
func WithValue(parent Context, key, val interface{}) Context {
if key == nil {
panic("nil key")
}
// 判断类型是否可以使用 == 比较
if !reflectlite.TypeOf(key).Comparable() {
panic("key is not comparable")
}
return &valueCtx{parent, key, val}
}

首先判断当前的 key 是否可以使用 == 判断相等(这个定义在 runtime/alg.go
可以看到, 这里不是重点就不介绍了)。 然后返回一个 valueCtx 结构。valueCtx
Context 字段指向的是父节点。

valueCtx 实现的也是一个树结构, 但是跟前面的 cancelCtx 不同,这里的 valueCtx
没有指向子节点的指针,只有指向父节点的指针,也就是说只能子节点访问父节点,父节点
无法方位子节点。

通过 WithValue 可以给一个 valueCtx 设置 key 和 value, 这样就能携带一些信息。
构建的树如下:

对于 valueCtx, 我们通过 Value 函数来获取这个信息:

1
2
3
4
5
6
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
return c.Context.Value(key)
}

可以看到如何当前的 key 匹配到了,则返回对应的值,如果没有找打则会寻找父节点,这样递归的往上找,
直到不是 valueCtx 的节点, 返回 nil。 可见 value 的查找是非常低效的。最重要的是当你使用 context
传递数据时,可能会滥用,比如在过渡依赖 context, 在各个地方都会设置值:

  1. 查找的时候不一定会从哪个节点开始,如果从父节点查找,而值存在子节点你是查找不到的
  2. 如果 key 一致可能会无意中覆盖原来的值
  3. 如果多个几点都有查找的 key, 那么查找的结果不一定会是哪一个

对于 key 的限制,golint 有一条规则 : should not use basic type %s as key in context.WithValue
哪些是基本类型呢?golint 中定义如下:

1
2
3
4
5
6
7
8
var basicTypeKinds = map[types.BasicKind]string{
types.UntypedBool: "bool",
types.UntypedInt: "int",
types.UntypedRune: "rune",
types.UntypedFloat: "float64",
types.UntypedComplex: "complex128",
types.UntypedString: "string",
}

就是因为对于基本类型而言,复制会出现覆盖,查找出现不确定的情况。一般情况下建议
使用一些自定义类型作为 key, 避免与其他的key冲突。

数据结构之间的关系:

前面讲了 context 中好几种数据结构及其实现,其实他们之间是有这非常紧密的联系的,
为了更加直观的看出的他们的关系,这里用一张图来表示:

这些结构体基本上都实现了 Context 接口,但是一般每个结构的侧重不一样,对于一些
接口的函数都是默认的实现。 比如 cancelCtx 并没有定义 Value 函数, valueCtx
也没有具体实现 Done, 这些函数是什么都不做的。

context 使用举例

对于 context 的使用 context 包里有说明:

  • 不要将 Context 塞到结构体里。直接将 Context 类型作为函数的第一参数,而且一般都命名为 ctx。
  • 不要向函数传入一个 nil 的 context,如果你实在不知道传什么,标准库给你准备好了一个 context:todo。
  • 不要把本应该作为函数参数的类型塞到 context 中,context 存储的应该是一些共同的数据。
    例如:登陆的 session、cookie 等。
  • 同一个 context 可能会被传递到多个 goroutine,别担心,context 是并发安全的。

context 的使用常见主要有以下几个。下面分别做一下介绍。

传递数据

在 web 开发中,我们为了串联整个请求的路径,会在日志中记录每条请求的唯一 id, 并且在访问下游服务
的时候把这个 id 传递下去。通过这个 id, 我们就能够对本次请求的路径进行了解,并且在遇到问题的时候
很好的定位在哪一步出现了问题。下面我们使用 context 来传递这个数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package main

import (
"context"
"fmt"
"math/rand"
)

type traceType string

func main() {
ctx := context.Background()
ctx = context.WithValue(ctx, traceType("traceId"), rand.Int())
process(ctx)
}

func process(ctx context.Context) {
traceID, ok := ctx.Value(traceType("traceId")).(int)
if ok {
fmt.Printf("traceType traceID=%d\n", traceID)
} else {
fmt.Println("no traceType tranceID")
}

traceID, ok = ctx.Value("traceId").(int)
if ok {
fmt.Printf("string type traceID=%d\n", traceID)
} else {
fmt.Println("no string type tranceID")
}
}

这里注意一点,WithValue 的 key 使用的是自定义的类型 traceType 而不是基本类型 string,
避免了查找冲突和覆盖的问题。所以输出结果为:

1
2
traceType traceID=5577006791947779410
no string type tranceID

在实际的开发中我们要需要在 server 端给每个请求都加上这个 ID, 这个数据优先是从 HEADER 里传过来。
所以一般实际业务中我们这么写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import (
"context"
"fmt"
"net/http"
)

type requestType string

var traceID = requestType("traceID")

func main() {

h := hand{}
http.HandleFunc("/hi", hi)
http.ListenAndServe(":8000", h)
}

type hand struct{}

func (h hand) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
v := req.Header.Get("X-TRACE-ID")
ctx := context.WithValue(req.Context(), traceID, v)
reqCtx := req.WithContext(ctx)

http.DefaultServeMux.ServeHTTP(rw, reqCtx)
}

func hi(rw http.ResponseWriter, req *http.Request) {
v := req.Context().Value(traceID).(string)
resp := fmt.Sprintf("traceID = %s\n", v)
fmt.Fprintf(rw, resp)
}

防止 goroutine 泄露

参考文献中的例子:
有一个 goroutine 往 chan 发送信息:

1
2
3
4
5
6
7
8
9
10
11
12
// gen is a broken generator that will leak a goroutine.
func gen() <-chan int {
ch := make(chan int)
go func() {
var n int
for {
ch <- n
n++
}
}()
return ch
}

调用这个函数,当信息发送次数等于 5 就停止运行:

1
2
3
4
5
6
7
// The call site of gen doesn't have a 
for n := range gen() {
fmt.Println(n)
if n == 5 {
break
}
}

停止运行后有一个问题,就是 gen 函数里的 goroutine 会一直存在,不会退出。
这样就照成了 goroutine 泄露,下面我们利用 context 改进一下这个程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// gen is a generator that can be cancellable by cancelling the ctx.
func gen(ctx context.Context) <-chan int {
ch := make(chan int)
go func() {
var n int
for {
select {
case <-ctx.Done():
return // avoid leaking of this goroutine when ctx is done.
case ch <- n:
n++
}
}
}()
return ch
}

加入了 context 参数,for 循环利用 select 监听取消的消息。调用的程序也改进了。
当 接收5次消息后会调用 cancel 函数发送消息,这样前面的 gen 就能够及时退出了。

1
2
3
4
5
6
7
8
9
10
11
12
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // make sure all paths cancel the context to avoid context leak

for n := range gen(ctx) {
fmt.Println(n)
if n == 5 {
cancel()
break
}
}

// ...

超时控制

超时控制也是用的比较多的场景。在实际的工作场景中,我们对外提供服务要保证服务的可用性,
可用性的一个指标是响应时间。 一般上游访问我们都会有一个超时时间,当过了这个超时时间
上游就会结束访问,认为这次请求失败了,这时如果我们的服务还在处理响应的请求已经没有必要
了,所以我们应该及时退出,尽快回收资源,提高程序的性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import (
"context"
"fmt"
"net/http"
"time"
)

func main() {

http.HandleFunc("/hi", hi)
http.ListenAndServe(":8000", nil)
}

func hi(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), time.Millisecond*100)
defer cancel()
reqCtx := req.WithContext(ctx)

for {
select {
case <-reqCtx.Context().Done():
return
case <-time.After(time.Second):
// do something
}
}
//...
}

这里要注意的是,如果已经进入了业务的处理内部,无法再回到 select 的阶段是无法取消这个
goroutine 的,也就是只有提前检查,或者周期性的检测才能使用。

参考

Go Concurrency Patterns: Context
深度解密Go语言之context
Using contexts to avoid leaking goroutines

etcd server 启动

入口

文件入口:

1
2
3
4
5
6
7
8
9
// main.go

package main

import "go.etcd.io/etcd/etcdmain"

func main() {
etcdmain.Main()
}

开始启动过程

  1. 平台支持检查
  2. 启动参数解析
  3. 启动服务或者 proxy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// etcdmain/main.go

func Main() {
// 检查支持的平台
checkSupportArch()

// 参数解析
if len(os.Args) > 1 {
cmd := os.Args[1]
if covArgs := os.Getenv("ETCDCOV_ARGS"); len(covArgs) > 0 {
args := strings.Split(os.Getenv("ETCDCOV_ARGS"), "\xe7\xcd")[1:]
rootCmd.SetArgs(args)
cmd = "grpc-proxy" // 如果设置了 ETCDCOV_ARGS 环境变量,就是以 grpc-proxy 方式启动
}
switch cmd {
case "gateway", "grpc-proxy": // 如果设置了 cmd变量,就以变量的形式启动,包括 gateway, grpc-proxy 两种方式
if err := rootCmd.Execute(); err != nil { // rootCmd.Execute 就等于调用 `etcd gateway` 或者 `etcd grpc-proxy`, 其它的参数不支持,调用的是 `etcd`
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
return
}
}

startEtcdOrProxyV2()
}

配置检查 & 启动 etcdServer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// etcdmain/etcd.go

func startEtcdOrProxyV2() {
grpc.EnableTracing = false

cfg := newConfig() // 默认配置项
defaultInitialCluster := cfg.ec.InitialCluster // 对应配置文件中的 ETCD_INITIAL_CLUSTER 变量,默认的集群节点配置

...

defer func() {
logger := cfg.ec.GetLogger()
if logger != nil {
logger.Sync()
}
}()

// 重新修改一下 cluster 配置中的 cluster 相关的信息,比如获取当前计算机的 hostname 代替 0.0.0.0 或 localhost
defaultHost, dhErr := (&cfg.ec).UpdateDefaultClusterFromName(defaultInitialCluster)
...

var stopped <-chan struct{}
var errc <-chan error

which := identifyDataDirOrDie(cfg.ec.GetLogger(), cfg.ec.Dir) // 检查存储数据的目录, 返回数据存储的类型是成员还是 proxy 等
if which != dirEmpty { // 节点目录不为空,证明不是第一次使用,恢复之前的配置
...
switch which {
case dirMember: // 如果是成员类型,需要开启一个 etcd 服务
stopped, errc, err = startEtcd(&cfg.ec)
case dirProxy: // 如果是proxy, 则开启一个 proxy 服务
err = startProxy(cfg)
default:
...
}
} else { // 如果为空,则根据参数启动服务
shouldProxy := cfg.isProxy()
if !shouldProxy { // 如果不是 proxy, 则启动一个正常的 server
stopped, errc, err = startEtcd(&cfg.ec)
if derr, ok := err.(*etcdserver.DiscoveryError); ok && derr.Err == v2discovery.ErrFullCluster {
if cfg.shouldFallbackToProxy() {
...
shouldProxy = true
}
}
...
if shouldProxy { // 如果是 proxy ,则启动一个 proxy 服务
err = startProxy(cfg)
}
}
...
osutil.HandleInterrupts(lg) // 接收外界信号

// At this point, the initialization of etcd is done.
// The listeners are listening on the TCP ports and ready
// for accepting connections. The etcd instance should be
// joined with the cluster and ready to serve incoming
// connections.
notifySystemd(lg) // 把型号发送给正在运行的 etcd 守护进程

select { // 进入阻塞状态,除非出现错误或者服务关闭
case lerr := <-errc:
// fatal out on listener errors
if lg != nil {
lg.Fatal("listener failed", zap.Error(lerr))
} else {
plog.Fatal(lerr)
}
case <-stopped:
}
osutil.Exit(0)
}

// startEtcd runs StartEtcd in addition to hooks needed for standalone etcd.
func startEtcd(cfg *embed.Config) (<-chan struct{}, <-chan error, error) {
e, err := embed.StartEtcd(cfg) // 根据配置开启一个 etcd server
if err != nil {
return nil, nil, err
}
osutil.RegisterInterruptHandler(e.Close) // 注册通过信号关闭时的回调函数
select { // 进入阻塞,除非接收到下面的信号
case <-e.Server.ReadyNotify(): // wait for e.Server to join the cluster 阻塞,直到当前 server 注册到了集群中
case <-e.Server.StopNotify(): // publish aborted from 'ErrStopped' 注册失败
}
return e.Server.StopNotify(), e.Err(), nil
}

开启 peer , client 和 metrics server

  1. etcd Server 启动会开起多个 server
  2. peer server 用于 etcd 集群之间的选举,探活, 默认端口 2380
  3. client 用于监听客户端请求,处理客户端读写, 默认端口 2379
  4. metrics 用户监控集群状态,需要手动指定端口,否则不开启
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// embed/etcd.go

// StartEtcd 回开启一个 etcd server, 并且接收 HTTP 请求,但是这个函数并不会保证加入到了集群中
// 加入集群是由 Etcd.Server.ReadyNotify() 来实现的
// StartEtcd launches the etcd server and HTTP handlers for client/server communication.
// The returned Etcd.Server is not guaranteed to have joined the cluster. Wait
// on the Etcd.Server.ReadyNotify() channel to know when it completes and is ready for use.
func StartEtcd(inCfg *Config) (e *Etcd, err error) {
if err = inCfg.Validate(); err != nil { // 检查一些参数是否合法
return nil, err
}
serving := false
e = &Etcd{cfg: *inCfg, stopc: make(chan struct{})}
cfg := &e.cfg
defer func() {
if e == nil || err == nil {
return
}
if !serving {
// errored before starting gRPC server for serveCtx.serversC
for _, sctx := range e.sctxs {
close(sctx.serversC)
}
}
e.Close()
e = nil
}()
// 开启 peer server, 默认端口 2380
if e.Peers, err = configurePeerListeners(cfg); err != nil { // 节点数据赋值
return e, err
}
// 开启 client server, 默认端口 2379,并且支持多协议(grpc,http, https)
if e.sctxs, err = configureClientListeners(cfg); err != nil { // client 数据赋值
return e, err
}

for _, sctx := range e.sctxs { // 当前 server ctx 记录
e.Clients = append(e.Clients, sctx.l)
}

// 注册 token
memberInitialized := true
if !isMemberInitialized(cfg) {
memberInitialized = false
urlsmap, token, err = cfg.PeerURLsMapAndToken("etcd")
if err != nil {
return e, fmt.Errorf("error setting up initial cluster: %v", err)
}
}

// AutoCompactionRetention defaults to "0" if not set.
if len(cfg.AutoCompactionRetention) == 0 {
cfg.AutoCompactionRetention = "0"
}
autoCompactionRetention, err := parseCompactionRetention(cfg.AutoCompactionMode, cfg.AutoCompactionRetention)
if err != nil {
return e, err
}

backendFreelistType := parseBackendFreelistType(cfg.ExperimentalBackendFreelistType)
// 根据配置新建一个 server 对象
if e.Server, err = etcdserver.NewServer(srvcfg); err != nil {
return e, err
}

// buffer channel 保证服务关闭的时候不会阻塞
// buffer channel so goroutines on closed connections won't wait forever
e.errc = make(chan error, len(e.Peers)+len(e.Clients)+2*len(e.sctxs))

// newly started member ("memberInitialized==false")
// does not need corruption check
if memberInitialized {
if err = e.Server.CheckInitialHashKV(); err != nil {
// set "EtcdServer" to nil, so that it does not block on "EtcdServer.Close()"
// (nothing to close since rafthttp transports have not been started)
e.Server = nil
return e, err
}
}
e.Server.Start() // 开启一个 etcd server

// 与每个 peer 保持通信
if err = e.servePeers(); err != nil {
return e, err
}
// 开启 server 与每个 client 保持通信, 可以同时支持多重协议
if err = e.serveClients(); err != nil {
return e, err
}
// 与每个 metrics 保持通信
if err = e.serveMetrics(); err != nil {
return e, err
}
...
serving = true
return e, nil
}

e.Server.Start()函数实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Start performs any initialization of the Server necessary for it to
// begin serving requests. It must be called before Do or Process.
// Start must be non-blocking; any long-running server functionality
// should be implemented in goroutines.
func (s *EtcdServer) Start() {
s.start()
// 下面的函数将在 server 关闭后等执行完成
s.goAttach(func() { s.adjustTicks() })
s.goAttach(func() { s.publish(s.Cfg.ReqTimeout()) })
s.goAttach(s.purgeFile)
s.goAttach(func() { monitorFileDescriptor(s.getLogger(), s.stopping) })
s.goAttach(s.monitorVersions)
s.goAttach(s.linearizableReadLoop)
s.goAttach(s.monitorKVHash)
}

// start prepares and starts server in a new goroutine. It is no longer safe to
// modify a server's fields after it has been sent to Start.
// This function is just used for testing.
func (s *EtcdServer) start() {
...
// 通过一个 goroutine 开启服务
// TODO: if this is an empty log, writes all peer infos
// into the first entry
go s.run()
}

func (s *EtcdServer) run() {
lg := s.getLogger()
sn, err := s.r.raftStorage.Snapshot()

// asynchronously accept apply packets, dispatch progress in-order
sched := schedule.NewFIFOScheduler()
...
s.r.start(rh)

defer func() {
s.wgMu.Lock() // block concurrent waitgroup adds in goAttach while stopping
close(s.stopping)
s.wgMu.Unlock()
s.cancel()

sched.Stop()

// wait for gouroutines before closing raft so wal stays open
s.wg.Wait()

s.SyncTicker.Stop()

// must stop raft after scheduler-- etcdserver can leak rafthttp pipelines
// by adding a peer after raft stops the transport
s.r.stop()

// kv, lessor and backend can be nil if running without v3 enabled
// or running unit tests.
if s.lessor != nil {
s.lessor.Stop()
}
if s.kv != nil {
s.kv.Close()
}
if s.authStore != nil {
s.authStore.Close()
}
if s.be != nil {
s.be.Close()
}
if s.compactor != nil {
s.compactor.Stop()
}
close(s.done)
}()
var expiredLeaseC <-chan []*lease.Lease
if s.lessor != nil {
expiredLeaseC = s.lessor.ExpiredLeasesC()
}
for {
select {
case ap := <-s.r.apply(): // 数据更新
f := func(context.Context) { s.applyAll(&ep, &ap) }
sched.Schedule(f)
case leases := <-expiredLeaseC: // 租期过期处理
s.goAttach(func() {
...
})
case err := <-s.errorc: // 出现错误,退出 server
...
return
case <-getSyncC(): // 定期同步数据
if s.v2store.HasTTLKeys() {
s.sync(s.Cfg.ReqTimeout())
}
case <-s.stop: // 停止信号,停止 server
return
}
}
}

// etcdserver/raft.go.start
// raft node 启动,保持心跳
// start prepares and starts raftNode in a new goroutine. It is no longer safe
// to modify the fields after it has been started.
func (r *raftNode) start(rh *raftReadyHandler) {
}

多协议支持

如何做到监听一个端口, 开启多个协议的 server 进行处理呢? 借助 github.com/soheilhy/cmux 来完成的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

func (e *Etcd) serveClients() (err error) {
if !e.cfg.ClientTLSInfo.Empty() {
if e.cfg.logger != nil {
e.cfg.logger.Info(
"starting with client TLS",
zap.String("tls-info", fmt.Sprintf("%+v", e.cfg.ClientTLSInfo)),
zap.Strings("cipher-suites", e.cfg.CipherSuites),
)
} else {
plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo)
}
}

// Start a client server goroutine for each listen address
var h http.Handler
if e.Config().EnableV2 {
if len(e.Config().ExperimentalEnableV2V3) > 0 {
srv := v2v3.NewServer(e.cfg.logger, v3client.New(e.Server), e.cfg.ExperimentalEnableV2V3)
h = v2http.NewClientHandler(e.GetLogger(), srv, e.Server.Cfg.ReqTimeout())
} else {
h = v2http.NewClientHandler(e.GetLogger(), e.Server, e.Server.Cfg.ReqTimeout())
}
} else {
mux := http.NewServeMux()
etcdhttp.HandleBasic(mux, e.Server)
h = mux
}

gopts := []grpc.ServerOption{}
if e.cfg.GRPCKeepAliveMinTime > time.Duration(0) {
gopts = append(gopts, grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: e.cfg.GRPCKeepAliveMinTime,
PermitWithoutStream: false,
}))
}
if e.cfg.GRPCKeepAliveInterval > time.Duration(0) &&
e.cfg.GRPCKeepAliveTimeout > time.Duration(0) {
gopts = append(gopts, grpc.KeepaliveParams(keepalive.ServerParameters{
Time: e.cfg.GRPCKeepAliveInterval,
Timeout: e.cfg.GRPCKeepAliveTimeout,
}))
}
// start client servers in each goroutine
for _, sctx := range e.sctxs {
go func(s *serveCtx) { // 这里正式启动多个 client server, 包括不同的协议
e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, h, e.errHandler, gopts...))
}(sctx)
}
return nil
}

// 上面函数中 s.serve 函数的实现
// serve accepts incoming connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
// 这里使用了 github.com/soheilhy/cmux 来实现链接复用器
// 同一个链接可以根据其协议分发到不同的 server 协议处理
// 支持 grpc, ssh, http, https等
// 同一个链接依次只能是一个协议
// 根据协议头的前几个字段来判断协议
func (sctx *serveCtx) serve(
s *etcdserver.EtcdServer,
tlsinfo *transport.TLSInfo,
handler http.Handler,
errHandler func(error),
gopts ...grpc.ServerOption) (err error) {
logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
<-s.ReadyNotify()

if sctx.lg == nil {
plog.Info("ready to serve client requests")
}

m := cmux.New(sctx.l)
v3c := v3client.New(s)
servElection := v3election.NewElectionServer(v3c)
servLock := v3lock.NewLockServer(v3c)

var gs *grpc.Server
defer func() {
if err != nil && gs != nil {
gs.Stop()
}
}()
// http 请求
if sctx.insecure {
gs = v3rpc.Server(s, nil, gopts...)
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
sctx.serviceRegister(gs)
}
// 匹配 HTTP2 协议
grpcl := m.Match(cmux.HTTP2())
// 启动 grpc server
go func() { errHandler(gs.Serve(grpcl)) }()

var gwmux *gw.ServeMux
if s.Cfg.EnableGRPCGateway {
gwmux, err = sctx.registerGateway([]grpc.DialOption{grpc.WithInsecure()})
if err != nil {
return err
}
}

httpmux := sctx.createMux(gwmux, handler)

srvhttp := &http.Server{
Handler: createAccessController(sctx.lg, s, httpmux),
ErrorLog: logger, // do not log user error
}
// 匹配 HTTP1 协议
httpl := m.Match(cmux.HTTP1())
// 启动 client HTTP Server
go func() { errHandler(srvhttp.Serve(httpl)) }()
sctx.serversC <- &servers{grpc: gs, http: srvhttp}
...
}
// https 请求
if sctx.secure {
...
}
close(sctx.serversC)
return m.Serve()
}

golang sync.mutex

##

##

##

1
2
3
4
5
6
7
const (
mutexLocked = 1 << iota // 1 = 0b001
mutexWoken // 2 = 0b010
mutexStarving // 4 = 0b100
mutexWaiterShift = iota // 3 用来屏蔽低三位,取数量
starvationThresholdNs = 1e6 // 10^6 ns = 1 ms
)
1
2
3
4
type Mutex struct {
state int32
sema uint32
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// state
+-----------------------------+---+---+---+
|00000000000000000000000000000|0/1|0/1|0/1|
+-----------------------------+---+---+---+
| | | |
| | | | +-----------+
| | | +------->|mutexLocked|
| | | +-----------+
| | | +----------+
| | +----------->|mutexWoken|
| | +----------+
| | +-------------+
| +--------------->|mutexStarving|
| +-------------+
| +------------------+
+-------------------------------->| wait list count |
+------------------+
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
func (m *Mutex) lockSlow() {
var waitStartTime int64 // 开始等待锁的时间点
starving := false // 当前 goroutine 是否处于饥饿状态
awoke := false // 当前 goroutine 是否被唤醒
iter := 0 //
old := m.state // 当前锁的状态
for {
// Don't spin in starvation mode, ownership is handed off to waiters
// so we won't be able to acquire the mutex anyway.
// 饥饿模式不要自旋, 因为锁的所有权回直接交给 waiters, 所以我们不会获取到锁
// 条件翻译为伪代码: isLocked() && isNotStarving() && canSpin()
if old&(mutexLocked|mutexStarving) == mutexLocked && runtime_canSpin(iter) {
// Active spinning makes sense.
// Try to set mutexWoken flag to inform Unlock
// to not wake other blocked goroutines.
// 尝试设置 mutexWoken 标志来通知 Unlock 不唤醒其它被阻塞的 goroutine
// 条件可以转换为: 当前 goroutine 没有被唤醒 &&
锁状态没有被唤醒 &&
等待获取锁的 goroutine 不为 0 &&
锁的状态改从未唤醒更新为 被唤醒
if !awoke && old&mutexWoken == 0 && old>>mutexWaiterShift != 0 &&
atomic.CompareAndSwapInt32(&m.state, old, old|mutexWoken) {
awoke = true // 设置当前 goroutine 为唤醒状态
}
runtime_doSpin() // 进入自旋
iter++
old = m.state // 更新锁状态
continue
}

// 经过上一步后,锁和状态的组合有下面几个:
// 获取锁 + 正常模式
// 获取锁 + 饥饿模式
// 未获取锁 + 正常模式
// 未获取锁 + 饥饿模式
new := old
// Don't try to acquire starving mutex, new arriving goroutines must queue.
// 正常模式: 期望设置为获取锁
// 如果是饥饿模式, 新来的 goroutine 必须放到锁队列尾部排队
if old&mutexStarving == 0 {
new |= mutexLocked
}
// (锁被获取 || 饥饿模式): 等待锁的 goroutine 数量 +1
if old&(mutexLocked|mutexStarving) != 0 {
new += 1 << mutexWaiterShift
}
// The current goroutine switches mutex to starvation mode.
// But if the mutex is currently unlocked, don't do the switch.
// Unlock expects that starving mutex has waiters, which will not
// be true in this case.
// 当前 gorutine 是(饥饿状态 && 锁被获取): 期望设置锁状态为饥饿模式
if starving && old&mutexLocked != 0 {
new |= mutexStarving
}
// 当前 goroutine 处于被唤醒状态
if awoke {
// The goroutine has been woken from sleep,
// so we need to reset the flag in either case.
// 如果锁状态为被唤醒状态,证明存在冲突
if new&mutexWoken == 0 {
throw("sync: inconsistent mutex state")
}
// new 期望设置为非被唤醒状态
new &^= mutexWoken
}
// 更新锁状态为 从 old 变为 new
if atomic.CompareAndSwapInt32(&m.state, old, new) {
// old 原来锁状态不是被获取 && 锁状态不是饥饿状态
// 根据前面的条件 new 现在是获取锁状态
// old 和 new 交换成功,所以当前 goroutine 获取到了锁, 直接返回
if old&(mutexLocked|mutexStarving) == 0 {
break // locked the mutex with CAS
}
// 走到这里: old 是被获取 || old 是饥饿状态
// waitStartTime != 0 证明等待过, 否则未等待过
// If we were already waiting before, queue at the front of the queue.
queueLifo := waitStartTime != 0 // true or false
if waitStartTime == 0 {
waitStartTime = runtime_nanotime()
}
// 如果等待过则放到锁队列头
// 否则放到锁队列尾部
runtime_SemacquireMutex(&m.sema, queueLifo, 1)
// 如果等待时间超过了 starvationThresholdNs (1ms), 则设置当前 goroutine 为饥饿模式
starving = starving || runtime_nanotime()-waitStartTime > starvationThresholdNs
old = m.state
// 如果原来处于饥饿模式
if old&mutexStarving != 0 {
// If this goroutine was woken and mutex is in starvation mode,
// ownership was handed off to us but mutex is in somewhat
// inconsistent state: mutexLocked is not set and we are still
// accounted as waiter. Fix that.
// 如果当前 goroutine 处于饥饿模式, 但是 mutex出一些冲突的状态: mutexLocked 状态没有设置,当前 goroutine 仍处于 waiter 中
if old&(mutexLocked|mutexWoken) != 0 || old>>mutexWaiterShift == 0 {
throw("sync: inconsistent mutex state")
}
delta := int32(mutexLocked - 1<<mutexWaiterShift)
// 当前 goroutine 不是饥饿状态 || 等待的 gorouine == 1, 退出饥饿模式
if !starving || old>>mutexWaiterShift == 1 {
// 退出饥饿模式
// Exit starvation mode.
// Critical to do it here and consider wait time.
// Starvation mode is so inefficient, that two goroutines
// can go lock-step infinitely once they switch mutex
// to starvation mode.
delta -= mutexStarving
}
// 等待队列数量 -1 && 获取锁
atomic.AddInt32(&m.state, delta)
break
}
awoke = true
iter = 0
} else {
old = m.state // 更新 state
}
}

if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Lock locks m.
// If the lock is already in use, the calling goroutine
// blocks until the mutex is available.
func (m *Mutex) Lock() {
// Fast path: grab unlocked mutex.
if atomic.CompareAndSwapInt32(&m.state, 0, mutexLocked) {
if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
return
}

var waitStartTime int64 // 用来存当前goroutine等待的时间
starving := false // 用来存当前goroutine是否饥饿
awoke := false // 用来存当前goroutine是否已唤醒
iter := 0 // 用来存当前goroutine的循环次数(想一想一个goroutine如果循环了2147483648次咋办……)
old := m.state // 复制一下当前锁的状态
for { // 自旋
// 如果是饥饿情况之下,就不要自旋了,因为锁会直接交给队列头部的goroutine
// 如果锁是被获取状态,并且满足自旋条件(canSpin见后文分析),那么就自旋等锁
// 伪代码:if isLocked() && isNotStarving() && canSpin()
// old&(mutexLocked|mutexStarving) == mutexLocked 满足的条件为: (0x1) & (101) ; 不满足的条件为: (1xx) & (101) 或 (0x1) && (101)
if old&(mutexLocked|mutexStarving) == mutexLocked && runtime_canSpin(iter) {
// 将自己的状态以及锁的状态设置为唤醒,这样当Unlock的时候就不会去唤醒其它被阻塞的goroutine了
// 自己为未唤醒状态, 锁状态为未唤醒, 等待锁的goroutine 数量不为0, 将锁状态从未唤醒更新为唤醒
if !awoke && old&mutexWoken == 0 && old>>mutexWaiterShift != 0 &&
atomic.CompareAndSwapInt32(&m.state, old, old|mutexWoken) {
awoke = true // 当前 goroutine 状态更新为唤醒
}
runtime_doSpin() // 进行自旋(分析见后文)
iter++
old = m.state // 更新锁的状态(有可能在自旋的这段时间之内锁的状态已经被其它goroutine改变)
continue
}

// 当走到这一步的时候,可能会有以下的情况:
// 1. 锁被获取+ 饥饿
// 2. 锁被获取+ 正常
// 3. 锁空闲 + 饥饿
// 4. 锁空闲 + 正常

// goroutine的状态可能是唤醒以及非唤醒

// 复制一份当前的状态,目的是根据当前状态设置出期望的状态,存在new里面,
// 并且通过CAS来比较以及更新锁的状态
// old用来存锁的当前状态
new := old

// 如果说锁不是饥饿状态,就把期望状态设置为被获取(获取锁)
// 也就是说,如果是饥饿状态,就不要把期望状态设置为被获取
// 新到的goroutine乖乖排队去
// 伪代码:if isNotStarving()
if old&mutexStarving == 0 {
// 伪代码:newState = locked
new |= mutexLocked
}
// 如果锁是被获取状态,或者饥饿状态
// 就把期望状态中的等待队列的等待者数量+1(实际上是new + 8)
// (会不会可能有三亿个goroutine等待拿锁……)
if old&(mutexLocked|mutexStarving) != 0 {
new += 1 << mutexWaiterShift
}
// 如果说当前的goroutine是饥饿状态,并且锁被其它goroutine获取
// 那么将期望的锁的状态设置为饥饿状态
// 如果锁是释放状态,那么就不用切换了
// Unlock期望一个饥饿的锁会有一些等待拿锁的goroutine,而不只是一个
// 这种情况下不会成立
if starving && old&mutexLocked != 0 {
// 期望状态设置为饥饿状态
new |= mutexStarving
}
// 如果说当前goroutine是被唤醒状态,我们需要reset这个状态
// 因为goroutine要么是拿到锁了,要么是进入sleep了
if awoke {
// 如果说期望状态不是woken状态,那么肯定出问题了
// 这里看不懂没关系,wake的逻辑在下面
if new&mutexWoken == 0 {
throw("sync: inconsistent mutex state")
}
// 这句就是把new设置为非唤醒状态
// &^的意思是and not
new &^= mutexWoken
}
// 通过CAS来尝试设置锁的状态
// 这里可能是设置锁,也有可能是只设置为饥饿状态和等待数量
if atomic.CompareAndSwapInt32(&m.state, old, new) {
// 如果说old状态不是饥饿状态也不是被获取状态
// 那么代表当前goroutine已经通过CAS成功获取了锁
// (能进入这个代码块表示状态已改变,也就是说状态是从空闲到被获取)
if old&(mutexLocked|mutexStarving) == 0 {
break // locked the mutex with CAS
}
// 如果之前已经等待过了,那么就要放到队列头
queueLifo := waitStartTime != 0
// 如果说之前没有等待过,就初始化设置现在的等待时间
if waitStartTime == 0 {
waitStartTime = runtime_nanotime()
}
// 既然获取锁失败了,就使用sleep原语来阻塞当前goroutine
// 通过信号量来排队获取锁
// 如果是新来的goroutine,就放到队列尾部
// 如果是被唤醒的等待锁的goroutine,就放到队列头部
runtime_SemacquireMutex(&m.sema, queueLifo)

// 这里sleep完了,被唤醒

// 如果当前goroutine已经是饥饿状态了
// 或者当前goroutine已经等待了1ms(在上面定义常量)以上
// 就把当前goroutine的状态设置为饥饿
starving = starving || runtime_nanotime()-waitStartTime > starvationThresholdNs
// 再次获取一下锁现在的状态
old = m.state
// 如果说锁现在是饥饿状态,就代表现在锁是被释放的状态,当前goroutine是被信号量所唤醒的
// 也就是说,锁被直接交给了当前goroutine
if old&mutexStarving != 0 {
// 如果说当前锁的状态是被唤醒状态或者被获取状态,或者说等待的队列为空
// 那么是不可能的,肯定是出问题了,因为当前状态肯定应该有等待的队列,锁也一定是被释放状态且未唤醒
if old&(mutexLocked|mutexWoken) != 0 || old>>mutexWaiterShift == 0 {
throw("sync: inconsistent mutex state")
}
// 当前的goroutine获得了锁,那么就把等待队列-1
delta := int32(mutexLocked - 1<<mutexWaiterShift)
// 如果当前goroutine非饥饿状态,或者说当前goroutine是队列中最后一个goroutine
// 那么就退出饥饿模式,把状态设置为正常
if !starving || old>>mutexWaiterShift == 1 {
// Exit starvation mode.
// Critical to do it here and consider wait time.
// Starvation mode is so inefficient, that two goroutines
// can go lock-step infinitely once they switch mutex
// to starvation mode.
delta -= mutexStarving
}
// 原子性地加上改动的状态
atomic.AddInt32(&m.state, delta)
break
}
// 如果锁不是饥饿模式,就把当前的goroutine设为被唤醒
// 并且重置iter(重置spin)
awoke = true
iter = 0
} else {
// 如果CAS不成功,也就是说没能成功获得锁,锁被别的goroutine获得了或者锁一直没被释放
// 那么就更新状态,重新开始循环尝试拿锁
old = m.state
}
}

if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
}

Go Interface 使用

本文基于go1.12.4源码

Duck Typing

面相对象

实现多个接口

下面举一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package main

type Mather interface {
Sub(a, b int64) int64
Add(a, b int32) int32
}

type Caller interface {
Name() string
}

type Adder struct{ id int32 }

func main() {
adder := Adder{id: 6754}
CallAdd(adder)
CallSub(adder)
CallName(adder)
}

//go:noinline
func (adder Adder) Add(a, b int32) int32 { return a + b }

//go:noinline
func (adder Adder) Sub(a, b int64) int64 { return a - b }

//go:noinline
func (adder Adder) Name() string { return "Adder" }

func CallAdd(m Mather) {
m.Add(12, 2)
}

func CallSub(m Mather) {
m.Sub(19, 4)
}

func CallName(c Caller) {
c.Name()
}

Adder实现了两个接口MatherCaller, 定义CallAddCallSub调用Mather类型,可以看到:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
0x0031 00049 (main.go:16)   MOVL    $6754, (SP)
0x0038 00056 (main.go:16) CALL runtime.convT32(SB)
0x003d 00061 (main.go:16) MOVQ 8(SP), AX; 返回值, data 字段
0x0042 00066 (main.go:16) MOVQ AX, ""..autotmp_1+40(SP)
0x0047 00071 (main.go:16) LEAQ go.itab."".Adder,"".Mather(SB), CX ; Mather 类型
0x004e 00078 (main.go:16) MOVQ CX, (SP) ; 类型从 CX 放到栈底
0x0052 00082 (main.go:16) MOVQ AX, 8(SP); 值 data 从 AX 放到 8(SP)位置
0x0057 00087 (main.go:16) CALL "".CallAdd(SB) ; 前面 (SP)和8(SP)加起来就是一个 adder 实现的 Mather 类型,作为这个函数调用的参数
0x005c 00092 (main.go:17) MOVL "".adder+20(SP), AX
0x0060 00096 (main.go:17) MOVL AX, (SP)
0x0063 00099 (main.go:17) CALL runtime.convT32(SB)
0x0068 00104 (main.go:17) MOVQ 8(SP), AX
0x006d 00109 (main.go:17) MOVQ AX, ""..autotmp_2+32(SP)
0x0072 00114 (main.go:17) LEAQ go.itab."".Adder,"".Mather(SB), CX
0x0079 00121 (main.go:17) MOVQ CX, (SP)
0x007d 00125 (main.go:17) MOVQ AX, 8(SP)
0x0082 00130 (main.go:17) CALL "".CallSub(SB); 处理同上 CallAdd

接着调用CallName, 需要的是一个Caller类型的参数:

1
2
3
4
5
6
7
8
9
0x0087 00135 (main.go:18)   MOVL    "".adder+20(SP), AX; adder 的值放到 AX
0x008b 00139 (main.go:18) MOVL AX, (SP); 放到栈底
0x008e 00142 (main.go:18) CALL runtime.convT32(SB)
0x0093 00147 (main.go:18) MOVQ 8(SP), AX; 返回处理后的值 unsafe.Pointer
0x0098 00152 (main.go:18) MOVQ AX, ""..autotmp_3+24(SP)
0x009d 00157 (main.go:18) LEAQ go.itab."".Adder,"".Caller(SB), CX ; Caller 类型地址放到 CX
0x00a4 00164 (main.go:18) MOVQ CX, (SP); 类型地址从 CX 赋值到栈底
0x00a8 00168 (main.go:18) MOVQ AX, 8(SP); 值 data 从 AX 赋值到 8(SP)
0x00ad 00173 (main.go:18) CALL "".CallName(SB): 前面两行把 adder 转换为Caller 类型,并做为参数本函数

通过上面可以看出,一个实例实现了多个接口,在具体调用的地方会根据接口的类型转换为不同的接口

未实现接口

如果我们调用一个接口的方法,而对应的实例没有实现这个接口会出现什么问题呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package main

type Adder struct{ id int32 }

type Empty interface {
F()
}

func main() {
adder := Adder{id: 6754}
CallF(adder)
}

func CallF(e Empty) {
e.F()
}

对上面的代码进行编译,得到:

1
2
3
# command-line-arguments
./main2.go:11:7: cannot use adder (type Adder) as type Empty in argument to CallF:
Adder does not implement Empty (missing F method)

可见编译器会在编译阶段对 AST 数据结构进行检查,如果发现没有实现对应的函数,就会报错。具体代码在cmd/compile/internal/gc/subr.go:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
if why != nil {
if isptrto(src, TINTER) {
*why = fmt.Sprintf(":\n\t%v is pointer to interface, not interface", src)
} else if have != nil && have.Sym == missing.Sym && have.Nointerface() {
*why = fmt.Sprintf(":\n\t%v does not implement %v (%v method is marked 'nointerface')", src, dst, missing.Sym)
} else if have != nil && have.Sym == missing.Sym {
*why = fmt.Sprintf(":\n\t%v does not implement %v (wrong type for %v method)\n"+
"\t\thave %v%0S\n\t\twant %v%0S", src, dst, missing.Sym, have.Sym, have.Type, missing.Sym, missing.Type)
} else if ptr != 0 {
*why = fmt.Sprintf(":\n\t%v does not implement %v (%v method has pointer receiver)", src, dst, missing.Sym)
} else if have != nil {
*why = fmt.Sprintf(":\n\t%v does not implement %v (missing %v method)\n"+
"\t\thave %v%0S\n\t\twant %v%0S", src, dst, missing.Sym, have.Sym, have.Type, missing.Sym, missing.Type)
} else {
*why = fmt.Sprintf(":\n\t%v does not implement %v (missing %v method)", src, dst, missing.Sym)
}
}

值接收与指针接收

实现接口方法的时候可以使用指针接收也可以使用值接收,他们有什么区别?不通的接收方式存在什么问题呢?针对实现和调用的方式,我们可以有四种组合,分别是:

  1. 值接收,值调用
  2. 值接收,指针调用
  3. 指针接收,指针调用
  4. 指针接收,值调用

值接收,值调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package main

import "fmt"

type notifer interface {
notify()
}

type user struct {
id int32
}

func (u user) notify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func main() {
u := user{9527}
sendNotification(u)
}

func sendNotification(n notifer) {
n.notify()
}

这种方式是最常见的,可以编译和调用,

1
2
3
4
5
6
7
8
9
10
0x001d 00029 (main3.go:18)  MOVL    $0, "".u+20(SP) ; 初始化 u 为空值
0x0025 00037 (main3.go:18) MOVL $9527, "".u+20(SP) ; 给 u 赋值
0x002d 00045 (main3.go:19) MOVL $9527, (SP) ; 放到栈底,作为下面函数调用的参数
0x0034 00052 (main3.go:19) CALL runtime.convT32(SB) ; 返回新申请的堆上的数据,并且返回
0x0039 00057 (main3.go:19) MOVQ 8(SP), AX; 函数的返回值放到 AX 中
0x003e 00062 (main3.go:19) MOVQ AX, ""..autotmp_1+24(SP) ;放到临时变量 autotmp_1中
0x0043 00067 (main3.go:19) LEAQ go.itab."".user,"".notifer(SB), CX ; 把 user 转换为 notifer 类型,并把 itab 地址放到 CX
0x004a 00074 (main3.go:19) MOVQ CX, (SP) ; _type 地址 赋值到栈底
0x004e 00078 (main3.go:19) MOVQ AX, 8(SP); data 赋值到 8(SP)
0x0053 00083 (main3.go:19) CALL "".sendNotification(SB) ; 前面两行组成的 interface 作为参数进行函数调用

这里详细分析一下下面这行代码:

1
LEAQ    go.itab."".user,"".notifer(SB), CX

详细解释一下这里的含义: go tool compile生成的是一个间接目标文件,还没有经过 链接器的链接, 符号没有把 package 名字填充上,如果填充上的话应该是这样:
go.itab.main.user,main.notifer(SB)(package是main), 这个代码可以看出其作用是为了把usernotifer关联起来,并且取出itab的地址。在汇编代码中还可以找出这样一段:

1
2
3
4
5
6
7
8
9
go.itab."".user,"".notifer SRODATA dupok size=32
0x0000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................
0x0010 56 e9 47 80 00 00 00 00 00 00 00 00 00 00 00 00 V.G.............
rel 0+8 t=1 type."".notifer+0
rel 8+8 t=1 type."".user+0
rel 24+8 t=1 "".(*user).notify+0
go.itablink."".user,"".notifer SRODATA dupok size=8
0x0000 00 00 00 00 00 00 00 00 ........
rel 0+8 t=1 go.itab."".user,"".notifer+0

对上面的代码我们一句一句来分析,首先第一句是声明和符号和他的属性: go.itab."".user,"".notifer SRODATA dupok size=32
我们这里得到的是一个 32 字节的全局对象的符号,该符号将被存到二进制文件的 .rodata 段中

  • dupok表示: 该变量对应的标识符可能有多个, 链接时 只选择其中一个即可,一般用于合并相同的常量字符串,减少重复数据占用的空间
  • RODATA表示: 将变量定义在只读内存段,因此后续任何对此变量的修改操作将导致异常(recover()也无法捕获) )

后面的 两行表示的是这32个字节存储的数据内容, 也就是itab被序列化之后的表示方法。 我们再来回顾一下itab类型的定义:

1
2
3
4
5
6
7
type itab struct {       // 32 bytes on a 64bit arch
inter *interfacetype // offset 0x00 ($00)
_type *_type // offset 0x08 ($08)
hash uint32 // offset 0x10 ($16)
_ [4]byte // offset 0x14 ($20)
fun [1]uintptr // offset 0x18 ($24)
}

可以看出前面 32 字节中有内容的部分对应的就是itab.hash的四个字节
再往下:

  • rel 0+8 t=1 type."".notifer+0 : 告诉链接器需要将内容的前8个字节填充为全局符号 type."".notifer 的地址 , 也就是 itab.inter 字段
  • rel 8+8 t=1 type."".user+0 : 告诉链接器需要将内容的 8-16 字节填充为全局符号 type."".user 的地址 , 也就是 itab._type字段
  • rel 24+8 t=1 "".(*user).notify+0: 这里对应的是itab.func的值, 填充的是 user.notify 函数的地址

总结一下LEAQ go.itab."".user,"".notifer(SB), CX的含义就是:

  1. 使用 接口 notifer 和 类型 user 组合成一个 itab类型
  2. itab 地址加载到 CX 编译器

值接收,指针调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package main

import "fmt"

type notifer interface {
notify()
}

type user struct {
id int32
}

func (u user) notify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func main() {
u := &user{9527}
sendNotification(u)
}

func sendNotification(n notifer) {
n.notify()
}
1
2
3
4
5
6
7
8
9
10
11
12
13
0x001d 00029 (main3.go:18)  LEAQ    type."".user(SB), AX ; 获取 user 的类型_type 地址放到 AX
0x0024 00036 (main3.go:18) MOVQ AX, (SP) ; _type 地址放到栈底,作为参数
0x0028 00040 (main3.go:18) CALL runtime.newobject(SB) ; 调用 runtime.newobject 会从堆上申请内存 用来存放数据
0x002d 00045 (main3.go:18) MOVQ 8(SP), AX ; 返回值放到 AX
0x0032 00050 (main3.go:18) MOVQ AX, ""..autotmp_2+24(SP) ; 返回值赋值给 autotmp_2
0x0037 00055 (main3.go:18) MOVL $9527, (AX); 9527 赋值给 AX 所指向的地址的值
0x003d 00061 (main3.go:18) MOVQ ""..autotmp_2+24(SP), AX
0x0042 00066 (main3.go:18) MOVQ AX, "".u+16(SP) ; AX 赋值给变量 u, u 的地址是指向 runtime.newobject 新申请的地址,值为 9527
0x0047 00071 (main3.go:19) MOVQ AX, ""..autotmp_1+32(SP) ; AX 赋值给临时变量 autotmp_1
0x004c 00076 (main3.go:19) LEAQ go.itab.*"".user,"".notifer(SB), CX ; 获取 user 实现 notifer 接口类型的地址,放到 CX
0x0053 00083 (main3.go:19) MOVQ CX, (SP) ; itab 地址放到栈底
0x0057 00087 (main3.go:19) MOVQ AX, 8(SP); data 放到 8(SP)
0x005c 00092 (main3.go:19) CALL "".sendNotification(SB) ; 前面两行做给一个 interface 参数,调用此函数

值接收,值调用不通的点有:

  1. 变量uuser类型的变量的地址,需要通过runtime.newobject申请新的地址
  2. 获取itab类型的地址方式不一样: 在go.itabuser之间多了一个*

    1
    LEAQ    go.itab.*"".user,"".notifer(SB), CX

    关于这个符号的具体细节如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    go.itab.*"".user,"".notifer SRODATA dupok size=32
    0x0000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................
    0x0010 c9 1b ab 4c 00 00 00 00 00 00 00 00 00 00 00 00 ...L............
    rel 0+8 t=1 type."".notifer+0
    rel 8+8 t=1 type.*"".user+0
    rel 24+8 t=1 "".(*user).notify+0
    go.itablink.*"".user,"".notifer SRODATA dupok size=8
    0x0000 00 00 00 00 00 00 00 00 ........
    rel 0+8 t=1 go.itab.*"".user,"".notifer+0

    可以看到唯一不一样的地方就是:
    rel 8+8 t=1 type.*"".user+0: 把 user地址类型放到_type字段的位置。

从上面的代码可以看出,不一样的地方就是 _type这个字段,函数的调用都是一样的,所以值接收 也可以用指针类型调用。

指针接收,指针调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package main

import "fmt"

type notifer interface {
notify()
}

type user struct {
id int32
}

func (u *user) notify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func main() {
u := &user{9527}
sendNotification(u)
}

func sendNotification(n notifer) {
n.notify()
}

这种方式其实跟值接收, 指针调用 基本上是一样的,interface.itabinterface.data都是一样的。

指针接收,值调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package main

import "fmt"

type notifer interface {
notify()
}

type user struct {
id int32
}

func (u *user) notify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func main() {
u := user{9527}
sendNotification(u)
}

func sendNotification(n notifer) {
n.notify()
}

这种方式编译时无法通过的,报错如下:

1
2
./main3.go:22:18: cannot use u (type user) as type notifer in argument to sendNotification:
user does not implement notifer (notify method has pointer receiver)

编译器认为user并没有实现notifer接口, 为什么呢?为了一探究竟,我们改一下上面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package main

import "fmt"

type notifer interface {
notify()
}

type user struct {
id int32
}

func (u user) notify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func (u *user) ptrnotify() {
fmt.Printf("Sending user email to %d\n", u.id)
}

func main() {
u1 := user{9527}
u2 := &user{9527}

sendNotification(u1)
sendNotification(u2)
}

func sendNotification(n notifer) {
n.notify()
}

notify 函数是值接收者,ptrnotify是指针接收者, 观察一下生成的汇编:

1
2
3
4
5
6
7
8
9
10
11
12
13
14

"".notifer.notify STEXT dupok size=92 args=0x10 locals=0x10
...
"".(*user).notify STEXT dupok size=108 args=0x8 locals=0x18
...
"".user.notify STEXT size=206 args=0x8 locals=0x80
...
"".(*user).ptrnotify STEXT size=229 args=0x8 locals=0x88
"".main STEXT size=185 args=0x0 locals=0x40
...
"".sendNotification STEXT size=68 args=0x10 locals=0x10
...
"".init STEXT size=100 args=0x0 locals=0x8
...

发现生成的函数中notify既实现了值接收类型的函数,又实现了指针接收类型的函数, 所以notify对于user*user类型都可以调用
ptrnoitfy函数只实现了指针接收 类型的函数,没有实现值接收类型的函数,所以无法通过user类型调用这个函数
那么为什么会有这个限制呢? 因为编辑器不是总能自动获取一个值得地址。, 看一下下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
package main

import "fmt"

type duration int

func (d *duration) pretty() string {
return fmt.Sprintf("Duration: %d", *d)
}

func main() {
duration(42).pretty()
}

运行时报错:

1
2
3
# command-line-arguments
./main2.go:12:14: cannot call pointer method on duration(42)
./main2.go:12:14: cannot take the address of duration(42)

如果改一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
package main

import "fmt"

type duration int

func (d *duration) pretty() string {
return fmt.Sprintf("Duration: %d", *d)
}

func main() {
d := duration(42)
d.pretty()
}

则可以正常运行,证明第一种方式没有中间变量, 所以duration(42)是一个常量,常量无法取地址。

Go 和 interface 探究
go addressable 详解
《go in action 中文版》p98-p103

Go Interface 源码解析

本文基于go1.12.4源码

源码

类型定义:

runtime/runtime2.go

不含methodinterface

1
2
3
4
type eface struct {
_type *_type
data unsafe.Pointer
}

包含methodinterface

1
2
3
4
type iface struct {
tab *itab
data unsafe.Pointer
}

eface 分析

首先写一个eface的具体case:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package main

import (
"fmt"
"unsafe"
)

func main() {
num := 3
var inter interface{} = num
e := *(*eface)(unsafe.Pointer(&inter))
fmt.Printf("%+v\n", e)
fmt.Printf("%+v\n", *(*int)(e.data))
fmt.Printf("%+v\n", &num)
}

type eface struct {
_type *_type
data unsafe.Pointer
}

type _type struct {
size uintptr
ptrdata uintptr // size of memory prefix holding all pointers
hash uint32
tflag tflag
align uint8
fieldalign uint8
kind uint8
alg *typeAlg
// gcdata stores the GC type data for the garbage collector.
// If the KindGCProg bit is set in kind, gcdata is a GC program.
// Otherwise it is a ptrmask bitmap. See mbitmap.go for details.
gcdata *byte
str nameOff
ptrToThis typeOff
}

type tflag uint8

type typeAlg struct {
// function for hashing objects of this type
// (ptr to object, seed) -> hash
hash func(unsafe.Pointer, uintptr) uintptr
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
equal func(unsafe.Pointer, unsafe.Pointer) bool
}

type nameOff int32
type typeOff int32

对应汇编:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"".main STEXT size=782 args=0x0 locals=0x128
0x0000 00000 (main.go:8) TEXT "".main(SB), ABIInternal, $296-0
...
0x0032 00050 (main.go:9) LEAQ type.int(SB), AX ; 通过 type.int 把 int 类型 转换为 *_type 类型,并把地址放到寄存器 AX
0x0039 00057 (main.go:9) MOVQ AX, (SP); 把 AX 内容放到栈底, 这里是下一个函数 newobject 调用的参数
0x003d 00061 (main.go:9) CALL runtime.newobject(SB) ; 函数调用
0x0042 00066 (main.go:9) MOVQ 8(SP), AX; newobject 返回值放到8(SP), 然后再放到 AX
0x0047 00071 (main.go:9) MOVQ AX, "".&num+128(SP); 返回值从 AX 放到 &num 变量位置, 表示 num 的地址
0x004f 00079 (main.go:9) MOVQ $3, (AX); (AX)表示 AX 的地址对应的值, 赋值为 3
0x0056 00086 (main.go:10) MOVQ "".&num+128(SP), AX ; num 地址的值赋值个 AX
0x005e 00094 (main.go:10) MOVQ (AX), AX ; AX 值是地址,其对应的值赋值给 AX, 也就是常量 3
0x0061 00097 (main.go:10) MOVQ AX, ""..autotmp_7+64(SP); 把这个值赋值给一个临时变量 autotmp_7
0x0066 00102 (main.go:10) MOVQ AX, (SP) ; 把值赋值给SP,栈底,是下一个函数 convT64 的参数
0x006a 00106 (main.go:10) CALL runtime.convT64(SB); 调用 runtime.convT64, 参数为 uint64, 返回值为 unsafe.Pointer
0x006f 00111 (main.go:10) MOVQ 8(SP), AX; convT64 返回值放到8(SP), 并且赋值到 AX
0x0074 00116 (main.go:10) MOVQ AX, ""..autotmp_8+80(SP); AX 的值放到临时变量 `autotmp_8`
0x0079 00121 (main.go:10) LEAQ type.int(SB), CX ; 通过 type.int 把 int 类型 转换为 *_type 类型,并把地址放到寄存器 CX
0x0080 00128 (main.go:10) MOVQ CX, "".inter+136(SP); 将 CX 中代表*_type 地址的值放到 inter 变量eface类型的 _type 变量中
0x0088 00136 (main.go:10) MOVQ AX, "".inter+144(SP); 将 AX 中代表 convT64返回值 3 的 unsafe.Pointer 类型 放到 inter 变量 eface 类型的 data 中
...

这里只关注main.go:9main.go:10的处理, 对应代码为:

1
8 num := 3

初始化变量numint类型,并赋值为3, 具体的过程已经在前面汇编代码中通过注释的方式标出,下面来看一些细节:

  1. 新建 int 类型变量需要申请内存, 通过runtime.newobject来申请
  2. type.int 可以获取 int类型的 _type 结构的地址 *_type (具体实现方式被编译优化了,需要再进一步深究)
  3. MOVQ AX, (SP)是为了把前面放到AX*_type 放到栈底, 这个位置下面调用函数runtime.newobject的参数, 具体函数实现:

    1
    2
    3
    func newobject(typ *_type) unsafe.Pointer {
    return mallocgc(typ.size, typ, true)
    }
  4. 执行完函数调用用会把返回值unsafe.Pointer放到8(SP), 然后在放入 AX

  5. MOV $3,(AX) 向表示寄存器AX包含的地址对应的值设置为常量3
  6. autotmp是一个临时变量,是为了在程序内复用全局临时变量, 防止变量被修改:
    https://github.com/golang/go/issues/21557,
    https://github.com/golang/go/issues/29547
    具体需要再深入研究
1
9    var inter interface{} = num

初始化变量inter,类型为interface{}, 并且指向num, 具体过程参考上面的注释部分,一些细节:

  1. 这里会调用runtime.convT64函数,定义如下: ( 在go 1.8版本调用的是runtime.convT2E, 在go1.10调用的是runtime.convT2E64):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    func convT64(val uint64) (x unsafe.Pointer) {
    if val == 0 {
    x = unsafe.Pointer(&zeroVal[0])
    } else {
    x = mallocgc(8, uint64Type, false)
    *(*uint64)(x) = val
    }
    return
    }

    这里的入参就是 num 的值 3, 返回值是转换为uint64的值,并且申请一个地址,值为3, 注意: 这里发生了值的copy

  2. 通过最后两行赋值inter, inter类型为eface, 定义前面提过,最终给eface._typeeface.data赋值
  3. eface.type的查看这里还没有找到好的方法来查看,dlv无法深入到内置的实现。

为了查看_type 类型,定义了一个跟 runtime 内部实现一样的数据结构,并且通过unsafe强制进行数据类型转换,可以得到_type的值。
借助dlv对上面的程序进行debug:

  1. e := *(*eface)(unsafe.Pointer(&inter))处打断点
  2. 执行到上面这行后打印出e的值:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    (dlv) p e
    main.eface {
    _type: *main._type {
    size: 8,
    ptrdata: 0,
    hash: 4149441018,
    tflag: 7,
    align: 8,
    fieldalign: 8,
    kind: 130,
    alg: *(*main.typeAlg)(0x57adf0),
    gcdata: *1,
    str: 1059,
    ptrToThis: 47520,},
    data: unsafe.Pointer(0xc000080018),}

    下面是*_type源码的定义:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    // 所有类型信息结构体的公共部分
    // src/rumtime/runtime2.go
    type _type struct {
    size uintptr // 类型的大小
    ptrdata uintptr // size of memory prefix holding all pointers
    hash uint32 // 类型的Hash值
    tflag tflag // 类型的Tags
    align uint8 // 结构体内对齐
    fieldalign uint8 // 结构体作为field时的对齐
    kind uint8 // 类型编号 定义于runtime/typekind.go
    alg *typeAlg // 类型元方法 存储hash和equal两个操作。map key便使用key的_type.alg.hash(k)获取hash值
    gcdata *byte // GC相关信息
    str nameOff // 类型名字的偏移
    ptrToThis typeOff
    }

为了查看eface.data的值,可以通过

1
fmt.Printf("%+v\n", *(*int)(e.data))

输出,可以看到运行结果为3, 正是num的值。

常量

对于上面的程序我们只修改两行:

1
2
3
4
5
func main() {
num := 3
var inter interface{} = num
e := *(*eface)(unsafe.Pointer(&inter))
}

改为:

1
2
3
4
func main() {
var inter interface{} = 3
e := *(*eface)(unsafe.Pointer(&inter))
}

然后再查看其对应的汇编:

1
2
3
4
0x002f 00047 (main.go:9)    LEAQ    type.int(SB), AX ; int 转为的_type类型
0x0036 00054 (main.go:9) MOVQ AX, "".inter+96(SP) ; 赋值给 inter._type
0x003b 00059 (main.go:9) LEAQ "".statictmp_0(SB), AX ; 取一个静态变量的是到 AX
0x0042 00066 (main.go:9) MOVQ AX, "".inter+104(SP) ; 赋值给 inter.data

statictmp_0代表的是一个全局静态变量,值是3:

1
2
"".statictmp_0 SRODATA size=8
0x0000 03 00 00 00 00 00 00 00

对于statictmp 时一个全局变量,一半的常量为了节省空间都会使用这个来代替,数据来源于’runtime/iface.go`:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// staticbytes is used to avoid convT2E for byte-sized values.
var staticbytes = [...]byte{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
}

在编译阶段会调用这个, 在src/cmd/compile/internal/gc/walk.go 中有相关代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
if staticbytes == nil {
staticbytes = newname(Runtimepkg.Lookup("staticbytes")) // 重 runtime 包中查找这个变量
staticbytes.SetClass(PEXTERN)
staticbytes.Type = types.NewArray(types.Types[TUINT8], 256)
zerobase = newname(Runtimepkg.Lookup("zerobase"))
zerobase.SetClass(PEXTERN)
zerobase.Type = types.Types[TUINTPTR]
}

// Optimize convT2{E,I} for many cases in which T is not pointer-shaped,
// by using an existing addressable value identical to n.Left
// or creating one on the stack.
var value *Node
switch {
case fromType.Size() == 0:
// n.Left is zero-sized. Use zerobase.
cheapexpr(n.Left, init) // Evaluate n.Left for side-effects. See issue 19246.
value = zerobase
case fromType.IsBoolean() || (fromType.Size() == 1 && fromType.IsInteger()):
// n.Left is a bool/byte. Use staticbytes[n.Left].
n.Left = cheapexpr(n.Left, init)
value = nod(OINDEX, staticbytes, byteindex(n.Left)) // 编译时使用
value.SetBounded(true)

iface 分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package main

import (
"fmt"
"unsafe"
)

type Mather interface {
Add(a, b int32) int32
Sub(a, b int64) int64
}

type Adder struct{ id int32 }

//go:noinline
func (adder Adder) Add(a, b int32) int32 { return a + b }

//go:noinline
func (adder Adder) Sub(a, b int64) int64 { return a - b }

func main() {
adder := Adder{id: 6754}
m := Mather(adder)
m.Add(12, 2)
m.Sub(19, 4)
i := *(*iface)(unsafe.Pointer(&m))
fmt.Printf("%+v\n", i)
fmt.Printf("%+v\n", *(*Adder)(i.data))
}

type iface struct {
tab *itab
data unsafe.Pointer
}

type itab struct {
inter *interfacetype
_type *_type
hash uint32 // copy of _type.hash. Used for type switches.
_ [4]byte
fun [1]uintptr // variable sized. fun[0]==0 means _type does not implement inter
}

type interfacetype struct {
typ _type
pkgpath name
mhdr []imethod
}

// See reflect/type.go for details.
type name struct {
bytes *byte
}
type imethod struct {
name nameOff
ityp typeOff
}

type _type struct {
size uintptr
ptrdata uintptr // size of memory prefix holding all pointers
hash uint32
tflag tflag
align uint8
fieldalign uint8
kind uint8
alg *typeAlg
// gcdata stores the GC type data for the garbage collector.
// If the KindGCProg bit is set in kind, gcdata is a GC program.
// Otherwise it is a ptrmask bitmap. See mbitmap.go for details.
gcdata *byte
str nameOff
ptrToThis typeOff
}

type tflag uint8

type typeAlg struct {
// function for hashing objects of this type
// (ptr to object, seed) -> hash
hash func(unsafe.Pointer, uintptr) uintptr
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
equal func(unsafe.Pointer, unsafe.Pointer) bool
}

type nameOff int32
type typeOff int32

对应的汇编代码:

1
2
3
4
5
6
7
8
9
10
"".main STEXT size=377 args=0x0 locals=0xc8
0x002f 00047 (main.go:22) MOVL $0, "".adder+68(SP) ; 初始化一个addr, 默认值都是空的
0x0037 00055 (main.go:22) MOVL $6754, "".adder+68(SP) ; 对 Addr.id 字段进行赋值
0x003f 00063 (main.go:23) MOVL $6754, (SP) ; 将id的值放到栈底,作为 convT32的参数
0x0046 00070 (main.go:23) CALL runtime.convT32(SB) ; 入参是 id, 输出的 值是转换后的unsafe.Pointer
0x004b 00075 (main.go:23) MOVQ 8(SP), AX ; 返回值从 8(SP) 的位置,复制到寄存器 AX
0x0050 00080 (main.go:23) MOVQ AX, ""..autotmp_5+80(SP) ; 将返回值从 AX 复制到临时变量autotmp_5 的位置
0x0055 00085 (main.go:23) LEAQ go.itab."".Adder,"".Mather(SB), CX ; 将 Adder 的 itab转换为 Mather 类型,并将地址放到 CX
0x005c 00092 (main.go:23) MOVQ CX, "".m+104(SP) ; 将 CX 寄存器中的 itab 地址赋值给 m.tab
0x0061 00097 (main.go:23) MOVQ AX, "".m+112(SP) ; 将 AX 寄存器中的 unsafe.Pointer 的 表示的值,赋值给 m.data

重点关注main.go:22main.go:23, 首先看一下main.go:22:

1
adder := Adder{id: 6754}

需要注意的细节:

  1. struct 的 赋值是先对其赋值为空,然后再一个字段一个字段赋值,字段在栈中的排列跟定义的顺序有关系,是紧密排列的,并且存在对齐的问题

然后看如何把Adder类型转换为Mather接口类型的:

1
m := Mather(adder)

具体赋值的 步骤前面汇编部分的注释已经说明了, 这里需要注意的几个细节:

  1. iface类型与eface不通,有tabdata 两个字段,分别是*itabunsafe.Pointer两个类型
  2. go.itab."".Adder,"".Mather(SB), CX 将 Adder 的 itab转换为 Mather 类型,并将地址放到 CX, 具体如何实现的,还没有找到查看方法,需要继续研究

itab类型也无法直接查看,这里通过unsafe进行转换,在通过dlv进行查看:

  1. 转换代码:

    1
    i := *(*iface)(unsafe.Pointer(&m))
  2. 通过 dlv 在此处打断点查看:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    (dlv) p i
    main.iface {
    tab: *main.itab {
    inter: *(*main.interfacetype)(0x4c0360),
    _type: *(*main._type)(0x4c63e0),
    hash: 1633631626,
    _: [4]uint8 [0,0,0,0],
    fun: [1]uintptr [4867136],},
    data: unsafe.Pointer(0xc000080010),}

    (dlv) p i.tab
    *main.itab {
    inter: *main.interfacetype {
    typ: (*main._type)(0x4c0360),
    pkgpath: (*main.name)(0x4c0390),
    mhdr: []main.imethod len: 2, cap: 2, [
    (*main.imethod)(0x4c03c0),
    (*main.imethod)(0x4c03c8),
    ],},
    _type: *main._type {
    size: 4,
    ptrdata: 0,
    hash: 1633631626,
    tflag: 7,
    align: 4,
    fieldalign: 4,
    kind: 153,
    alg: *(*main.typeAlg)(0x57bde0),
    gcdata: *1,
    str: 14386,
    ptrToThis: 105952,},
    hash: 1633631626,
    _: [4]uint8 [0,0,0,0],
    fun: [1]uintptr [4867136],}

关于itab类型是包含接口的静态类型信息、数据的动态类型信息、函数表的结构, 在源码中的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
type itab struct {
inter *interfacetype // 本实例所实现的接口的类型信息, 用于定位到具体的 interface, 这个是在编译时确定的
_type *_type // 本实例的具体数据的类型信息, 参考前面 _type 类型的定义
hash uint32 // copy of _type.hash. Used for type switches.
_ [4]byte
// fun 表示的 interface 里面的 method 的具体实现
// 这里放置和接口方法对应的具体数据类型的方法地址
// 实现接口调用方法的动态分派,一般在每次给接口赋值发生转换时
// 会更新此表,或者直接拿缓存的itab
fun [1]uintptr // variable sized. fun[0]==0 means _type does not implement inter.
}

// interfacetype 只是对于 _type 的一种包装,在其顶部空间还包装了额外的 interface 相关的元信息
type interfacetype struct {
typ _type // 所实现的接口的类型
pkgpath name // 所实现的接口的定义路径
mhdr []imethod // 所实现的接口在定义时的函数声明列表
}

//这里的 method 只是一种函数声明的抽象,比如 func Print() error
type imethod struct {
name nameOff
ityp typeOff
}

需要注意的点:

  1. func表示的 interface 里面的 method 的具体实现, 比如这里的两个方法SubAdd, 但是func的长度为1, 该如何表示多个方法呢?
    看一下函数调用:
    1
    2
    m.Add(12, 2)
    m.Sub(19, 4)

对应的汇编:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
0x0066 00102 (main.go:24)   TESTB   AL, (CX); 求与 AL & (CX), 检查 CX 是否为 nil, AL 是 AX 的低8位, AH 是 AX 的高8位
0x0068 00104 (main.go:24) MOVQ go.itab."".Adder,"".Mather+24(SB), CX ; Add函数的入口地址放到 CX
0x006f 00111 (main.go:24) MOVQ AX, (SP)
0x0073 00115 (main.go:24) MOVQ $8589934604, AX ; 将 12, 2两个参数赋值到 AX 中
0x007d 00125 (main.go:24) MOVQ AX, 8(SP); 将 AX 中的参数赋值到 8(SP) 位置
0x0082 00130 (main.go:24) CALL CX ; 调用 m.Add
0x0084 00132 (main.go:25) MOVQ "".m+104(SP), AX ; 前面看到这个位置是 m.tab 的值
0x0089 00137 (main.go:25) TESTB AL, (AX); 检查 AX 是否为 nil
0x008b 00139 (main.go:25) MOVQ 32(AX), AX ; (AX)地址 + 32 偏移,指向 Sub
0x008f 00143 (main.go:25) MOVQ "".m+112(SP), CX ; 这个位置是 m.data 的值
0x0094 00148 (main.go:25) MOVQ CX, (SP)
0x0098 00152 (main.go:25) MOVQ $19, 8(SP) ; 把参数 19 放到 8(SP) 位置
0x00a1 00161 (main.go:25) MOVQ $4, 16(SP) ; 把参数 4 放到 16(SP) 位置
0x00aa 00170 (main.go:25) CALL AX ; 调用 m.Sub

需要注意的细节:

  1. TESTB AL, (CX)是把 AL & (CX) 位与的值放到 (CX) 中, 参考: https://github.com/golang/go/issues/10432 & https://github.com/golang/go/issues/27180, 这个步骤其实是为了检查 CX 是否为 nil, 如果是 nil 就没法调用这个函数了

  2. MOVQ go.itab."".Adder,"".Mather+24(SB), CX 这个为什么是取到了Add的地址?
    看一下itab的定义:

    1
    2
    3
    4
    5
    6
    7
    type itab struct {
    inter *interfacetype
    _type *_type
    hash uint32
    _ [4]byte
    fun [1]uintptr
    }

    当前机器是64位的,所以可以看出func 相对于itab起始地址的偏移量为:
    8(*interfacetype) + 8(*_type) + 4(uint32) + 4(byte=uint8) = 24
    所以 MOVQ go.itab."".Adder,"".Mather+24(SB) 其实就是func的第一个函数Add的地址

  3. 函数Sub的地址为什么是32(AX)?
    可以从前面一句MOVQ "".m+104(SP), AX 得出: AX目前指向的是m.tab, 也就是itab类型的起始地址,32(AX)就是相对AX有32位的偏移,前面说了相对itab 24位的偏移其实时Add函数,然后对于64位系统,函数地址占8位,所以32(AX)就是下一个函数Sub的地址。

  4. 前面只有两个函数,我们如果调换一下接口定义中两个函数的位置,发现生成的汇编是一样的,也就是:函数顺序与定义的顺序无关, 如果增加几个函数就可以看出来,其实: 函数在func 中的顺序是按照函数名的字典顺序排列的

  5. MOVQ $8589934604, AX 为什么是参数赋值?
    这个其实我们可以对常量8589934604 进行分析,首先把它转化为二进制:

    1
    2
    echo 'obase=2;8589934604' | bc
    1000000000000000000000000000001100

    得到的数据其实是:

    1
    2
    3
    4
    5
    6
    7
    +------------------------------------------+
    | 0000001000000000000000000000000000001100 |
    +------------------------------------------+
    \______/\______________________________/
    +---+ +----+
    | 2 | | 12 |
    +---+ +----+

    其实就是212两个8字节的数据组合在一起放到了AX寄存器中, 正是Add(12,2)的两个参数。

断言

interface{} 是一个抽象的类型,如果需要转换为具体的类型,则需要类型断言, 类型断言其实有两个:

  1. 类型判断: 判断类型是否一致
  2. 类型转换: 类型一致取出具体的数据

下面看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
package main

var j uint32
var r int32
var ok bool
var eface interface{}

func assertion() {
i := uint64(42)
eface = i
j = eface.(uint32)
r, ok = eface.(int32)
}

对应的汇编语言如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
0x0066 00102 (eface.go:11)  MOVL    $0, ""..autotmp_1+36(SP) ; 初始化0值
0x006e 00110 (eface.go:11) MOVQ "".eface+8(SB), AX ; 把 data 放到 AX 寄存器
0x0075 00117 (eface.go:11) MOVQ "".eface(SB), CX ; 把 _type 放到 CX 寄存器
0x007c 00124 (eface.go:11) LEAQ type.uint32(SB), DX ; 把 uint32的_type 值放到 DX 寄存器
0x0083 00131 (eface.go:11) CMPQ CX, DX; 比较 eface._type == uint32 ?
0x0086 00134 (eface.go:11) JEQ 138 ; JEQ = jump if equal, 如果类型相等就跳转到 138行
0x0088 00136 (eface.go:11) JMP 246 ; 类型不匹配, 跳转到 246 行, 出现 panic
0x008a 00138 (eface.go:11) MOVL (AX), AX ; JEQ 跳转的行 138, 把(AX)地址对应的值放到 AX 寄存器,也就是 eface.data
0x008c 00140 (eface.go:11) MOVL AX, ""..autotmp_1+36(SP) ; 临时变量赋值
0x0090 00144 (eface.go:11) MOVL AX, "".j(SB) ; 赋值给变量 j
0x0096 00150 (eface.go:12) MOVQ "".eface+8(SB), AX ; 把 data 放到 AX 寄存器
0x009d 00157 (eface.go:12) LEAQ type.int32(SB), CX ; 把 int32 的_type 值放到 DX 寄存器
0x00a4 00164 (eface.go:12) CMPQ "".eface(SB), CX ; 比较 eface._type == int32 ?
0x00ab 00171 (eface.go:12) JEQ 175 ; 如果类型相等就跳转到 175 行
0x00ad 00173 (eface.go:12) JMP 223 ; 跳转到 223 行,输出 panic
0x00af 00175 (eface.go:12) MOVL (AX), AX; 类型相等就把 data 放到 AX
0x00b1 00177 (eface.go:12) MOVL $1, CX 把常量 1 放到 CX
0x00b6 00182 (eface.go:12) JMP 184 ; 调到 184 行
0x00b8 00184 (eface.go:12) MOVL AX, ""..autotmp_2+32(SP)
0x00bc 00188 (eface.go:12) MOVB CL, ""..autotmp_3+31(SP) ; CL 是 CX 的低 8 位
0x00c0 00192 (eface.go:12) MOVL ""..autotmp_2+32(SP), AX
0x00c4 00196 (eface.go:12) MOVL AX, "".r(SB) ; AX 是 data 的值, 放到 r 变量中
0x00ca 00202 (eface.go:12) MOVBLZX ""..autotmp_3+31(SP), AX; MOVBLZX 用 0 扩展,放到 autotmp_3 变量
0x00cf 00207 (eface.go:12) MOVB AL, "".ok(SB); AL 低8位赋值给 ok ,因为ok 是 bool 类型的, 根据字节对齐,占 8 位
0x00d5 00213 (eface.go:13) MOVQ 56(SP), BP
0x00da 00218 (eface.go:13) ADDQ $64, SP
0x00de 00222 (eface.go:13) RET
0x00df 00223 (eface.go:13) XORL AX, AX ; eface.type != int32 情况下,执行本行, XOR是异或,所以 AX^AX , 结果为 0
0x00e1 00225 (eface.go:13) XORL CX, CX; 同上, CX 结果为 0
0x00e3 00227 (eface.go:12) JMP 184 ; 跳转到 184 行执行,这里要注意的是 AX, CX 寄存器已经为0, 所有后面 ok 的值也位0了
0x00e5 00229 (eface.go:10) LEAQ "".eface+8(SB), DI
0x00ec 00236 (eface.go:10) CALL runtime.gcWriteBarrier(SB)
0x00f1 00241 (eface.go:10) JMP 102
0x00f6 00246 (eface.go:11) MOVQ CX, (SP) ; 一个返回值表达式类型不匹配时,执行到这里, CX 类型值放到(SP)作为第一个参数
0x00fa 00250 (eface.go:11) MOVQ DX, 8(SP); 想要的类型从 DX 放到 8(SP) 作为第二个参数
0x00ff 00255 (eface.go:11) LEAQ type.interface {}(SB), AX ; interface 类型的地址放到 AX
0x0106 00262 (eface.go:11) MOVQ AX, 16(SP); AX 值放到 16(SP) 作为第三个参数
0x010b 00267 (eface.go:11) CALL runtime.panicdottypeE(SB); 执行函数调用,使用前面的三个参数,返回 panic

需要注意的点:

  1. 当使用返回值为一个的表达式时,如果出现类型不匹配,会触发panic
  2. 当使用两个返回值的表达式时, r, ok的值随着AX, CX的值 改变:
    分为两种情况:
    当类型相等时: AX 值为eface.data, CX 的值为1
    赋值的过程如下:

    1
    2
    3
    4
    5
    6
    0x00b8 00184 (eface.go:12)  MOVL    AX, ""..autotmp_2+32(SP)
    0x00bc 00188 (eface.go:12) MOVB CL, ""..autotmp_3+31(SP) ; CL 是 CX 的低 8 位, CX 是 1, 二进制是: 0000000000000001; CL 就是: 00000001
    0x00c0 00192 (eface.go:12) MOVL ""..autotmp_2+32(SP), AX
    0x00c4 00196 (eface.go:12) MOVL AX, "".r(SB) ; AX 是 data 的值, 放到 r 变量中
    0x00ca 00202 (eface.go:12) MOVBLZX ""..autotmp_3+31(SP), AX; MOVBLZX 用 0 扩展,放到 autotmp_3 变量, autotmp_3 是 00000001, 扩展后是: 0000000000000001
    0x00cf 00207 (eface.go:12) MOVB AL, "".ok(SB); AL 低8位赋值给 ok ,因为ok 是 bool 类型的, 根据字节对齐,占 8 位, ok 值为: 00000001

    当类型不相等时: AXCX的值都初始化位空

    1
    2
    3
    0x00df 00223 (eface.go:13)  XORL    AX, AX ; eface.type != int32 情况下,执行本行, XOR是异或,所以 AX^AX , 结果为 0
    0x00e1 00225 (eface.go:13) XORL CX, CX; 同上, CX 结果为 0
    0x00e3 00227 (eface.go:12) JMP 184 ; 跳转到 184 行执行,这里要注意的是 AX, CX 寄存器已经为0, 所有后面 ok 的值也位0了

    赋值过程如下:

    1
    2
    3
    4
    5
    6
    0x00b8 00184 (eface.go:12)  MOVL    AX, ""..autotmp_2+32(SP)
    0x00bc 00188 (eface.go:12) MOVB CL, ""..autotmp_3+31(SP) ; CL 是 CX 的低 8 位, CX 是 0, 二进制是: 0000000000000000; CL 就是: 00000000
    0x00c0 00192 (eface.go:12) MOVL ""..autotmp_2+32(SP), AX
    0x00c4 00196 (eface.go:12) MOVL AX, "".r(SB) ; AX 是 data 的值, 放到 r 变量中, AX 是空值,所以 r == nil
    0x00ca 00202 (eface.go:12) MOVBLZX ""..autotmp_3+31(SP), AX; MOVBLZX 用 0 扩展,放到 autotmp_3 变量, autotmp_3 是 00000000, 扩展后是: 0000000000000000
    0x00cf 00207 (eface.go:12) MOVB AL, "".ok(SB); AL 低8位赋值给 ok ,因为ok 是 bool 类型的, 根据字节对齐,占 8 位, ok 值为: 00000000

参考文献,

理解Go语言模型(1):interface底层详解
Go Data Structures: Interfaces
Interface Semantics
go-internals chapter2 interfacs

go converte between string and byte slice

String

Go第一版代码c实现, 在runtime/runtime.h里:

1
2
3
4
5
6
7
8
9
typedef struct  String      String;

struct String
{
byte* str;
intgo len;
};

extern String runtime·emptystring;

可以看到Go中的string类型其实就是String这个类型。
之后Go实现了自举,从runtime/string.go中可以看到之前的影子:

1
2
3
4
type stringStruct struct {
str unsafe.Pointer
len int
}

todo: 如何通过编译过程查找对应的类型定义

Byte

byte的类型定义在 builtin/builtin.go中:

1
2
3
4
5
6
7
8
// uint8 is the set of all unsigned 8-bit integers.
// Range: 0 through 255.
type uint8 uint8

// byte is an alias for uint8 and is equivalent to uint8 in all ways. It is
// used, by convention, to distinguish byte values from 8-bit unsigned
// integer values.
type byte = uint8

可以看到其实byteuint8的类型别名

Slice

1
2
3
4
5
type slice struct {
array unsafe.Pointer
len int
cap int
}

string to byte slice

写一个string强制类型转换为[]bytedemo:

1
2
3
4
5
6
7
8
9
package main

import "fmt"

func main() {
var s = "strings"
var b = []byte(s)
fmt.Printf("%v\n", b)
}

通过命令:

1
go tool compile -S -N -l main.go

编译出汇编指令:

1
2
3
4
5
6
7
8
9
10
11
"".main STEXT size=317 args=0x0 locals=0xa8
0x0000 00000 (main.go:5) TEXT "".main(SB), ABIInternal, $168-0
...
0x002f 00047 (main.go:6) LEAQ go.string."strings"(SB), AX
0x0036 00054 (main.go:6) MOVQ AX, "".s+80(SP) # 把string内容放到这个位置
0x003b 00059 (main.go:6) MOVQ $7, "".s+88(SP) # 把string长度放到这个位置
...
0x005a 00090 (main.go:7) CALL runtime.stringtoslicebyte(SB)
...
0x008e 00142 (main.go:8) CALL runtime.convTslice(SB)
...

上面可以看出当定义一个string时,其实会存储string的内容和长度, 对应前讲的string的结构:

1
2
3
4
5
struct String
{
byte* str;
intgo len;
};

然后又调用了runtime.stringtoslicebyte(SB), 在runtime/string.go中:

1
2
3
4
5
6
7
8
9
10
11
func stringtoslicebyte(buf *tmpBuf, s string) []byte {
var b []byte
if buf != nil && len(s) <= len(buf) { // 如果字符串的长度小于buf长度,直接使用buf
*buf = tmpBuf{}
b = buf[:len(s)]
} else {
b = rawbyteslice(len(s)) // 否则调用这个进行内存申请
}
copy(b, s) // 内存 copy
return b
}

buf默认值是32:

1
2
3
const tmpStringBufSize = 32

type tmpBuf [tmpStringBufSize]byte

如果不满足长度,申请的内存大小为len(s):

1
2
3
4
5
6
7
8
9
10
11
12
// rawbyteslice allocates a new byte slice. The byte slice is not zeroed.
func rawbyteslice(size int) (b []byte) {
cap := roundupsize(uintptr(size))
p := mallocgc(cap, nil, false)
if cap != uintptr(size) {
memclrNoHeapPointers(add(p, uintptr(size)), cap-uintptr(size))
}

// 下面是类型的转换,把申请的内存变成一个slice结构,赋值给b的地址
*(*slice)(unsafe.Pointer(&b)) = slice{p, size, int(cap)}
return
}

上面的过程重点有三个:

  1. 当长度小于32时,直接使用临时内存地址
  2. 当长度大于32时,需要申请新的长度为len(s)的内存地址
  3. 需要进行内存的copy

byte slice to string

下面返回来,把一个[]byte转换为string:

1
2
3
4
5
6
7
8
9
package main

import "fmt"

func main() {
var b = []byte{101, 102, 103}
var s = string(b)
fmt.Printf("%v\n", s)
}

生成的汇编代码是:

1
2
3
4
5
6
7
8
9
10
11
"".main STEXT size=371 args=0x0 locals=0xb8
0x0000 00000 (main.go:5) TEXT "".main(SB), ABIInternal, $184-0
...
0x005b 00091 (main.go:6) MOVQ AX, "".b+128(SP) # 把slice内容放到这个位置
0x0063 00099 (main.go:6) MOVQ $3, "".b+136(SP) # 把slice len 放到这个位置
0x006f 00111 (main.go:6) MOVQ $3, "".b+144(SP) # 把slice cap 放到这个位置
...
0x00a2 00162 (main.go:7) CALL runtime.slicebytetostring(SB) #调用这个函数进行转换
...
0x00c4 00196 (main.go:8) CALL runtime.convTstring(SB)
...

上面可以看出当定义一个slice时,其实会存储slice的内容和长度和容量, 对应之前讲的slice的结构:

1
2
3
4
5
type slice struct {
array unsafe.Pointer
len int
cap int
}

然后调用runtime.slicebytetostring函数, 在runtime/string.go中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// Buf is a fixed-size buffer for the result,
// it is not nil if the result does not escape.
func slicebytetostring(buf *tmpBuf, b []byte) (str string) {
l := len(b)
if l == 0 {
// Turns out to be a relatively common case.
// Consider that you want to parse out data between parens in "foo()bar",
// you find the indices and convert the subslice to string.
// 长度为0,直接返回空字符串
return ""
}

...

// 长度为1,直接返回staticbytes[b[0]]这个提前设定好的地址内容
if l == 1 {
// stringStruct结构的str字段指向对应的值得地址
stringStructOf(&str).str = unsafe.Pointer(&staticbytes[b[0]])
// stringStruct结构的len字段设置为1
stringStructOf(&str).len = 1
return
}

var p unsafe.Pointer
if buf != nil && len(b) <= len(buf) {
p = unsafe.Pointer(buf)
} else {
p = mallocgc(uintptr(len(b)), nil, false)
}
stringStructOf(&str).str = p
stringStructOf(&str).len = len(b)
memmove(p, (*(*slice)(unsafe.Pointer(&b))).array, uintptr(len(b)))
return
}

go select 原理

本篇主要介绍 select 的内部实现原理(基于go1.12), 通过源码和图形的方式展示 select 的内部结构及对select 进行操作的过程。

基本语法

Go官方给出的例子很简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package main

import "fmt"

func fibonacci(c, quit chan int) {
x, y := 0, 1
for {
select {
case c <- x:
x, y = y, x+y
case <-quit:
fmt.Println("quit")
return
}
}
}

func main() {
c := make(chan int)
quit := make(chan int)
go func() {
for i := 0; i < 10; i++ {
fmt.Println(<-c)
}
quit <- 0
}()
fibonacci(c, quit)
}

可以看出SelectSwitch使用方法有点相似,不一样的地方在于:Selectcase条件必须是与chan相关的操作(从chan发送或者接收数据)

编译

关于select的编译过程可以从$GOROOT/src/cmd/compile/internal/gc/select.go中找到。

工具使用

为了研究代码的运行逻辑,我们可以借助针对Go开发的debug工具:dlv, 我们通过:

1
go build  -gcflags="all=-N -l" $GOROOT/src/cmd/compile

这样我们就可以使用对compile工具进行debug了, 通过下面的方式运行:

1
dlv exec compile $GOROOT/src/cmd/compile/internal/gc/select.go

然后分别给我们想要debug的地方打断点

编译过程

我们对typecheckselectwalkselect函数打断点,可以知道运行顺序是typecheckselect -> walkselect
首先我们先来看typecheckselect函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// select
func typecheckselect(sel *Node) {
var def *Node
lno := setlineno(sel)
typecheckslice(sel.Ninit.Slice(), ctxStmt)
// 遍历检查所有的case
for _, ncase := range sel.List.Slice() {
// 处理之前对每个case先进行检查是否是空的
if ncase.Op != OXCASE {
setlineno(ncase)
Fatalf("typecheckselect %v", ncase.Op)
}
// case 后面是空条件,这种情况说明是default
if ncase.List.Len() == 0 {
// default
if def != nil {
yyerrorl(ncase.Pos, "multiple defaults in select (first at %v)", def.Line())
} else {
def = ncase
}
} else if ncase.List.Len() > 1 { // case的值不支持多个表达式
yyerrorl(ncase.Pos, "select cases cannot be lists")
} else { // case 只有一个表达式
ncase.List.SetFirst(typecheck(ncase.List.First(), ctxStmt))
n := ncase.List.First() // 把case的第一个表达式赋值给 n
ncase.Left = n
ncase.List.Set(nil)
switch n.Op { // 对 case 的具体操作进行检查
default: // 对于未知类型的case 进行下面的处理
pos := n.Pos
if n.Op == ONAME {
// We don't have the right position for ONAME nodes (see #15459 and
// others). Using ncase.Pos for now as it will provide the correct
// line number (assuming the expression follows the "case" keyword
// on the same line). This matches the approach before 1.10.
pos = ncase.Pos
}
// 打印错误,只接受下面的几个类型
yyerrorl(pos, "select case must be receive, send or assign recv")

// convert x = <-c into OSELRECV(x, <-c).
// remove implicit conversions; the eventual assignment
// will reintroduce them.
// 处理 case 为 x = <-c 的表达式
case OAS:
if (n.Right.Op == OCONVNOP || n.Right.Op == OCONVIFACE) && n.Right.Implicit() {
n.Right = n.Right.Left
}

if n.Right.Op != ORECV {
yyerrorl(n.Pos, "select assignment must have receive on right hand side")
break
}

n.Op = OSELRECV

// convert x, ok = <-c into OSELRECV2(x, <-c) with ntest=ok
// 处理 case 为 x, ok = <-c 的表达式
case OAS2RECV:
if n.Rlist.First().Op != ORECV {
yyerrorl(n.Pos, "select assignment must have receive on right hand side")
break
}

n.Op = OSELRECV2
n.Left = n.List.First()
n.List.Set1(n.List.Second())
n.Right = n.Rlist.First()
n.Rlist.Set(nil)

// convert <-c into OSELRECV(N, <-c)
// 处理 case 为 <-c 的表达式
case ORECV:
n = nodl(n.Pos, OSELRECV, nil, n)

n.SetTypecheck(1)
ncase.Left = n

case OSEND: // 无需要做特殊处理
break
}
}

typecheckslice(ncase.Nbody.Slice(), ctxStmt)
}

lineno = lno
}

然后再来看walkselect函数, 这个函数主要是对每个case进行处理,真正处理每个case的函数是walkselectcases:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
func walkselectcases(cases *Nodes) []*Node {
n := cases.Len()
sellineno := lineno

// optimization: zero-case select
if n == 0 {
return []*Node{mkcall("block", nil, nil)}
}

// optimization: one-case select: single op.
// TODO(rsc): Reenable optimization once order.go can handle it.
// golang.org/issue/7672.
// 处理只有一个 case 的情况
// 处理结果是优化成: if xx {}
if n == 1 {
cas := cases.First()
setlineno(cas)
l := cas.Ninit.Slice()
if cas.Left != nil { // not default:
n := cas.Left
l = append(l, n.Ninit.Slice()...)
n.Ninit.Set(nil)
var ch *Node
switch n.Op { // 根据操作符对齐进行词法分析,重新构造AST
default:
Fatalf("select %v", n.Op)

// ok already
case OSEND:
ch = n.Left

case OSELRECV, OSELRECV2:
ch = n.Right.Left
if n.Op == OSELRECV || n.List.Len() == 0 {
if n.Left == nil {
n = n.Right
} else {
n.Op = OAS // 转化为 Left = Right 表达式
}
break
}

if n.Left == nil {
nblank = typecheck(nblank, ctxExpr|ctxAssign)
n.Left = nblank
}

n.Op = OAS2 // 转化为多赋值表达式: List = Rlist (x, y, z = a, b, c)
n.List.Prepend(n.Left)
n.Rlist.Set1(n.Right)
n.Right = nil
n.Left = nil
n.SetTypecheck(0)
n = typecheck(n, ctxStmt)
}
// if ch == nil { block() }; n; // 转换为 if ch == nil { block() } 表达式
a := nod(OIF, nil, nil) // 转化为: if Ninit; Left { Nbody } else { Rlist }

a.Left = nod(OEQ, ch, nodnil()) // a.Left 转化为: Left == Right, 既: ch == nil
var ln Nodes
ln.Set(l)
a.Nbody.Set1(mkcall("block", nil, &ln)) // Body 变为 block 调用
l = ln.Slice()
a = typecheck(a, ctxStmt)
l = append(l, a, n)
}

l = append(l, cas.Nbody.Slice()...)
l = append(l, nod(OBREAK, nil, nil))
return l
}

// convert case value arguments to addresses.
// this rewrite is used by both the general code and the next optimization.
// 存在多个 case, 分别处理
for _, cas := range cases.Slice() {
setlineno(cas)
n := cas.Left
if n == nil {
continue
}
switch n.Op {
case OSEND:
...
case OSELRECV, OSELRECV2:
...
}
}

// optimization: two-case select but one is default: single non-blocking op.
// 处理只有两个 case, 并且其中一个是 default 的情况
if n == 2 && (cases.First().Left == nil || cases.Second().Left == nil) {
var cas *Node // 非 default case
var dflt *Node // default case
if cases.First().Left == nil {
cas = cases.Second()
dflt = cases.First()
} else {
dflt = cases.Second()
cas = cases.First()
}

n := cas.Left
setlineno(n)
r := nod(OIF, nil, nil)
r.Ninit.Set(cas.Ninit.Slice())
switch n.Op {
default:
Fatalf("select %v", n.Op)

case OSEND:
// if selectnbsend(c, v) { body } else { default body }
ch := n.Left
r.Left = mkcall1(chanfn("selectnbsend", 2, ch.Type), types.Types[TBOOL], &r.Ninit, ch, n.Right)

case OSELRECV:
// if selectnbrecv(&v, c) { body } else { default body }
...
r.Left = mkcall1(chanfn("selectnbrecv", 2, ch.Type), types.Types[TBOOL], &r.Ninit, elem, ch)
case OSELRECV2:
// if selectnbrecv2(&v, &received, c) { body } else { default body }
...
r.Left = mkcall1(chanfn("selectnbrecv2", 2, ch.Type), types.Types[TBOOL], &r.Ninit, elem, receivedp, ch)
}

r.Left = typecheck(r.Left, ctxExpr)
r.Nbody.Set(cas.Nbody.Slice())
r.Rlist.Set(append(dflt.Ninit.Slice(), dflt.Nbody.Slice()...))
return []*Node{r, nod(OBREAK, nil, nil)}
}

var init []*Node

// generate sel-struct
lineno = sellineno
selv := temp(types.NewArray(scasetype(), int64(n)))
r := nod(OAS, selv, nil)
r = typecheck(r, ctxStmt)
init = append(init, r)

order := temp(types.NewArray(types.Types[TUINT16], 2*int64(n)))
r = nod(OAS, order, nil)
r = typecheck(r, ctxStmt)
init = append(init, r)

// register cases
for i, cas := range cases.Slice() { // 其它 case 的情况处理
setlineno(cas)

init = append(init, cas.Ninit.Slice()...)
cas.Ninit.Set(nil)

// Keep in sync with runtime/select.go.
const (
caseNil = iota
caseRecv
caseSend
caseDefault
)

var c, elem *Node
var kind int64 = caseDefault

if n := cas.Left; n != nil {
init = append(init, n.Ninit.Slice()...)

switch n.Op {
default:
Fatalf("select %v", n.Op)
case OSEND:
kind = caseSend
c = n.Left
elem = n.Right
case OSELRECV, OSELRECV2:
kind = caseRecv
c = n.Right.Left
elem = n.Left
}
}

setField := func(f string, val *Node) {
r := nod(OAS, nodSym(ODOT, nod(OINDEX, selv, nodintconst(int64(i))), lookup(f)), val)
r = typecheck(r, ctxStmt)
init = append(init, r)
}

setField("kind", nodintconst(kind))
if c != nil {
c = convnop(c, types.Types[TUNSAFEPTR])
setField("c", c)
}
if elem != nil {
elem = convnop(elem, types.Types[TUNSAFEPTR])
setField("elem", elem)
}

// TODO(mdempsky): There should be a cleaner way to
// handle this.
if instrumenting {
r = mkcall("selectsetpc", nil, nil, bytePtrToIndex(selv, int64(i)))
init = append(init, r)
}
}
// run the select
lineno = sellineno
chosen := temp(types.Types[TINT])
recvOK := temp(types.Types[TBOOL])
r = nod(OAS2, nil, nil)
r.List.Set2(chosen, recvOK)
fn := syslook("selectgo")
r.Rlist.Set1(mkcall1(fn, fn.Type.Results(), nil, bytePtrToIndex(selv, 0), bytePtrToIndex(order, 0), nodintconst(int64(n))))
r = typecheck(r, ctxStmt)
init = append(init, r)

// selv and order are no longer alive after selectgo.
init = append(init, nod(OVARKILL, selv, nil))
init = append(init, nod(OVARKILL, order, nil))

// dispatch cases
for i, cas := range cases.Slice() {
setlineno(cas)

cond := nod(OEQ, chosen, nodintconst(int64(i)))
cond = typecheck(cond, ctxExpr)
cond = defaultlit(cond, nil)

r = nod(OIF, cond, nil)

if n := cas.Left; n != nil && n.Op == OSELRECV2 {
x := nod(OAS, n.List.First(), recvOK)
x = typecheck(x, ctxStmt)
r.Nbody.Append(x)
}

r.Nbody.AppendNodes(&cas.Nbody)
r.Nbody.Append(nod(OBREAK, nil, nil))
init = append(init, r)
}

return init
}

针对select不同case的情况编译的方式不用:

没有case

看一下相关代码:

1
2
3
if n == 0 {
return []*Node{mkcall("block", nil, nil)}
}

直接调用block
由于没又对应的chan处理,所以当前goroutine进入休眠状态,无法被唤醒

只有1case

对应源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
if n == 1 {
cas := cases.First()
setlineno(cas)
l := cas.Ninit.Slice()
if cas.Left != nil { // not default:
...
switch n.Op {
default:
Fatalf("select %v", n.Op)

// ok already
case OSEND:
ch = n.Left

case OSELRECV, OSELRECV2:
...
}

// if ch == nil { block() }; n;
a := nod(OIF, nil, nil)

a.Left = nod(OEQ, ch, nodnil())
var ln Nodes
ln.Set(l)
a.Nbody.Set1(mkcall("block", nil, &ln))
l = ln.Slice()
a = typecheck(a, ctxStmt)
l = append(l, a, n)
}

l = append(l, cas.Nbody.Slice()...) // 指的是下面的具体case处理内容
l = append(l, nod(OBREAK, nil, nil))
return l
}

这里其实也分为两种方式,

一种是这个casedefault

直接把 case 对应的 body 放入 AST

1
2
3
4
select {
default:
println("default")
}

转换为:

1
println("default")

另一种是这个case不是default:

转换一下 case 表达式

1
2
3
4
select {
case v, ok <-ch:
// ...
}

转换为:

1
2
3
4
5
if ch == nil {
block()
}
v, ok := <-ch
// ...

2case, 其中一个是default

对于有两个, 但是其中一个为default的,具体处理代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// optimization: two-case select but one is default: single non-blocking op.
if n == 2 && (cases.First().Left == nil || cases.Second().Left == nil) {
var cas *Node
var dflt *Node
if cases.First().Left == nil {
cas = cases.Second()
dflt = cases.First()
} else {
dflt = cases.Second()
cas = cases.First()
}

n := cas.Left
setlineno(n)
r := nod(OIF, nil, nil)
r.Ninit.Set(cas.Ninit.Slice())
switch n.Op {
default:
Fatalf("select %v", n.Op)

case OSEND:
// if selectnbsend(c, v) { body } else { default body }
ch := n.Left
r.Left = mkcall1(chanfn("selectnbsend", 2, ch.Type), types.Types[TBOOL], &r.Ninit, ch, n.Right)

case OSELRECV:
// if selectnbrecv(&v, c) { body } else { default body }
...
r.Left = mkcall1(chanfn("selectnbrecv", 2, ch.Type), types.Types[TBOOL], &r.Ninit, elem, ch)

case OSELRECV2:
// if selectnbrecv2(&v, &received, c) { body } else { default body }
...
r.Rlist.Set(append(dflt.Ninit.Slice(), dflt.Nbody.Slice()...))
return []*Node{r, nod(OBREAK, nil, nil)}
}

可以看到真对每个case的具体操作可以转为不同的形式,具体的可以参考针对每个 case 所调用的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// compiler implements
//
// select {
// case c <- v:
// ... foo
// default:
// ... bar
// }
//
// as
//
// if selectnbsend(c, v) {
// ... foo
// } else {
// ... bar
// }
//
func selectnbsend(c *hchan, elem unsafe.Pointer) (selected bool) {
return chansend(c, elem, false, getcallerpc())
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// compiler implements
//
// select {
// case v = <-c:
// ... foo
// default:
// ... bar
// }
//
// as
//
// if selectnbrecv(&v, c) {
// ... foo
// } else {
// ... bar
// }
//
func selectnbrecv(elem unsafe.Pointer, c *hchan) (selected bool) {
selected, _ = chanrecv(c, elem, false)
return
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// compiler implements
//
// select {
// case v, ok = <-c:
// ... foo
// default:
// ... bar
// }
//
// as
//
// if c != nil && selectnbrecv2(&v, &ok, c) {
// ... foo
// } else {
// ... bar
// }
//
func selectnbrecv2(elem unsafe.Pointer, received *bool, c *hchan) (selected bool) {
// TODO(khr): just return 2 values from this function, now that it is in Go.
selected, *received = chanrecv(c, elem, false)
return
}

2个以上case, 或两个case并且没有default

对于这种情况,主要是调用了 selectgo来处理

1
2
fn := syslook("selectgo")
r.Rlist.Set1(mkcall1(fn, fn.Type.Results(), nil, bytePtrToIndex(selv, 0), bytePtrToIndex(order, 0), nodintconst(int64(n))))

用图形来表示这时的select结构如下:

详细内容参考下面的 selectgo的分析

select 源码分析

前面主要是介绍select的基本语法和词法分析过程,下面针对select的运行时代码进行分析

源码

在分析源码之前,先写一个demo, 然后通过编译成汇编,看看内部是如何调用的, 还是使用官方给出的demo, 对其进行编译:

1
go tool compile -S select.go

输出汇编代码:

1
2
3
4
5
6
"".fibonacci STEXT size=354 args=0x10 locals=0xc8
0x0000 00000 (select2.go:5) TEXT "".fibonacci(SB), ABIInternal, $200-16
...
0x00d4 00212 (select2.go:8) CALL runtime.selectgo(SB)
...
0x00d4 00212 (select2.go:8) CALL runtime.selectgo(SB)

可以看出调用了runtime.selectgo函数,这个函数的实现在runtime/select.go:155:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
// selectgo implements the select statement.
//
// cas0 points to an array of type [ncases]scase, and order0 points to
// an array of type [2*ncases]uint16. Both reside on the goroutine's
// stack (regardless of any escaping in selectgo).
//
// selectgo returns the index of the chosen scase, which matches the
// ordinal position of its respective select{recv,send,default} call.
// Also, if the chosen scase was a receive operation, it reports whether
// a value was received.
// cas0 是指向类型为 [ncases]scase 的数组, 其实就是我们 select 写的 case 组成的数组
// order0 指向的是一个类型为 [2*ncases]uint16 的数组
// cas0 和 order0 都存在于 goroutine 的栈中(不考虑逃逸分析)
// selectgo 返回的是要执行的 case 的索引(index)
// 如果 case 是 recv 操作, 还没返回是否接收到了数据(第二个 bool 参数)
// ncases 表示的是 case 的个数
func selectgo(cas0 *scase, order0 *uint16, ncases int) (int, bool) {
...
cas1 := (*[1 << 16]scase)(unsafe.Pointer(cas0)) // 创建一个长度为[1<<16]的数组,并把cas0的地址赋值给cas1, 现在cas1表示了所有的case
order1 := (*[1 << 17]uint16)(unsafe.Pointer(order0)) // 创建一个长度为[1<<17]的数组,指向order0

scases := cas1[:ncases:ncases] // scases 包含了所有的case, 并且cap为case的个数
pollorder := order1[:ncases:ncases] // 长度为case个数的数组,其实里面的内容是要放case的执行顺序
lockorder := order1[ncases:][:ncases:ncases] // 指向order1的后面的存储空间, 内容是要存放根据chan的地址顺序排序的所有chan

// Replace send/receive cases involving nil channels with
// caseNil so logic below can assume non-nil channel.
// 先对所有case进行处理,忽略case.c = nil, 也就是对应的chan已经被关闭或者其他情况导致的nil
for i := range scases {
cas := &scases[i]
if cas.c == nil && cas.kind != caseDefault {
*cas = scase{}
}
}
...

// The compiler rewrites selects that statically have
// only 0 or 1 cases plus default into simpler constructs.
// The only way we can end up with such small sel.ncase
// values here is for a larger select in which most channels
// have been nilled out. The general code handles those
// cases correctly, and they are rare enough not to bother
// optimizing (and needing to test).
// 编译器已经把对于只有0或者1个case+default的形式给优化成了简单的结构
// 这个我们处理的是更多的select case的情况

// generate permuted order
// 对 case 进行随机排序
for i := 1; i < ncases; i++ {
j := fastrandn(uint32(i + 1)) // 快速求随机数, 范围: [0, i]
pollorder[i] = pollorder[j] // 第 i 个数据的值 = 第 j 个数据的值
pollorder[j] = uint16(i) // 第 j 个数的值赋值为 i
}

// sort the cases by Hchan address to get the locking order.
// simple heap sort, to guarantee n log n time and constant stack footprint.
// 对所有case 中的 hchan 按照地址进行堆排序
// 排完序后是为了对齐进行加锁,防止重复加锁

// 根据 hchan 地址构建大顶堆
for i := 0; i < ncases; i++ {
j := i
// Start with the pollorder to permute cases on the same channel.
c := scases[pollorder[i]].c
for j > 0 && scases[lockorder[(j-1)/2]].c.sortkey() < c.sortkey() {
k := (j - 1) / 2
lockorder[j] = lockorder[k]
j = k
}
lockorder[j] = pollorder[i]
}
// 进行堆排序
for i := ncases - 1; i >= 0; i-- {
o := lockorder[i]
c := scases[o].c
lockorder[i] = lockorder[0]
j := 0
for {
k := j*2 + 1
if k >= i {
break
}
if k+1 < i && scases[lockorder[k]].c.sortkey() < scases[lockorder[k+1]].c.sortkey() {
k++
}
if c.sortkey() < scases[lockorder[k]].c.sortkey() {
lockorder[j] = lockorder[k]
j = k
continue
}
break
}
lockorder[j] = o
}
...

// lock all the channels involved in the select
// 根据lockorder对scases的chan上锁,具体实现参考下面对sellock函数的介绍
sellock(scases, lockorder)

var (
gp *g
sg *sudog
c *hchan
k *scase
sglist *sudog
sgnext *sudog
qp unsafe.Pointer
nextp **sudog
)

loop:
// pass 1 - look for something already waiting
// 第一种情况,已经有满足的 case 条件
var dfli int
var dfl *scase
var casi int
var cas *scase
var recvOK bool
for i := 0; i < ncases; i++ { // 根据之前的随机顺序访问
casi = int(pollorder[i])
cas = &scases[casi]
c = cas.c

switch cas.kind {
case caseNil: // 如果当前 case 是一个 nil 的 chan, 则不处理,继续寻找其他的 case
continue

case caseRecv: // 如果是接收数据的 case
sg = c.sendq.dequeue()
if sg != nil { // sg != nil 这证明 buf已经满了,或者是一个不带buf的chan, 然后执行 recv函数,recv的过程可以参考 go channel 原理的介绍
goto recv
}
if c.qcount > 0 { // 如果 qcount > 0 这证明 buf 里是有数据的,所有从 buf 里取数据
goto bufrecv
}
if c.closed != 0 { // 对应的chan已经被关闭了
goto rclose
}

case caseSend: // 如果是发送数据
if raceenabled {
racereadpc(c.raceaddr(), cas.pc, chansendpc)
}
if c.closed != 0 { // 如果被关闭了,进入sclose, 最终会 panic
goto sclose
}
sg = c.recvq.dequeue()
if sg != nil { // recv队列不为空,则执行 send 函数, 参考 go channel 中 send 的处理
goto send
}
if c.qcount < c.dataqsiz { // 证明 buf 还没满,直接发送到 buf 中
goto bufsend
}

case caseDefault: // case 是 default case,这里其实时延迟赋值,如果运行了其他case, 就没必要赋值了
dfli = casi // default 的下标
dfl = cas // 给 default 赋值
// 继续寻找其他case
}
}

// 走到这里证明没有准备好的chan case能够执行,下面会优先执行 default

if dfl != nil { // default 不为空
selunlock(scases, lockorder) // 对已经上锁的进行解锁
casi = dfli
cas = dfl
goto retc // 直接返回 default 的 index
}

// pass 2 - enqueue on all chans
// 走到这里所有的case都无法运行
// 把所有的chan都进入阻塞状态
// 具体细节可以参考 go channel 的操作
gp = getg()
if gp.waiting != nil {
throw("gp.waiting != nil")
}
nextp = &gp.waiting
for _, casei := range lockorder {
casi = int(casei)
cas = &scases[casi]
if cas.kind == caseNil {
continue
}
c = cas.c
sg := acquireSudog()
sg.g = gp
sg.isSelect = true
// No stack splits between assigning elem and enqueuing
// sg on gp.waiting where copystack can find it.
sg.elem = cas.elem
sg.releasetime = 0
if t0 != 0 {
sg.releasetime = -1
}
sg.c = c
// Construct waiting list in lock order.
*nextp = sg
nextp = &sg.waitlink

switch cas.kind {
case caseRecv:
c.recvq.enqueue(sg)

case caseSend:
c.sendq.enqueue(sg)
}
}

// wait for someone to wake us up
// 这里当前goroutine会进入阻塞,让出CPU, 等待chan可以发送或者接收数据时就被唤醒,这个可以参考chan的实现:
// 当前goroutine进入recvq或sendq, 当任何一个chan被其他goroutine操作时,就会把当前goroutine唤醒
gp.param = nil
gopark(selparkcommit, nil, waitReasonSelect, traceEvGoBlockSelect, 1) //阻塞,直到被唤醒

sellock(scases, lockorder)

gp.selectDone = 0
sg = (*sudog)(gp.param)
gp.param = nil

// pass 3 - dequeue from unsuccessful chans
// otherwise they stack up on quiet channels
// record the successful case, if any.
// We singly-linked up the SudoGs in lock order.
casi = -1
cas = nil
sglist = gp.waiting
// Clear all elem before unlinking from gp.waiting.
for sg1 := gp.waiting; sg1 != nil; sg1 = sg1.waitlink {
sg1.isSelect = false
sg1.elem = nil
sg1.c = nil
}
gp.waiting = nil

for _, casei := range lockorder {
k = &scases[casei]
if k.kind == caseNil {
continue
}
if sglist.releasetime > 0 {
k.releasetime = sglist.releasetime
}
if sg == sglist {
// sg has already been dequeued by the G that woke us up.
casi = int(casei)
cas = k // 寻找当前被唤醒的case
} else {
c = k.c
if k.kind == caseSend {
c.sendq.dequeueSudoG(sglist) // 出队, 但是不处理对应的 chan 值
} else {
c.recvq.dequeueSudoG(sglist) // 出队, 但是不处理对应的 chan 值
}
}
sgnext = sglist.waitlink
sglist.waitlink = nil
releaseSudog(sglist) // 释放当前 sglist
sglist = sgnext // 继续处理下一个 sglist
}
// 如果没有被唤醒的case(在一些情况下, 如: close chan等)
if cas == nil {
// We can wake up with gp.param == nil (so cas == nil)
// when a channel involved in the select has been closed.
// It is easiest to loop and re-run the operation;
// we'll see that it's now closed.
// Maybe some day we can signal the close explicitly,
// but we'd have to distinguish close-on-reader from close-on-writer.
// It's easiest not to duplicate the code and just recheck above.
// We know that something closed, and things never un-close,
// so we won't block again.
goto loop
}

c = cas.c
...
if cas.kind == caseRecv {
recvOK = true
}
...
selunlock(scases, lockorder)
goto retc

bufrecv:
// can receive from buffer
// 处理从 buf recv 的情况
...
recvOK = true
qp = chanbuf(c, c.recvx)
if cas.elem != nil {
typedmemmove(c.elemtype, cas.elem, qp)
}
typedmemclr(c.elemtype, qp)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.qcount--
selunlock(scases, lockorder)
goto retc

bufsend:
// can send to buffer
// 处理从 buf send 的情况
...
typedmemmove(c.elemtype, chanbuf(c, c.sendx), cas.elem)
c.sendx++
if c.sendx == c.dataqsiz {
c.sendx = 0
}
c.qcount++
selunlock(scases, lockorder)
goto retc

recv:
// can receive from sleeping sender (sg)
// 直接从 goroutine 中 recv
recv(c, sg, cas.elem, func() { selunlock(scases, lockorder) }, 2)
if debugSelect {
print("syncrecv: cas0=", cas0, " c=", c, "\n")
}
recvOK = true
goto retc

rclose:
// read at end of closed channel
// recv close chan 的情况
selunlock(scases, lockorder)
recvOK = false
if cas.elem != nil {
typedmemclr(c.elemtype, cas.elem)
}
if raceenabled {
raceacquire(c.raceaddr())
}
goto retc

send:
// can send to a sleeping receiver (sg)
// 直接从 goroutine 中 send
...
send(c, sg, cas.elem, func() { selunlock(scases, lockorder) }, 2)
if debugSelect {
print("syncsend: cas0=", cas0, " c=", c, "\n")
}
goto retc

retc:
// 返回 index 和 recv 状态
if cas.releasetime > 0 {
blockevent(cas.releasetime-t0, 1)
}
return casi, recvOK

sclose:
// send on closed channel
// send close chan 的情况
selunlock(scases, lockorder)
panic(plainError("send on closed channel"))
}

以上是整个select的选择过程。
其中 每个case scase 的数据结构如下:

1
2
3
4
5
6
7
8
9
10
// Select case descriptor.
// Known to compiler.
// Changes here must also be made in src/cmd/internal/gc/select.go's scasetype.
type scase struct {
c *hchan // case 语句中使用到的 chan
elem unsafe.Pointer // data element
kind uint16 // case的类型,包括send, recv, default等
pc uintptr // race pc (for race detector / msan)
releasetime int64
}

对每个casechan上锁的过程如下:

1
2
3
4
5
6
7
8
9
10
11
// 对 select 的 case.c 上锁,根据 lockorder 的顺序
func sellock(scases []scase, lockorder []uint16) {
var c *hchan
for _, o := range lockorder {
c0 := scases[o].c
if c0 != nil && c0 != c { // 这个判断如果不满足证明当前chan和前一个chan地址是一样的,只上一次锁就行了
c = c0
lock(&c.lock)
}
}
}

上面的过程也可以用一个流程图来表示:

参考

The Go Programming Language
select 源码分析

go channel 原理

本篇主要介绍chan的内部实现原理(基于go1.12), 通过源码和图形的方式展示chan的内部结构及对chan进行操作的过程。

make chan

在进入源码分析之前,我们假设自己并不知道去哪里看其源码,我们先简单的创建一个chan

1
2
3
4
5
package main

func main() {
_ = make(chan int, 3)
}

为了分析其内部实现,我们可以通过compile工具对其编译生成伪汇编代码:

1
go tool compile -S chan.go

生成的汇编代码重点的内容入下:

1
2
3
4
5
6
"".main STEXT size=71 args=0x0 locals=0x20
0x0000 00000 (chan1.go:3) TEXT "".main(SB), ABIInternal, $32-0
...
0x0031 00049 (chan1.go:4) CALL runtime.makechan(SB)
...
0x0045 00069 (chan1.go:3) JMP 0

可以看到执行make其实最终执行的是runtime.makechan这个函数,这个函数的实现在runtime/chan.go文件中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func makechan(t *chantype, size int) *hchan {
elem := t.elem
...
mem, overflow := math.MulUintptr(elem.size, uintptr(size))
...
var c *hchan
switch {
case mem == 0:
// Queue or element size is zero.
c = (*hchan)(mallocgc(hchanSize, nil, true))
// Race detector uses this location for synchronization.
c.buf = c.raceaddr()
case elem.kind&kindNoPointers != 0:
// Elements do not contain pointers.
// Allocate hchan and buf in one call.
c = (*hchan)(mallocgc(hchanSize+mem, nil, true))
c.buf = add(unsafe.Pointer(c), hchanSize)
default:
// Elements contain pointers.
c = new(hchan)
c.buf = mallocgc(mem, elem, true)
}

c.elemsize = uint16(elem.size)
c.elemtype = elem
c.dataqsiz = uint(size)
...
return c

可以看到最终会返回一个*hchan类型,这个就是chan的结构体:

1
2
3
4
5
6
7
8
9
10
11
12
13
type hchan struct {
qcount uint // 队列中有数据的个数
dataqsiz uint // 循环队列的大小z
buf unsafe.Pointer // 指向循环队列的地址
elemsize uint16
closed uint32 // chan的关闭状态
elemtype *_type // element type
sendx uint // 队列中下一个要发送的数据的下标
recvx uint // 队列中下一个要接收的数据的下标
recvq waitq // 等待接受的G队列
sendq waitq // 等待发送的G队列
lock mutex // 操作chan是需要加锁
}

执行完上面的make后,生成的chan如下:

send chan

为了了解我们往chan发送的时候都做了什么我可能先写一个demo:

1
2
3
4
5
6
package main

func main() {
c := make(chan int, 3)
c <- 3
}

查看其汇编代码:

1
2
3
4
5
6
7
8
"".main STEXT size=97 args=0x0 locals=0x20
0x0000 00000 (chan2.go:3) TEXT "".main(SB), ABIInternal, $32-0
...
0x0031 00049 (chan2.go:4) CALL runtime.makechan(SB)
...
0x004b 00075 (chan2.go:5) CALL runtime.chansend1(SB)
...
0x005f 00095 (chan2.go:3) JMP 0

可以看出我们往chan发送数据其实执行的是runtime.chansend1函数,这个函数很简简单,只是调用了runtime.chansend函数,我们主要看一下runtime.chansend函数的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
func chansend(c *hchan, ep unsafe.Pointer, block bool, callerpc uintptr) bool {
if c == nil {
if !block {
return false
}
gopark(nil, nil, waitReasonChanSendNilChan, traceEvGoStop, 2)
throw("unreachable")
}
...
if !block && c.closed == 0 && ((c.dataqsiz == 0 && c.recvq.first == nil) ||
(c.dataqsiz > 0 && c.qcount == c.dataqsiz)) {
return false
}
...
lock(&c.lock)
// 往已经 closed 的 chan 发送数据会直接 panic
if c.closed != 0 {
unlock(&c.lock)
panic(plainError("send on closed channel"))
}
...
// 如果有接收队列,则进入send函数
if sg := c.recvq.dequeue(); sg != nil {
// Found a waiting receiver. We pass the value we want to send
// directly to the receiver, bypassing the channel buffer (if any).
send(c, sg, ep, func() { unlock(&c.lock) }, 3)
return true
}
...
// 没有接收队列,buf还没有满,则直接往里放数据
if c.qcount < c.dataqsiz {
// Space is available in the channel buffer. Enqueue the element to send.
qp := chanbuf(c, c.sendx)
if raceenabled {
raceacquire(qp)
racerelease(qp)
}
typedmemmove(c.elemtype, qp, ep)
c.sendx++
if c.sendx == c.dataqsiz { //如果sendx == dataqsize, 证明buf满了,
c.sendx = 0 // c.sendx=0保证了又从头开始,形成了一个循环队列
}
c.qcount++
unlock(&c.lock)
return true
}

if !block {
unlock(&c.lock)
return false
}

//获取一个sudog结构, 把当前发送数据所在的g和要发送的数据都放到这里
gp := getg()
mysg := acquireSudog()
mysg.releasetime = 0
if t0 != 0 {
mysg.releasetime = -1
}
// No stack splits between assigning elem and enqueuing mysg
// on gp.waiting where copystack can find it.
mysg.elem = ep
mysg.waitlink = nil
mysg.g = gp
mysg.isSelect = false
mysg.c = c
gp.waiting = mysg
gp.param = nil
c.sendq.enqueue(mysg) // 把这个sudog结构体放到发送对队列中
goparkunlock(&c.lock, waitReasonChanSend, traceEvGoBlockSend, 3) //阻塞当前g,直到由于可以发送数据而被唤醒
// Ensure the value being sent is kept alive until the
// receiver copies it out. The sudog has a pointer to the
// stack object, but sudogs aren't considered as roots of the
// stack tracer.
KeepAlive(ep)

// someone woke us up.
if mysg != gp.waiting {
throw("G waiting list is corrupted")
}
gp.waiting = nil
if gp.param == nil {
if c.closed == 0 {
throw("chansend: spurious wakeup")
}
panic(plainError("send on closed channel"))
}
gp.param = nil
if mysg.releasetime > 0 {
blockevent(mysg.releasetime-t0, 2)
}
mysg.c = nil
releaseSudog(mysg)
return true
}

下面我们有一个图来表示其过程,图中主要分为下面几个步骤:

  1. 往上面初始化好的hchan结构体发送第 1 个数据: 数据放到buf[0]的位置
  2. hchan结构体发送第 2 个数据: 数据放到buf[1]的位置
  3. hchan结构体发送第 3 个数据: 数据放到buf[2]的位置, 这时buf满了
  4. buf满了的hchan结构体发送第 4 个数据: g1会放到sudog结构体中,并放到sendq队列中,等待被唤醒
  5. buf满了的hchan结构体发送第 5 个数据: g2会放到sudog结构体中,并放到sendq队列中,等待被唤醒

recv chan

同上面一样,我们先写一个demo看看recv调用的是哪个函数:

1
2
3
4
5
6
package main

func main() {
c := make(chan int, 3)
<-c
}
1
2
3
4
5
6
7
"".main STEXT size=94 args=0x0 locals=0x20
...
0x0031 00049 (chan3.go:4) CALL runtime.makechan(SB)
...
0x0048 00072 (chan3.go:5) CALL runtime.chanrecv1(SB)
...
0x005c 00092 (chan3.go:3) JMP 0

同样runtime.chanrecv1也是简单调用了runtime.chanrecv函数,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
func chanrecv(c *hchan, ep unsafe.Pointer, block bool) (selected, received bool) {
...
if c == nil {
if !block {
return
}
gopark(nil, nil, waitReasonChanReceiveNilChan, traceEvGoStop, 2)
throw("unreachable")
}
if !block && (c.dataqsiz == 0 && c.sendq.first == nil ||
c.dataqsiz > 0 && atomic.Loaduint(&c.qcount) == 0) &&
atomic.Load(&c.closed) == 0 {
return
}

lock(&c.lock)
// 如果chan已经被关闭,并且qcount==0, 则返回默认零值+false(如x, ok := <- c, x是零值,ok=false)
if c.closed != 0 && c.qcount == 0 {
if raceenabled {
raceacquire(c.raceaddr())
}
unlock(&c.lock)
if ep != nil {
typedmemclr(c.elemtype, ep)
}
return true, false
}
//如果在接收的时候有发送队列存在,则执行recv函数
if sg := c.sendq.dequeue(); sg != nil {
// Found a waiting sender. If buffer is size 0, receive value
// directly from sender. Otherwise, receive from head of queue
// and add sender's value to the tail of the queue (both map to
// the same buffer slot because the queue is full).
recv(c, sg, ep, func() { unlock(&c.lock) }, 3)
return true, true
}
// 如果存在buf, 存在数据
if c.qcount > 0 {
// Receive directly from queue
qp := chanbuf(c, c.recvx) //获取recvx位置的地址
if raceenabled {
raceacquire(qp)
racerelease(qp)
}
if ep != nil {
typedmemmove(c.elemtype, ep, qp) // 把recvx位置的数据copy到接收的变量中
}
typedmemclr(c.elemtype, qp) // 清空原来recvx位置的数据
c.recvx++
if c.recvx == c.dataqsiz { // 如果recvx == dataqsiz 证明已经到达最后一个,需要从头开始
c.recvx = 0 //从头开始,形成一个循环队列
}
c.qcount--
unlock(&c.lock)
return true, true
}

if !block {
unlock(&c.lock)
return false, false
}
gp := getg()
mysg := acquireSudog() // 获取一个sudog结构,把对应的g和接收数据的变量地址放到sudog中
mysg.releasetime = 0
if t0 != 0 {
mysg.releasetime = -1
}
// No stack splits between assigning elem and enqueuing mysg
// on gp.waiting where copystack can find it.
mysg.elem = ep
mysg.waitlink = nil
gp.waiting = mysg
mysg.g = gp
mysg.isSelect = false
mysg.c = c
gp.param = nil
c.recvq.enqueue(mysg) // 把sudog放入接收队列中
goparkunlock(&c.lock, waitReasonChanReceive, traceEvGoBlockRecv, 3) //阻塞当前g,直到被唤醒

// someone woke us up
if mysg != gp.waiting {
throw("G waiting list is corrupted")
}
gp.waiting = nil
if mysg.releasetime > 0 {
blockevent(mysg.releasetime-t0, 2)
}
closed := gp.param == nil
gp.param = nil
mysg.c = nil
releaseSudog(mysg)
return true, !closed
}

上面说到如果存在发送队列就会执行recv函数,下面看一下这个函数的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
func recv(c *hchan, sg *sudog, ep unsafe.Pointer, unlockf func(), skip int) {
//对于nobuf的chan, 直接copy数据
if c.dataqsiz == 0 {
if raceenabled {
racesync(c, sg)
}
if ep != nil {
// copy data from sender
recvDirect(c.elemtype, sg, ep)
}
} else {
// Queue is full. Take the item at the
// head of the queue. Make the sender enqueue
// its item at the tail of the queue. Since the
// queue is full, those are both the same slot.
qp := chanbuf(c, c.recvx) // 获取接收数据的位置
if raceenabled {
raceacquire(qp)
racerelease(qp)
raceacquireg(sg.g, qp)
racereleaseg(sg.g, qp)
}
// copy data from queue to receiver
if ep != nil {
typedmemmove(c.elemtype, ep, qp) //把recvx位置的数据copy到接收的变量中
}
// copy data from sender to queue
typedmemmove(c.elemtype, qp, sg.elem) // 把发送队列的数据copy到当前recvx的位置
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
// 因为上面把发送队列的数据copy到了recvx, 为了保证下一个位置属按照顺序的,需要sendx = recvx
// 这几步保证了chan是一个FIFO的过程
c.sendx = c.recvx // c.sendx = (c.sendx+1) % c.dataqsiz
}
sg.elem = nil
gp := sg.g
unlockf()
gp.param = unsafe.Pointer(sg)
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
goready(gp, skip+1) // 把出队的g放到ready中,下次调度就可以运行了,不再阻塞
}

下面我们有一个图来表示接收数据的过程,图中主要分为下面几个步骤:

  1. 初始的hchan是上面send之后的结构
  2. g3执行接收操作,首先会把发送队列中的第 1 个g1出队,然后把buf[0]的数据赋值到g3中,再把g1的数据赋值到buf[0]
  3. g3执行接收操作,首先会把发送队列中的第 2 个g2出队,然后把buf[1]的数据赋值到g3中,再把g2的数据赋值到buf[1]
  4. 这个时候没有发送队列了,所以可以直接把buf[2]中的书赋值到g3
  5. 把下一个数据buf[0]中的书赋值到g3
  6. 把最后一个数据buf[1]中的书赋值到g3
  7. 已经没有数据可以赋值给g3了,所以g3被放入sudog结构体中,入队到了接收队列, 进入阻塞状态

send chan again

上面介绍send说到如果发送数据的时候有recvq队列就会调用send函数,这个函数的具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func send(c *hchan, sg *sudog, ep unsafe.Pointer, unlockf func(), skip int) {
if raceenabled {
if c.dataqsiz == 0 {
racesync(c, sg) // no buf 直接同步
} else {
// Pretend we go through the buffer, even though
// we copy directly. Note that we need to increment
// the head/tail locations only when raceenabled.
qp := chanbuf(c, c.recvx) // 获取recvx位置
raceacquire(qp)
racerelease(qp)
raceacquireg(sg.g, qp)
racereleaseg(sg.g, qp)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.sendx = c.recvx // c.sendx = (c.sendx+1) % c.dataqsiz
}
}
if sg.elem != nil {
sendDirect(c.elemtype, sg, ep) //直接把要发送的数据 copy 到 recvq 队列出队的 g 中
sg.elem = nil
}
gp := sg.g
unlockf()
gp.param = unsafe.Pointer(sg)
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
goready(gp, skip+1) // 把g放到ready队列中,下次有机会被调度,不再阻塞
}

close

当我们close掉一个chan都发生了什么呢? 下面写一个closedemo:

1
2
3
4
5
6
package main

func main() {
c := make(chan int, 3)
close(c)
}

1
2
3
4
5
6
"".main STEXT size=85 args=0x0 locals=0x20
...
0x0031 00049 (chan4.go:4) CALL runtime.makechan(SB)
...
0x003f 00063 (chan4.go:5) CALL runtime.closechan(SB)
0x0053 00083 (chan4.go:3) JMP 0

可以调用了runtime.closechan函数,对应的代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
func closechan(c *hchan) {
if c == nil {
panic(plainError("close of nil channel"))
}

lock(&c.lock)
if c.closed != 0 {
unlock(&c.lock)
panic(plainError("close of closed channel")) // 已经关闭的 chan 不能再关闭
}

if raceenabled {
callerpc := getcallerpc()
racewritepc(c.raceaddr(), callerpc, funcPC(closechan))
racerelease(c.raceaddr())
}

c.closed = 1 // 关闭状态设置为 1

var glist gList
// release all readers
// 遍历所有recvq 队列, 从队列中去掉,并清空其内容,把所有g都放到glist结构中
for {
sg := c.recvq.dequeue()
if sg == nil {
break
}
if sg.elem != nil {
typedmemclr(c.elemtype, sg.elem)
sg.elem = nil
}
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
gp := sg.g
gp.param = nil
if raceenabled {
raceacquireg(gp, c.raceaddr())
}
glist.push(gp)
}

// 遍历所有 sendq 队列, 从队列中去掉,把所有g都放到glist结构中
// release all writers (they will panic)
for {
sg := c.sendq.dequeue()
if sg == nil {
break
}
sg.elem = nil
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
gp := sg.g
gp.param = nil
if raceenabled {
raceacquireg(gp, c.raceaddr())
}
glist.push(gp)
}
unlock(&c.lock)

// Ready all Gs now that we've dropped the channel lock.
// 把刚才所有放到 glist 中的 g 都改为ready 状态,使其不再阻塞
for !glist.empty() {
gp := glist.pop()
gp.schedlink = 0
goready(gp, 3)
}
}

下面我们分别看一下:

  1. 当存在recvq队列时:

  2. 当存在sendq队列时:

no buffer chan

前面讲的都是带bufferchan, 还有一种是经常使用的不带bufferchan,其实处理起来更简单,前面源码部分已经有涉及了,下面看一下操作过程:

  1. make一个不带bufferchan
  2. g1向这个chan发送数据, 由于没有接收者而被阻塞,放到sendq
  3. g2继续想这个chan发送数据,继续放到sendq
  4. 来一个接收者g3, 这时把g1sendq中出队,并把elem的值赋值给g3x
  5. g3继续接收,把g2sendq中出队,并把elem的值赋值给g3x
  6. 没有发送队列存在,g3也进入了阻塞状态,放到了recvq队列中

下面是其图形化展示:

参考

图解Go的channel底层原理
Go 1.12 runtime/chan.go
GopherCon 2017: Kavya Joshi - Understanding Channels

Go 单元测试

Go单元测试

在计算机编程中,单元测试(英语:Unit Testing)又称为模块测试,是针对程序模块(软件设计的最小单位)来进行正确性检验的测试工作。
我们为何要进行单元测试呢?其实如果你不添加单元测试的话,别人调用你提供的函数是,其实就是帮你做测试,但是这种测试我们越早做越能发现问题.

单元测试的粒度

程序单元是应用的最小可测试部件。在过程化编程中,一个单元就是单个程序、函数、过程等;对于面向对象编程,最小单元就是方法,包括基类(超类)、抽象类、或者派生类(子类)中的方法。在 Go 中程序单元指的是Package中的方法。
那么我们对于 Go 中改对那些函数写单元测试呢?这里的答案是: 包中可导出的函数。 因为这些函数是对外可见的,这些是我们包的入口。那么对于不可导出的函数我们是否需要些单元测试呢?答案是不用。有些人可能会有一位,如果我们不对不可导出的函数写单元测试,那么如何保证单元测试的覆盖率呢?因为有些不可导出函数的覆盖率达不到要求。这里要说的是: 如果有些不可导出函数单元测试覆盖率达不到,有两点可能性:

  • 这些逻辑是不需要的,你可以直接去掉
  • 你的测试用例不够,你需要增加可导出函数的测试用例

单测的三个原则

还有一个问题是: 我该先开发功能在写单元测试,还是先写单元测试再开发功能?
其实关于TDD有三个定律:

  1. You are not allowed to write any production code unless it is to make a failing unit test pass.
  2. You are not allowed to write any more of a unit test than is sufficient to fail; and compilation failures are failures.
  3. You are not allowed to write any more production code than is sufficient to pass the one failing unit test.

关于这三条定律,我发现每个人翻译的都不一样,我觉得比较符合我的理解的翻译是:

  1. 除非是为了使一个失败的 unit test 通过,否则不允许编写任何产品代码
  2. 在一个单元测试中,只允许编写刚好能够导致失败的内容(编译错误也算失败)
  3. 只允许编写刚好能够使一个失败的 unit test 通过的产品代码

如果违反了会怎么样呢?
违反第一条,先编写了产品代码,那这段代码是为了实现什么需求呢?怎么确保它真的实现了呢?
违反第二条,写了多个失败的测试,如果测试长时间不能通过,会增加开发者的压力,另外,测试可能会被重构,这时会增加测试的修改成本。
违反第三条,产品代码实现了超出当前测试的功能,那么这部分代码就没有测试的保护,不知道是否正确,需要手工测试。可能这是不存在的需求,那就凭空增加了代码的复杂性。如果是存在的需求,那后面的测试写出来就会直接通过,破坏了 TDD 的节奏感。

还是针对上面的问题: 先写单元测试还是先写功能?
我的答案是: 单元测试-> 功能开发 -> 单元测试 -> 功能开发…
它们应该是交替进行的,既: 先写小范围的单元测试,然后针对这些测试进行开发功能,等所有测试通过后继续增加测试case, 然后针对新增的case继续编写功能,直到功能满足了需求为止。

测试行为, 而非实现

Avoid Testing Implementation Details, Test Behaviours
当我们测试行为时,我们的意思是 : “我不在乎你是如何得出答案的,只要确保在这种情况下答案是正确的”
当我们测试实现时,我们的意思是 : “我不在乎答案是什么,只要确保它是按照你规定的方式工作的。”

初级

单元测试编写

下面给出一个完整的Go的单元测试的例子:
split.go文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package split

import "strings"

// Split slices s into all substrings separated by sep and
// returns a slice of the substrings between those separators.
func Split(s, sep string) []string {
var result []string
i := strings.Index(s, sep)
for i > -1 {
result = append(result, s[:i])
s = s[i+len(sep):]
i = strings.Index(s, sep)
}
return append(result, s)
}

split_test.go文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
got := Split("a/b/c", "/")
want := []string{"a", "b", "c"}
if !reflect.DeepEqual(want, got) {
t.Fatalf("expected: %v , got %v", want, got)
}
}

Go官方网站有关于单元测试的写法介绍, 以上面的代码为例:

  1. 一般我们需要单元测试文件和要测试的包的文件需要在同一个目录下,并且以_test.go结尾。

    1
    2
    3
    src/split/
    ├── split.go
    └── split_test.go
  2. 单元测试的函数名为Test + 要测试的函数名。

    1
    2
    3
    4
    // 要测试的函数
    func Split(...)
    // 单元测试函数
    func TestSplit(...)
  3. 单元测试函数的参数是固定的 (*testing.T):

    1
    func TestSplit(t *testing.T) {}

运行单元测试

Go语言的工具链中提供了很强大的单元测试工具:go test, 如果想要运行刚才的单元测试,我们只需要在split文件夹下执行:

1
go test

就可以得出测试结果:

1
2
3
$go test
PASS
ok split 0.008s

运行多个单元测试

有是有我们需要同时运行多个单元测试, 如果这些单元测试在同一个包下:

1
2
3
4
5
6
7
8
9
10
11
12
$GOROOT/src/encoding/xml/
├── atom_test.go
├── example_marshaling_test.go
├── example_test.go
├── example_text_marshaling_test.go
├── marshal.go
├── marshal_test.go
├── read.go
├── read_test.go
├── typeinfo.go
├── xml.go
└── xml_test.go

我们可以直接运行: go test
如果这些单元测试文件不在同一个包下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
$GOROOT/src/encoding/
├── ascii85
│   ├── ascii85.go
│   └── ascii85_test.go
├── asn1
│   ├── asn1.go
│   ├── asn1_test.go
│   ├── common.go
│   ├── marshal.go
│   └── marshal_test.go
├── base32
│   ├── base32.go
│   ├── base32_test.go
│   └── example_test.go
├── base64
│   ├── base64.go
│   ├── base64_test.go
│   └── example_test.go
...

我们需要在这些包的外面运行: go test ./...

覆盖率测试

如果我们想要查看单元测试的覆盖率,Go 工具链也是支持的, 详情可以参考官方的Blog: The cover story
如果要查看单元测试覆盖率,我们可以运行:

1
2
3
4
$go test -cover
PASS
coverage: 100.0% of statements
ok split 0.013s

但是上面的测试只给出了覆盖率的值,并没有看到详细的信息,如果我们需要查看覆盖率的详细信息,可以把测试覆盖率的内容输出到文件中:

1
$go test -coverprofile=coverage.out

这样,测试覆盖率的详细信息就输出到了文件coverage.out中。
如过要查看每个函数的测试覆盖率,可以利用刚才的coverage.out文件:

1
2
3
$go tool cover -func=coverage.out
split/split.go:7: Split 100.0%
total: (statements) 100.0%

如果要想可视化测试覆盖率,还可以生成html格式:

1
$go tool cover -html=coverage.out

我们可以看到每行的覆盖情况:
覆盖率
其中红色代表没有覆盖到,绿色代表覆盖到,灰色代表不计入测试覆盖率的范围

进阶

多个case

前面我们讲了如何进行基本的单元测试,但是现实中往往我们需要对同一个函数进行多个case的测试,那么其实有两种写法:

针对每个case写一个测试函数:

对于比较复杂的函数,其函数的表现可能会收到不同环境因素的影响,他们的单元测试写法差别也比较大,比如beegologs/file 的单元测试, 同样是测试FileDailyRotate函数,TestFileDailyRotate_01测试的是创建文件, TestFileDailyRotate_02测试的是当创建的文件存在时,给文件加后缀。

同一个测试函数里有多个case:

一般比较简单的单元测试,只是根据输入的不同而产生不同的输出,则可以使用这种方式。比如前面说的split函数的多个case测试, 我们把split_test.go改为下面的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
type test struct {
input string
sep string
want []string
}

tests := []test{
{input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
{input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
{input: "abc", sep: "/", want: []string{"abc"}},
}

for _, tc := range tests {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("expected: %v , got %v", tc.want, got)
}
}
}

边界条件测试

由长期的测试工作经验得知,大量的错误是发生在输入或输出的边界上。因此针对各种边界情况设计测试用例,可以查出更多的错误。上面的case中我们并没有对边界条件进行测试,下面我们加上一个边界条件的测试case:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
type test struct {
input string
sep string
want []string
}

tests := []test{
{input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
{input: "a/b/c/", sep: "/", want: []string{"a", "b", "c"}},
{input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
{input: "abc", sep: "/", want: []string{"abc"}},
}

for _, tc := range tests {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("expected: %v , got %v", tc.want, got)
}
}
}

然后我们执行单元测试:

1
2
3
4
5
6
7
$go test
=== RUN TestSplit
--- FAIL: TestSplit (0.00s)
split_test.go:25: expected: [a b c] , got [a b c ]
FAIL
exit status 1
FAIL split 0.015s

可以看到我们的单元测试有一个case没有通过,但是这里有一点疑问:哪个测试case没过?

定位测试case

通过编号定位

我们可以给每个case一个编号:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
type test struct {
input string
sep string
want []string
}

tests := []test{
{input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
{input: "a/b/c/", sep: "/", want: []string{"a", "b", "c"}},
{input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
{input: "abc", sep: "/", want: []string{"abc"}},
}

for i, tc := range tests {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("test %d: expected: %v , got %v", i+1, tc.want, got)
}
}
}

这时候执行

1
2
3
4
5
6
$ go test
--- FAIL: TestSplit (0.00s)
split_test.go:25: test 2: expected: [a b c] , got [a b c ]
FAIL
exit status 1
FAIL split 0.016s

这里可以定位出 test 2 有问题的,但是编号的问题是 :

  • 每个人定义的开始下标可能不同: 有的人是从0开始,有的人从1开始,照成理解不一致
  • 随着case的增多,同样不好定位具体的case: 如果你要从50case中定位第27case, 还是比较费时的。

通过名字定位

还有一种方式: 我们给每个case一个名字:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
type test struct {
name string
input string
sep string
want []string
}

tests := []test{
{name: "simple", input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
{name: "trailing sep", input: "a/b/c/", sep: "/", want: []string{"a", "b", "c"}},
{name: "wrong sep", input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
{name: "no sep", input: "abc", sep: "/", want: []string{"abc"}},
}

for _, tc := range tests {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("%s: expected: %v , got %v", tc.name, tc.want, got)
}
}
}

1
2
3
4
5
6
$go test
--- FAIL: TestSplit (0.00s)
split_test.go:26: trailing sep: expected: [a b c] , got [a b c ]
FAIL
exit status 1
FAIL split 0.015s

现在我们可以看到我们可以很好的通过trailing sep快速定位到了具体的case

随机测试case

上面的测试方式看上去很完美了,可以如果我们实现的时候没有注意,case之间可能会相互影响, 比如一个case在函数内部修改了一个全局变量,下一个case的执行就会受到这种影响。为了避免由于测试顺序带来的问题,我们一般都会让每个case之间的顺序是随机的,而不是按照特定的顺序,而slice本身有顺序的,所以不满足我们的条件,这时我们可以使用map, 同时还可以把name放到mapkey中,简化我们的写法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
tests := map[string]struct {
input string
sep string
want []string
}{
"simple": {input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
"trailing sep": {input: "a/b/c/", sep: "/", want: []string{"a", "b", "c"}},
"wrong sep": {input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
"no sep": {input: "abc", sep: "/", want: []string{"abc"}},
}

for name, tc := range tests {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("%s: expected: %v , got %v", name, tc.want, got)
}
}
}

然后执行单元测试:

1
2
3
4
5
6
$go test
--- FAIL: TestSplit (0.00s)
split_test.go:23: trailing sep: expected: [a b c] , got [a b c ]
FAIL
exit status 1
FAIL split 0.014s

并发测试

看上去前面的测试更加完美了, 但是……
我们的测试case出现错误的时候,我们会调用:

1
t.Fatalf("%s: expected: %v , got %v", name, tc.want, got)

打印我们的错误信息,但是这个错误信息打印后整个测试过程就结束了,如果我们有很多个case需要测试,而前面的case失败后就无法进行后面的测试了,这时候我们如果针对这个出错的case修改后,我们会发现其他的case有报错了,我们反复的修改,但是我们并不知道自己到底有多少个case是有问题的,我们无法一次性把问题修复好,照成我们工作量变大,并且效率变低,那么我们该如何改进这个情况呢?
我们知道问题出在t.Fatalf,那么我们可不可以即打印出错误信息又不让程序中断呢?答案是: 可以! 我们使用f.Errorf替换f.Fatalf
可是…..
如果某个case出现了panic同样会导致整个程序中断,所以这种方式治标不治本。那么我们该如何改进呢? Go 1.7 开始支持了 sub test。 下面我们就按照Sub Test的写法进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package split

import (
"reflect"
"testing"
)

func TestSplit(t *testing.T) {
tests := map[string]struct {
input string
sep string
want []string
}{
"simple": {input: "a/b/c", sep: "/", want: []string{"a", "b", "c"}},
"trailing sep": {input: "a/b/c/", sep: "/", want: []string{"a", "b", "c"}},
"wrong sep": {input: "a/b/c", sep: ",", want: []string{"a/b/c"}},
"no sep": {input: "abc", sep: "/", want: []string{"abc"}},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
got := Split(tc.input, tc.sep)
if !reflect.DeepEqual(tc.want, got) {
t.Fatalf("expected: %v , got %v", tc.want, got)
}
})
}
}

通过t.Run的源码我们看到:

1
2
3
4
5
func (t *T) Run(name string, f func(t *T)) bool {
...
go tRunner(t, f)
...
}

其实会诊对每个case启动一个goroutine, 所以其中一个出现了panic不会影响其他的case执行。

上面这种形态就是目前我们进行单元测试的最佳实践了。

高级

理论知识

外部依赖

外部依赖是指我们的函数需要调用其他的函数,外部依赖有可能涉及到一些数据依赖,网络依赖等。关于单元测试中如何解决外部依赖的问题, 常用的方法是: Test Double(测试替身), 而它也分很多种:

  • Dummy objects are passed around but never actually used. Usually they are just used to fill parameter lists.
  • Fake objects actually have working implementations, but usually take some shortcut which makes them not suitable for production (an in memory database is a good example).
  • Stubs provide canned answers to calls made during the test, usually not responding at all to anything outside what’s programmed in for the test.
  • Spies are stubs that also record some information based on how they were called. One form of this might be an email service that records how many messages it was sent.
  • Mocks are what we are talking about here: objects pre-programmed with expectations which form a specification of the calls they are expected to receive.

看上去有点儿头大,分这么多类型而且他们的接线感觉也比较模糊,为了便于理解我们不对这些概念做过多的解读,我们后面把所有我们的工作都看做是Mock

编写可测试代码

函数要短小

函数的第一规则是要短小。第二条规则是还要短小 ———— 《代码整洁之道》
至于如何才算短小,一般建议是不超过100行,也就是显示器一屏所显示的行数。
函数越短小那么单元测试的编写就越简单。

函数功能要单一

函数应该做一件事。做好这件事。只做一件事。 ————–《代码整洁之道》
一个函数做的事情越少其逻辑越简单,难么对应的单元测试也就越简单。

减少外部依赖

这里要明确的是我们要测试的是自己的函数而不是调用的函数,所以我们应该把中重点放到自己的函数上,至于外部依赖的函数越少越好,因为每个外部依赖都增加了我们单元测试的不确定性。

依赖模块要方便 Mock

为了专注我们自己模块的测试,对于外部的模块我们一般都会使用Mock的方法, 所以依赖模块如果好Mock的话测试起来就会方便很多,反之会很麻烦。

方便依赖注入

一般我们Mock是通过依赖注入的方式,这种方式可以方便的更改依赖的对象的实现,而依赖注入的方式有好几种:

  • 通过变量赋值
  • 通过参数传递
  • 通过Set/Get方法

一个外部依赖的例子

一个User包, 有一个通过uid获取分数score的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package user

import (
"strconv"

"github.com/go-redis/redis"
)

func Score(uid int) (int, error) {
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})

_, err := client.Ping().Result()
if err != nil {
return -1, err
}

val, err := client.Get(strconv.Itoa(uid)).Result()
if err == redis.Nil {
return -1, nil
}
if err != nil {
return -1, err
}

return strconv.Atoi(val)
}

一个Class包,通过调用user.Score方法获取分数,根据分数给这个用户一个等级:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package class

import (
"user"
)

func UserLevel(uid int) string {
score, err := user.Score(uid)
if err != nil {
return "E"
}

switch {
case score < 0:
return "N"
case score <= 60:
return "C"
case score <= 90:
return "B"
case score <= 100:
return "A"
default:
return "W"
}
}

现在我们要给UserLevel写单元测试,该怎么写呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package class

import (
"testing"
)

func TestUserLevel(t *testing.T) {
tests := map[string]struct {
input int
want string
}{
"not found user": {input: 1, want: "N"},
"C level": {input: 2, want: "C"},
"B level": {input: 3, want: "B"},
"A level": {input: 4, want: "A"},
"Got Error": {input: 5, want: "E"},
"Wrong Score": {input: 5, want: "W"},
}

for name, tc := range tests {
got := UserLevel(tc.input)
t.Run(name, func(t *testing.T) {
if tc.want != got {
t.Fatalf("expected: %s, got %s", tc.want, got)
}
})
}
}

运行单元测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
$ go test
--- FAIL: TestUserLevel (0.02s)
--- FAIL: TestUserLevel/not_found_user (0.00s)
class_test.go:24: expected: N, got E
--- FAIL: TestUserLevel/C_level (0.00s)
class_test.go:24: expected: C, got E
--- FAIL: TestUserLevel/B_level (0.00s)
class_test.go:24: expected: B, got E
--- FAIL: TestUserLevel/A_level (0.00s)
class_test.go:24: expected: A, got E
--- FAIL: TestUserLevel/Wrong_Score (0.00s)
class_test.go:24: expected: W, got E
FAIL
exit status 1
FAIL class 0.023s

可以看到除了Got Error运行成功,其他的都失败了,因为我们本地并没有开启redis服务,所以是连不上的。如果我们要让这个测试用例通过,显然我们不能真的开启一个redis的服务,我们需要对user.Score进行Mock

Mock框架

go中mock的支持也有很多种:

每个框架都有自己的用法, 这里我那github.com/bouk/monkey来举例子, 改造一下我们的单元测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package class

import (
"errors"
"testing"

"bou.ke/monkey"
"user"
)

func TestUserLevel(t *testing.T) {
tests := map[string]struct {
input int
want string
}{
"not found user": {input: 1, want: "N"},
"C level": {input: 2, want: "C"},
"B level": {input: 3, want: "B"},
"A level": {input: 4, want: "A"},
"Got Error": {input: 5, want: "E"},
"Wrong Score": {input: 6, want: "W"},
}

monkey.Patch(user.Score, mockScore)
for name, tc := range tests {
got := UserLevel(tc.input)
t.Run(name, func(t *testing.T) {
if tc.want != got {
t.Fatalf("expected: %s, got %s", tc.want, got)
}
})
}
}

func mockScore(uid int) (int, error) {
switch uid {
case 1:
return -1, nil
case 2:
return 10, nil
case 3:
return 70, nil
case 4:
return 95, nil
case 5:
return -1, errors.New("something was error")
case 6:
return 130, nil
}
return -1, nil
}

运行测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
$ go test -v
=== RUN TestUserLevel
=== RUN TestUserLevel/not_found_user
=== RUN TestUserLevel/C_level
=== RUN TestUserLevel/B_level
=== RUN TestUserLevel/A_level
=== RUN TestUserLevel/Got_Error
=== RUN TestUserLevel/Wrong_Score
--- PASS: TestUserLevel (0.00s)
--- PASS: TestUserLevel/not_found_user (0.00s)
--- PASS: TestUserLevel/C_level (0.00s)
--- PASS: TestUserLevel/B_level (0.00s)
--- PASS: TestUserLevel/A_level (0.00s)
--- PASS: TestUserLevel/Got_Error (0.00s)
--- PASS: TestUserLevel/Wrong_Score (0.00s)
PASS
ok class 0.014s

面相接口编程

前面通过Mock框架我们可以在测试的时候替换原来的实现,这样就可以很方便的进行单元测试了,但是这种代码的实现方式其实并不符合面相对象设计的原则, 下面提出两个问题:

  1. 如果我们不依赖Mock框架该如何mock?
  2. 如果有一天我们不从redis获取数据,而是要从mysql获取数据了,怎么改?直接改Score函数么?那么如果有一天又要从redis获取数据呢?或者有的调用者是从redis获取数据,有的是从mysql获取数据怎么办?

可见上面的方式不太灵活,面对复杂多变的需求无法很好的满足。这时就要求我们改用面相接口编程, 下面是我们使用面相接口编程的方式改进了上面的实现:
user包增加了一个User接口,这个接口有一个函数Score, 然后定义了一个defaultUser, 并且实现了Score函数,最后定一个New函数向外输出这个defaultUser:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package user

import (
"strconv"

"github.com/go-redis/redis"
)

type User interface {
Score(int) (int, error)
}

func New() User {
return defaultUser{}
}

type defaultUser struct{}

func (defaultUser) Score(uid int) (int, error) {
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})

_, err := client.Ping().Result()
if err != nil {
return -1, err
}

val, err := client.Get(strconv.Itoa(uid)).Result()
if err == redis.Nil {
return -1, nil
}
if err != nil {
return -1, err
}

return strconv.Atoi(val)
}

class包调用由原来的通过包直接调用改为了增加一个u变量, 然后调用u.Score来获取信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package class

import (
"user"
)

var u = user.New()

func UserLevel(uid int) string {
score, err := u.Score(uid)
if err != nil {
return "E"
}

switch {
case score < 0:
return "N"
case score <= 60:
return "C"
case score <= 90:
return "B"
case score <= 100:
return "A"
default:
return "W"
}
}

class_test不再依赖mock框架,而是实现了自己的User接口mockUser,替换了user包的defaultUser:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package class

import (
"errors"
"testing"
)

func TestUserLevel(t *testing.T) {
tests := map[string]struct {
input int
want string
}{
"not found user": {input: 1, want: "N"},
"C level": {input: 2, want: "C"},
"B level": {input: 3, want: "B"},
"A level": {input: 4, want: "A"},
"Got Error": {input: 5, want: "E"},
"Wrong Score": {input: 6, want: "W"},
}

u = mockUser{}
for name, tc := range tests {
got := UserLevel(tc.input)
t.Run(name, func(t *testing.T) {
if tc.want != got {
t.Fatalf("expected: %s, got %s", tc.want, got)
}
})
}
}

type mockUser struct{}

func (mockUser) Score(uid int) (int, error) {
switch uid {
case 1:
return -1, nil
case 2:
return 10, nil
case 3:
return 70, nil
case 4:
return 95, nil
case 5:
return -1, errors.New("something was error")
case 6:
return 130, nil
}
return -1, nil
}

运行单元测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
$ go test -v
=== RUN TestUserLevel
=== RUN TestUserLevel/not_found_user
=== RUN TestUserLevel/C_level
=== RUN TestUserLevel/B_level
=== RUN TestUserLevel/A_level
=== RUN TestUserLevel/Got_Error
=== RUN TestUserLevel/Wrong_Score
--- PASS: TestUserLevel (0.00s)
--- PASS: TestUserLevel/not_found_user (0.00s)
--- PASS: TestUserLevel/C_level (0.00s)
--- PASS: TestUserLevel/B_level (0.00s)
--- PASS: TestUserLevel/A_level (0.00s)
--- PASS: TestUserLevel/Got_Error (0.00s)
--- PASS: TestUserLevel/Wrong_Score (0.00s)
PASS
ok class 0.005s

下面再来看上面提出的两个问题:

  1. 如果我们不依赖Mock框架该如何mock?
    答: 根据上面的实现可以看到,我们没有借助任何框架同样完成了Mock的效果
  2. 如果有一天我们不从redis获取数据,而是要从mysql获取数据了,怎么改?直接改Score函数么?那么如果有一天又要从redis获取数据呢?或者有的调用者是从redis获取数据,有的是从mysql获取数据怎么办?
    答: 由于面相接口编程,我们可以在user中增加一个实例实现从mysql获取数据的方法,调用者可以根据需求选择不同的实例,而且如果调用者对这个数据来源有自己的需求,甚至可以自己实现这个接口。

工厂方法

上面的实现我们可以看到每次调用var u = user.New()都会新建一个defaultUser对象,对于有些需要共享defaultUser状态的情况下,例如defaultUser中有一个常驻内存共享的数据, 我们在多个包调用的时候其实那得是不同的对象,为了共享这个数据我们把user.New改成下面的实现:

1
2
3
4
5
var du = defaultUser{}

func New() User {
return du
}

这样每次返回的其实都是同一个defaultUser

更方便的调用

上面我们看出,修改为面相接口编程后我们需要通过依赖注入传递对象,但是这样会对调用者照成麻烦,我们是否可以在优化一下呢?
我们在user中增加一个函数:

1
2
3
func Score(uid int) (int, error) {
return du.Score(uid)
}

这样我们就可以通过user.Score调用du.Score函数了,所以class.go的实现可以改为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package class

import (
"user"
)

func UserLevel(uid int) string {
score, err := user.Score(uid)
if err != nil {
return "E"
}

switch {
case score < 0:
return "N"
case score <= 60:
return "C"
case score <= 90:
return "B"
case score <= 100:
return "A"
default:
return "W"
}
}

看上去不错,但是我们如何进行依赖注入呢?不然单元测试使用的是默认实现,我们没办法做单元测试了。前面其实我们提过依赖注入的方式有一个Get/Set方式,我们可以再修改一下user包:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package user

import (
"strconv"

"github.com/go-redis/redis"
)

type User interface {
Score(int) (int, error)
}

func Score(uid int) (int, error) {
if definedUser != nil {
return definedUser.Score(uid)
}
return du.Score(uid)
}

var definedUser User

func SetUser(u User) {
definedUser = u
}

var du = defaultUser{}

func New() User {
return du
}

type defaultUser struct{}

func (defaultUser) Score(uid int) (int, error) {
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})

_, err := client.Ping().Result()
if err != nil {
return -1, err
}

val, err := client.Get(strconv.Itoa(uid)).Result()
if err == redis.Nil {
return -1, nil
}
if err != nil {
return -1, err
}

return strconv.Atoi(val)
}

class不用修改,class_test修改为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package class

import (
"errors"
"testing"

"user"
)

func TestUserLevel(t *testing.T) {
tests := map[string]struct {
input int
want string
}{
"not found user": {input: 1, want: "N"},
"C level": {input: 2, want: "C"},
"B level": {input: 3, want: "B"},
"A level": {input: 4, want: "A"},
"Got Error": {input: 5, want: "E"},
"Wrong Score": {input: 6, want: "W"},
}

user.SetUser(mockUser{})
for name, tc := range tests {
got := UserLevel(tc.input)
t.Run(name, func(t *testing.T) {
if tc.want != got {
t.Fatalf("expected: %s, got %s", tc.want, got)
}
})
}
}

type mockUser struct{}

func (mockUser) Score(uid int) (int, error) {
switch uid {
case 1:
return -1, nil
case 2:
return 10, nil
case 3:
return 70, nil
case 4:
return 95, nil
case 5:
return -1, errors.New("something was error")
case 6:
return 130, nil
}
return -1, nil
}

我们通过user.SetUser方法用自己的实现替换了之前默认的实现,这样我们就可以方便的进行单元测试了。
运行单元测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
$ go test -v
=== RUN TestUserLevel
=== RUN TestUserLevel/Wrong_Score
=== RUN TestUserLevel/not_found_user
=== RUN TestUserLevel/C_level
=== RUN TestUserLevel/B_level
=== RUN TestUserLevel/A_level
=== RUN TestUserLevel/Got_Error
--- PASS: TestUserLevel (0.00s)
--- PASS: TestUserLevel/Wrong_Score (0.00s)
--- PASS: TestUserLevel/not_found_user (0.00s)
--- PASS: TestUserLevel/C_level (0.00s)
--- PASS: TestUserLevel/B_level (0.00s)
--- PASS: TestUserLevel/A_level (0.00s)
--- PASS: TestUserLevel/Got_Error (0.00s)
PASS
ok class 0.011s

在大多数情况下,我们都是使用的默认实现,只有在我们必须要修改依赖的实现,或者单元测试时才会使用其他的实现,所以为了大多数的场景下调用简单,我们应该尽量使用这种方式来实现。

总结

本文主要回顾了一下关于单元测试的一些理论知识:

  • 测试的粒度应该是测试包中的可导出函数
  • 测试的原则告诉我们应该是变测试变开发, 相互交替进行
  • 测试的目的应该是测试行为,而不是测试具体的实现

关于Go的单元测试可以分为三个阶段:

  • 初级阶段: 主要是认识Go的单元测试基本写法,以及如何利用Go的工具链运行单元测试及查看单元测试覆盖率的情况
  • 进阶阶段: 主要是举一个单元测试的例子,通过不断改进这个单元测试的写法来告诉我们如何写出更好的单元测试
  • 高级阶段: 介绍了如何写出可测试的函数,面对复杂的调用和多变得需求如何利用面相接口编程和依赖注入改进我们的程序的写法

参考

Test-Driven Development By Example
Testing; how, what, why - Dave
TDD, Where Did It All Go Wrong - Lan Cooper
The Three Laws of TDD.
深度解读 - TDD(测试驱动开发)
如何写出优雅的 Golang 代码
单元测试wiki
How to Write Go Code - Testing
Testing Behavior vs. Testing Implementation
Avoid Testing Implementation Details, Test Behaviours
边界条件测试
代码整洁之道
Mocks Aren’t Stubs
TestDouble