transformer系列2---transformer架构详细解析

2023-09-21 11:12:18

请添加图片描述

Transformer是一个序列到序列的模型,包含一个encoder和一个decoder,每个encoder或decoder均由 H个相同的结构堆叠而成。编码器模块主要包含:多头自注意力模块和前馈神经网络,同时采用残差连接增加网络深度,采用layerNorm进行归一处理。解码器模块基本结构和编码器模块相同,只是使用mask的多头注意力模块,然后增加一个交叉注意力模块,再输入逐点前馈网络,最后经过一个全连接层和一个Softmax,得到输出结果,每个输出结果又会作为下一轮decoder的输入,因此,transformer属于自回归。整体结构如上图所示,下面对每个模块进行详细解读。

Encoder

1 输入

输入部分

1.1 Embedding 词嵌入

1.1.1 Embedding 定义

首先介绍一下常见的几种编码方式:

  1. 整数编码:用一种数字来代表一个词
  2. one-hot 编码:用一个序列向量表示一个词,该向量只有词汇表中表示这个单词的位置是1,其余都是0,序列向量长度是预定义的词汇表中单词数量。
  3. word embedding 词嵌入编码:将词映射或者嵌入(Embedding)到另一个数值向量空间(常常存在降维),它以one hot的稀疏矩阵为输入,经过一个线性变换(查表)将其转换成一个密集矩阵的过程。Embedding的原理是使用矩阵乘法来进行降维,节约存储空间。例如下图,一个2×6的矩阵,乘以6×3的矩阵,得到2×3的矩阵。虽然矩阵的变小,但数字蕴藏的信息没有改变,只是按照某一种映射关系将矩阵映射到一个新的维度的矩阵。每个单词都表示为一个3维的浮点值向量。可以将其理解为“查找表”,通过查找表查找密集矩阵中的值,得到每个单词进行编码。比如下面的第一个词,查找密集矩阵中第一行即可得到新的表示
    在这里插入图片描述
1.1.2 几种编码方式对比

下表是对几种编码方式的对比
在这里插入图片描述

1.1.3 实现代码
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

1.2 位置编码

1.2.1 使用位置编码原因

由于句子中的每个词语同时通过Transformer的编码器/解码器,模型本身对于每个词语的位置/顺序没有任何概念。因此,需要将词语的顺序融入到模型中,使模型能够获取词语的位置信息,也就是位置编码。位置编码需要满足以下条件:

  1. 词语在句子中的每个位置应输出一个唯一的编码。
  2. 不同长度的句子之间的任意两个位置之间的距离应保持一致。
  3. 模型应能适应更长的句子,其取值应该受到限制。
  4. 它必须是确定值。
1.2.2 位置编码方式
  1. 固定位置编码

  2. 可学习的位置编码

1.2.3 位置编码代码

2 注意力 Attention

在这里插入图片描述

2.1 自注意力self-attention

2.1.1 QKV含义
2.1.2 自注意力公式

在这里插入图片描述
dk代表矩阵K的维度,这里进行归一化计算,假设q和k的分量是独立的随机变量,均值为0,方差为1。那么它们的点积在这里插入图片描述的均值为0,方差为dk,点积会变大,因此需要归一化计算。

2.1.3 自注意力计算流程
  1. 对于输入的序列a1,a2,a3,a4,通过Wq,Wk,Wv三个矩阵,将其转换为q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4。
    请添加图片描述

  2. q1和k1—k4分别计算相似性,得到相似性系数在这里插入图片描述

  3. 计算soft-max,得到在这里插入图片描述
    在这里插入图片描述
    该步骤

  4. 将上一步得到的注意力分数与v1,v2,v3,v4相乘,求和,得到。
    具体流程如下图:
    请添加图片描述

  5. 总结:上述流程用矩阵表示为:
    请添加图片描述

2.2 多头注意力 Multi-Head Attention

2.2.1 多头注意力公式

在这里插入图片描述
这里在这里插入图片描述在这里插入图片描述

2.2.2 多头注意力计算流程

顾名思义,多头注意力就是用H个不同参数的 QKV注意力结构对输入的 Dk维度的 query,key和 value进行计算;然后将所有输出结果进行拼接,并将N*d_v维度映射回 D_m维,注意,这里的H个head是并行计算的,而不是一个一个head串行。
在这里插入图片描述

2.3 模型使用的三种注意力

  1. 自注意力 self attention
    在 Transformer编码器中使用,Q=K=V=X, X为输入或前一层的输出
  2. 掩码自注意力 masked self attention
    在Transformer解码器中使用,此时自注意力限制为:每个位置的 Query只注意到该位置和之前的所有 Key-Value值。具体实现为:对位归一化的注意力矩阵进行掩码操作,即当 i<j时,设置 A_ij 为负无穷。这种自注意力也称之为自回归(autogress)或因果(causal)注意力。
  3. 交叉注意力 cross Attention
    Q,K,V的来源不同,其中 Q(query)来自于输入或上一层的输出,而 K和 V则来自于编码器的输出。

3 前馈神经网络

逐点前馈网络(Position-wise Feed-Forward Networks)为全连接(通常为两层)前馈结构,对特征图每个点进行逐点计算,计算公式为:
在这里插入图片描述
x是多头注意力的输出,第一层全连接结构经过RELU激活函数后,再计算一次全连接,W1,W2,b1,b2为可学习的参数。

4 残差连接和归一化

Transformer结构在每个模块中采用了残差连接,并对连接后的张量进行归一化操作可以表示为:
LN(X)= LayerNorm(x + Sublayer(x))

