深度学习:pytorch nn.Embedding详解

2023-09-18 20:48:41

目录

1 nn.Embedding介绍

1.1 nn.Embedding作用

1.2 nn.Embedding函数描述

1.3 nn.Embedding词向量转化

2 nn.Embedding实战

2.1 embedding如何处理文本

2.2 embedding使用示例

2.3 nn.Embedding的可学习性


1 nn.Embedding介绍

1.1 nn.Embedding作用

nn.Embedding是PyTorch中的一个常用模块,其主要作用是将输入的整数序列转换为密集向量表示。在自然语言处理(NLP)任务中,可以将每个单词表示成一个向量,从而方便进行下一步的计算和处理。

1.2 nn.Embedding函数描述

nn.Embedding是将输入向量化,定义如下:

torch.nn.Embedding(num_embeddings, 
                   embedding_dim, 
                   padding_idx=None, 
                   max_norm=None, 
                   norm_type=2.0, 
                   scale_grad_by_freq=False, 
                   sparse=False, 
                   _weight=None, 
                   _freeze=False, 
                   device=None, 
                   dtype=None)

参数说明:

  • num_embeddings :字典中词的个数
  • embedding_dim:embedding的维度
  • padding_idx(索引指定填充):如果给定,则遇到padding_idx中的索引,则将其位置填0(0是默认值,事实上随便填充什么值都可以)。

注:embeddings中的值是正态分布N(0,1)中随机取值。

1.3 nn.Embedding词向量转化

在PyTorch中,nn.Embedding用来实现词与词向量的映射。nn.Embedding具有一个权重(.weight),形状是(num_words, embedding_dim)。例如一共有100个词,每个词用16维向量表征,对应的权重就是一个100×16的矩阵。

Embedding的输入形状N×W,N是batch size,W是序列的长度,输出的形状是N×W×embedding_dim。

Embedding输入必须是LongTensor,FloatTensor需通过tensor.long()方法转成LongTensor。

Embedding的权重是可以训练的,既可以采用随机初始化,也可以采用预训练好的词向量初始化。

2 nn.Embedding实战

2.1 embedding如何处理文本

在NLP任务中,首先要对文本进行处理,将文本进行编码转换,形成向量表达,embedding处理文本的流程如下:

(1)输入一段文本,中文会先分词(如jieba分词),英文会按照空格提取词

(2)首先将单词转成字典的形式,由于英语中以空格为词的分割,所以可以直接建立词典索引结构。类似于:word2id = {'i' : 1, 'like' : 2, 'you' : 3, 'want' : 4, 'an' : 5, 'apple' : 6} 这样的形式。如果是中文的话,首先进行分词操作。

(3)然后再以句子为list,为每个句子建立索引结构,list [ [ sentence1 ] , [ sentence2 ] ] 。以上面字典的索引来说,最终建立的就是 [ [ 1 , 2 , 3 ] , [ 1 , 4 , 5 , 6 ] ] 。这样长短不一的句子

(4)接下来要进行padding的操作。由于tensor结构中都是等长的,所以要对上面那样的句子做padding操作后再利用 nn.Embedding 来进行词的初始化。padding后的可能是这样的结构

[ [ 1 , 2 , 3, 0 ] , [ 1 , 4 , 5 , 6 ] ] 。其中0作为填充。(注意:由于在NMT任务中肯定存在着填充问题,所以在embedding时一定存在着第三个参数,让某些索引下的值为0,代表无实际意义的填充)

2.2 embedding使用示例

比如有两个句子:

  • I want a plane

  • I want to travel to Beijing

将两个句子转化为ID映射:

{I:1,want:2,a:3,plane:4,to:5,travel:6,Beijing:7}

转化成ID表示的两个句子如下:

  • 1,2,3,4

  • 1,2,5,6,5,7

import torch
from torch import nn

# 创建最大词个数为10,每个词用维度为4表示
embedding = nn.Embedding(10, 4)

# 将第一个句子填充0,与第二个句子长度对齐
in_vector = torch.LongTensor([[1, 2, 3, 4, 0, 0], [1, 2, 5, 6, 5, 7]])
out_emb = embedding(in_vector)
print(in_vector.shape)
print((out_emb.shape))
print(out_emb)
print(embedding.weight)

运行结果显示如下:

