深度学习笔记--Transformer中position encoding的源码理解与实现
创始人
2024-01-28 15:35:20
0

1--源码

import torch
import math
import numpy as np
import torch.nn as nnclass Pos_Embed(nn.Module):def __init__(self, channels, num_frames, num_joints):super().__init__()# 根据帧序和节点序生成位置向量pos_list = [] for tk in range(num_frames):for st in range(num_joints):pos_list.append(st)position = torch.from_numpy(np.array(pos_list)).unsqueeze(1).float()  # num_frames*num_joints, 1pe = torch.zeros(num_frames * num_joints, channels)  # T*N, Cdiv_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))pe[:, 0::2] = torch.sin(position * div_term)  # 偶数列 # 偶数C维度sinpe[:, 1::2] = torch.cos(position * div_term)  # 奇数列 # 奇数C维度cospe = pe.view(num_frames, num_joints, channels).permute(2, 0, 1).unsqueeze(0)  # T N C -> C T N -> 1 C T Nself.register_buffer('pe', pe)def forward(self, x):  # nctv # BCTNx = self.pe[:, :, :x.size(2)]return xif __name__ == "__main__":B = 2C = 4T = 120N = 25x = torch.rand((B, C, T, N))Pos_embed_1 = Pos_Embed(C, T, N)PE = Pos_embed_1(x)# print(PE.shape) # 1 C T Nx = x + PEprint("All Done !")

2--源码分析与理解

原理理解:Positional Encoding(位置编码)

代码解释:

①代码 div_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)):

令:channels = C, torch.arange(0, channels, 2).float() = k(则k = 0, 2, ..., C-2);

-(math.log(10000.0) / channels)  \large {\color{Red} =\frac{-\log_{e}1000}{C}}

则:torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)\large {\color{Red} =\frac{-k\log_{e}10000}{C}}

torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))\LARGE {\color{Red} =e^{\frac{-k\log_{e}10000}{C}} = e^{\log_{e}\frac{-10000k}{C}} = \frac{-10000k}{C}};

②代码:pe[:, 0::2] = torch.sin(position * div_term)  pe[:, 1::2] = torch.cos(position * div_term):

令:position = p,则position * div_term\large {\color{Red} =p*\frac{-10000k}{C}=\frac{p}{10000^{\frac{k}{c}}}};

k等价为2ipe[:, 0::2]pe[:, 1::2]分别取行数列和奇数列,就可以得到上图绿框所示的公式。

3--参考

参考1

参考2

相关内容

热门资讯

云南3岁女童头上插刀淡定就医,...   近日,“云南3岁女童头上插刀淡定就医”的视频在网上引发关注。8月17日,记者从当地相关部门获悉,...
“17岁主播被要求陪榜一大哥聊...   今年7月,一则14岁少女与MCN机构(网络信息内容多渠道分发服务机构)解约被索赔的新闻引发社会广...
澳大利亚悉尼发生枪击案致1死1...   中新网悉尼8月18日电 据澳大利亚新南威尔士州警方消息,新南威尔士州首府悉尼当地时间17日晚发生...
这个“恐怖游戏”,为啥会在孩子...   受访专家:张惠姗  国家二级心理咨询师  广东某学校心理老师  广东省汕尾市暖阳社工服务中心首席...
“2岁女儿被抱走” 发布者被拘...   今年7月23日,一位名叫“诺言”的网民在抖音短视频平台发布了一条“2岁女童王喵喵走失的寻人启事”...
很有用!微信转账,务必注意这个...   微信转账暗藏玄机?  这个步骤不做小心吃大亏!  8月17日  关于微信转账的一个话题词冲上热搜...
太危险!骑行者紧跟大货车尾部 ...   太危险!骑行者紧跟大货车尾部疑似利用货车车身“破风”,车内石块堆积,一旦急刹车危险万分(编辑:小...
深圳“00后”女孩代送外卖遭抢...   深圳“00后”女孩代送外卖遭抢单,老人抓伤女孩抢地盘。
成都世运会闭幕   第12届世界运动会8月17日晚在四川省成都市圆满闭幕。过去11天里,来自116个国家和地区的近4...
南宁一房东收房时拒退租金 威胁...   南宁一房东收房时拒退租金,带来人员大闹出租屋,威胁、辱骂租客。(编辑 文文)