Sublayer表示当前层的操作,比如当前层计算多头注意力,这里就是Multihead(x),若是逐点前馈网络,就是FFN(x),LayerNorm表示层归一化操作。残差连接和归一化主要用于构建更深的网络结构,减缓梯度消失问题。

Decoder

1 输入

Decoder的输入部分与Encoder类似,都是将输入先经过embedding转换,然后增加位置编码,区别是,这里的输入是Decoder上一轮的输出结果,因此,transformer是自回归。

2 mask 多头自注意力

Decoder也有多头自注意力(MHA)模块,不过这里需要对当前单词和之后的单词做mask,也就是模型只会看到当前位置之前的内容,这样,才能确保预测仅依赖于已生成的输出单词。

3 多头交叉注意力

交叉注意力也成为跨注意力,这里的K和V是来源于encoder的输出,Q是来源于Decoder的Mask MHA,而解码器自注意力中,Q、K和V都来自上一个解码器层的输出。

4 前馈神经网络

这部分和encoder相同,不再赘述

5 输出

经过decoder后,网络再经过全连接层和softmax,即可得到输出结果

参考内容:
感谢李宏毅老师的视频讲解,文章中的图来源于李老师的PPT。

更多推荐

markdown学习笔记

markdown学习笔记1.文字(依靠HTML)1.1文字缩进-空格转义符单字符空:&emsp;半字符空:&ensp;1.2文字对齐「居中:」<center>居中</center>or<palign="center">居中</p>「左对齐:」<palign="left">左对齐</p>「右对齐:」<palign="ri

VUE build:gulp打包:测试、正式环境

目录项目结构GulpVUE使用GulpVue安装GulpVue定义Gulp.jspackage.jsonbuild文件夹config文件夹static-config文件夹项目结构GulpGulp是一个自动化构建工具,可以帮助前端开发者通过自动化任务来管理工作流程。Gulp使用Node.js的代码编写,可以更加灵活地管理

STM32 基础学习——GPIO位结构(江科大老师教程)

一、GPIO内部结构1、GPIO外设名称是由GPIOA、GPIOB、GPIOC等命名,共有16个引脚2、每个GPIO模块内,主要包含了寄存器和驱动器这些东西3、寄存器写1,对应的端口就是高电平。写0,对应的端口就是低电平4、寄存器只负责存储数据这是GPIO结构图,总体来说上半部分是输入部分,下半部分是输出部分这是部分是

Bartender for Mac菜单栏图标自定义

Bartender是一款可以帮助用户更好地管理和组织菜单栏图标的macOS软件。它允许用户隐藏和重新排列菜单栏图标,从而减少混乱和杂乱。以下是Bartender的主要特点:菜单栏图标隐藏:Bartender允许用户隐藏菜单栏图标,只在需要时显示。这样可以减少菜单栏的拥挤和视觉干扰,使界面更加整洁和专注。可自定义的菜单栏

Layui快速入门之第九节 表格事件的使用

目录一:事件二:头部工具栏事件三:排序切换事件四:列拖拽宽度后的事件五:列筛选(显示或隐藏)后的事件六:行单击和双击事件七:行右键菜单事件八:单元格编辑事件九:单元格工具事件十:复选框事件十一:单选框事件十二:尾部分页栏事件一:事件table.on('event(filter)',callback);参数event(f

Spring后处理器-BeanPostProcessor

Spring后处理器-BeanPostProcessorBean被实例化后,到最终缓存到名为singletonObjects单例池之前,中间会经过bean的初始化过程((该后处理器的执行时机)),例如:属性的填充、初始化方法init的执行等,其中有一个对外拓展的点BeanPostProcessor,我们称之为bean后

2D游戏开发和3D游戏开发有什么不同?

2D游戏开发和3D游戏开发是两种不同类型的游戏制作方法,它们之间有一些显著的区别:1.图形和视觉效果:2D游戏开发:2D游戏通常使用二维图形,游戏世界和角色通常在一个平面上显示。这种类型的游戏具有平面的外观,就像经典的平台游戏,如《超级马里奥》或《糖果传奇》。3D游戏开发:3D游戏使用三维图形,玩家可以在三维环境中自由

MySQL学习系列(3)-每天学习10个知识

目录1.全文搜索(Full-TextSearch)vs.LIKE操作符2.MySQL中的大数据量处理3.分区(Partitioning)在MySQL中的作用和用法4.MySQL中的数据复制(Replication)5.索引的覆盖和索引下推6.预处理语句(PreparedStatements)7.视图和存储过程8.MyS

C语言知识阶段性总结项目:电子词典

项目需求使用TCP实现客户端和服务端通信使用sqlite存放用户信息客户端需要有登录、注册、查询单词、账号查询记录功能服务器需要实时显示在线用户解决方案使用sqlite创建三个数据库,分别存放用户账号密码,单词表,用户查询记录使用链表存放在线用户的信息,在子线程中循环遍历,达到实时显示在线用户的效果主要的功能代码头文件

大数据(九):数据可视化(一)

专栏介绍结合自身经验和内部资料总结的Python教程,每天3-5章,最短1个月就能全方位的完成Python的学习并进行实战开发,学完了定能成为大佬!加油吧!卷起来!全部文章请访问专栏:《Python全栈教程(0基础)》再推荐一下最近热更的:《大厂测试高频面试题详解》该专栏对近年高频测试相关面试题做详细解答,结合自己多年

Mysql---第六篇

系列文章目录文章目录系列文章目录一、分表后非sharding_key的查询怎么处理,分表后的排序?二、mysql主从同步原理一、分表后非sharding_key的查询怎么处理,分表后的排序?可以做一个mapping表,比如这时候商家要查询订单列表怎么办呢?不带user_id查询的话你总不能扫全表吧?所以我们可以做一个映

热文推荐