torch.Size([2, 6])
torch.Size([2, 6, 4])
tensor([[[-0.6642, -0.6263,  1.2333, -0.6055],
         [ 0.9950, -0.2912,  1.0008,  0.1202],
         [ 1.2501,  0.1923,  0.5791, -1.4586],
         [-0.6935,  2.1906,  1.0595,  0.2089],
         [ 0.7359, -0.1194, -0.2195,  0.9161],
         [ 0.7359, -0.1194, -0.2195,  0.9161]],

        [[-0.6642, -0.6263,  1.2333, -0.6055],
         [ 0.9950, -0.2912,  1.0008,  0.1202],
         [-0.3216,  1.2407,  0.2542,  0.8630],
         [ 0.6886, -0.6119,  1.5270,  0.1228],
         [-0.3216,  1.2407,  0.2542,  0.8630],
         [ 0.0048,  1.8500,  1.4381,  0.3675]]], grad_fn=<EmbeddingBackward0>)
Parameter containing:
tensor([[ 0.7359, -0.1194, -0.2195,  0.9161],
        [-0.6642, -0.6263,  1.2333, -0.6055],
        [ 0.9950, -0.2912,  1.0008,  0.1202],
        [ 1.2501,  0.1923,  0.5791, -1.4586],
        [-0.6935,  2.1906,  1.0595,  0.2089],
        [-0.3216,  1.2407,  0.2542,  0.8630],
        [ 0.6886, -0.6119,  1.5270,  0.1228],
        [ 0.0048,  1.8500,  1.4381,  0.3675],
        [ 0.3810, -0.7594, -0.1821,  0.5859],
        [-1.4029,  1.2243,  0.0374, -1.0549]], requires_grad=True)

注意:

  • 句子中的ID不能大于最大词的index(上面例子中,不能大于10)
  • embeding的输入必须是维度对齐的,如果长度不够,需要预先做填充

2.3 nn.Embedding的可学习性

nn.Embedding中的参数并不是一成不变的,它也是会参与梯度下降的。也就是更新模型参数也会更新nn.Embedding的参数,或者说nn.Embedding的参数本身也是模型参数的一部分。

import torch
from torch import nn

# 创建最大词个数为10,每个词用维度为4表示
embedding = nn.Embedding(10, 4)

# 将第一个句子填充0,与第二个句子长度对齐
in_vector = torch.LongTensor([[1, 2, 3, 4, 0, 0], [1, 2, 5, 6, 5, 7]])

optimizer = torch.optim.SGD(embedding.parameters(), lr=0.01)
criteria = nn.MSELoss()

for i in range(1000):
    outputs = embedding(torch.LongTensor([1, 2, 3, 4]))
    loss = criteria(outputs, torch.ones(4, 4))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

print(embedding.weight)
new_output = embedding(in_vector)
print(new_output)

经过1000epochs的训练后,查看新的编码结果,显示如下:

Parameter containing:
tensor([[-0.2475, -1.3436, -0.0449,  0.2093],
        [ 0.4831,  0.5887,  1.2278,  1.1106],
        [ 1.1809,  0.7451,  0.2049,  1.3053],
        [ 0.7369,  1.1276,  1.0066,  0.4399],
        [ 1.3064,  0.3979,  0.8753,  0.9410],
        [-0.6222,  0.2574,  1.1211,  0.1801],
        [-0.5072,  0.2564,  0.5500,  0.3136],
        [-1.7473,  0.0504, -0.0633, -0.3138],
        [-2.4507, -0.6092,  0.0348, -0.4384],
        [ 0.9458, -0.2867, -0.0285,  1.1842]], requires_grad=True)
tensor([[[ 0.4831,  0.5887,  1.2278,  1.1106],
         [ 1.1809,  0.7451,  0.2049,  1.3053],
         [ 0.7369,  1.1276,  1.0066,  0.4399],
         [ 1.3064,  0.3979,  0.8753,  0.9410],
         [-0.2475, -1.3436, -0.0449,  0.2093],
         [-0.2475, -1.3436, -0.0449,  0.2093]],

        [[ 0.4831,  0.5887,  1.2278,  1.1106],
         [ 1.1809,  0.7451,  0.2049,  1.3053],
         [-0.6222,  0.2574,  1.1211,  0.1801],
         [-0.5072,  0.2564,  0.5500,  0.3136],
         [-0.6222,  0.2574,  1.1211,  0.1801],
         [-1.7473,  0.0504, -0.0633, -0.3138]]], grad_fn=<EmbeddingBackward0>)

权重参数和编码结果都发生了很大变化,所以nn.Embedding在构建模型过程中,可以作为模型的一部分,进行共同训练。

更多推荐

死锁详细解读

目录死锁(1)一、死锁的定义二、产生死锁的原因三、产生死锁的四个必要条件四、解决死锁的方法死锁(2)第三节死锁避免一、死锁避免的概念二、安全状态与安全序列三、银行家算法第四节、死锁的检测与解除一、死锁的检测和解除二、死锁检测的算法三、解除死锁的方法死锁(3)第五节资源分配图一、资源分配图二、死锁定理第六节哲学家就餐问题

SIEM:网络攻击检测

如果您正在寻找一种能够检测环境中的网络威胁、发送实时警报并自动执行事件响应的网络攻击检测平台,Log360SIEM解决方案可以完成所有这些以及更多,能够准确检测安全威胁并遏制网络攻击。网络攻击检测能力基于规则的攻击检测MITREATT&CK实现来检测APTS基于ML的行为分析基于规则的攻击检测使用从Log360强大的关

Spring Cloud Alibaba Nacos注册中心(单机)

文章目录SpringCloudAlibabaNacos注册中心(单机)1.docker安装nacos(先别着急)2.配置nacos持久化到mysql、2.1properties文件3.java注册3.1POM文件3.2properties文件3.3测试配置中心4.注册中心4.1配置文件4.2测试类4.3补充演示Spri

Vivado初体验LED工程

文章目录前言一、PL和PS二、LED硬件介绍三、创建Vivado工程四、创建VerilogHDL文件五、添加管脚约束六、添加时序约束七、生成BIT文件八、仿真测试九、下载测试前言本节我们要做的是熟练使用Vivado创建工程并实现对LED灯控制,每秒钟控制开发板上的LED灯翻转一次,实现亮、灭、亮、灭的控制。会控制LED

中国这么多 Java 开发者,应该诞生出生态级应用开发框架

1、必须要有,不然就永远不会有应用开发框架,虽然没有芯片、操作系统、数据库、编程语言这些重要。但是最终呈现在用户面前的,总是有软件部分。而软件系统开发,一般都需要应用开发框架,它是软件系统的基础性部件之一。很多很多软件系统都会有Java开发的部分,尤其是政府部门的软件系统大量的使用了Java。市场非常的大,我们有很多的

【国产32位mcu】电动车控制芯片CS32F031C8T6的应用

近年来,随着“新国标”的落地,双轮电动车在智能化、强性能、安全性等方面不断演进,带动了新一轮的换车高峰。电动车控制器作为双轮电动车的核心部件,迎来新的增长。芯海科技32位MCUCS32F031C8T6,作为电动车控制器的主控MCU芯片,很好地满足了双轮电动车在户外工作中的高温宽、高耐潮的工作环境,以及PWM、ADC等高

《DevOps实践指南》- 读书笔记(八)

DevOps实践指南Part6集成信息安全、变更管理和合规性的技术实践22.将信息安全融入每个人的日常工作22.1将安全集成到开发迭代的演示中22.2将安全集成到缺陷跟踪和事后分析会议中22.3将预防性安全控制集成到共享源代码库及共享服务中22.4将安全集成到部署流水线中22.5保证应用程序的安全性22.6确保软件供应

BD就业复习第三天

1.连续活跃区间表的实现思路实现连续活跃区间表是数据仓库中常见的需求,通常用于分析用户或实体在一段时间内的活跃情况。以下是一种可能的实现思路:1.数据模型设计:首先,您需要设计一个数据模型来存储连续活跃区间。通常,这个表包含以下字段:用户/实体ID:标识活跃实体的唯一标识符。开始日期:活跃区间的开始日期或时间戳。结束日

【DevOps系列】Docker数据卷(volume)详解

【DevOps系列】Docker数据卷(volume)详解文章目录【DevOps系列】Docker数据卷(volume)详解一、概述二、数据卷三、为什么使用数据卷volume数据卷的作用:数据卷的特点:四、数据卷volume基本操作4.1创建数据卷4.2查看数据卷4.3查看数据卷详细信息4.4数据卷删除五、数据卷的使用

Python发布订阅模式

Python发布订阅模式1、broadcast-service模块2、基本使用3、使用装饰器4、发布Topic带参数1、broadcast-service模块Python发布订阅模式可以实现程序间的松耦合broadcast-service是一个轻量级的Python发布订阅者框架,且支持同步、异步、多主题订阅等不同场景下

【Python基础】S01E02 列表

S01E02列表列表是什么列表的操作修改、添加和删除元素列表排序列表倒序列表长度遍历整个列表数值列表创建数值列表数值列表简单统计计算列表推导式列表切片复制列表列表是什么在Python中,用方括号([])表示列表,用逗号分隔其中的元素。bicycles=['trek','cannon','redline','specia

热文推荐