Transformer部分实现
0.导包
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy
1.word_embedding
batch_size = 2
# 词表大小
max_num_src_words = 8
max_num_tgt_words = 8
# 模型大小
model_dim = 8
# 最大序列长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5
# src_len = torch.randint(2, 5, (batch_size, ))
# tgt_len = torch.randint(2, 5, (batch_size, ))
src_len = torch.tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3]).to(torch.int32)
# 储存idx,用于在词表中找到对应词
# 将不同长度的句子pad成同一长度
src_seq = [F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seq_len-L)) for L in src_len]
src_seq = [torch.unsqueeze(x, 0) for x in src_seq]
src_seq = torch.cat(src_seq, dim = 0)
tgt_seq = [F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_tgt_seq_len-L)) for L in tgt_len]
tgt_seq = [torch.unsqueeze(x, 0) for x in tgt_seq]
tgt_seq = torch.cat(tgt_seq, dim = 0)
src_embdding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embdding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
# src_embdding_table.shape = (9, 8)
2.positional embedding
用三角函数具有泛化性,比如训练时最长长度是t,但是实际应用中不小心出现了t+ϕ
那么也可以由前面的三角函数线性表出。
证明如下:
# 构造position embedding
pos_mat = torch.arange(max_posit_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/model_dim)
pe_embedding_table = torch.zeros(max_posit_len, model_dim)
pe_embedding_table[:,0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:,1::2] = torch.cos(pos_mat / i_mat)
pe_embedding = nn.Embedding(max_posit_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len]).to(torch.int32)
src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
3.Attention
Attention(Q, K, V)
= softmax(QKT√(dK)V)
这里除掉√dk是为了让softmax
后各概率差距不要过大
可以看到,当score(=QKT)系数较大的时候,很多地方梯度消失了。
3.1 encoder mask矩阵构造
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len)-L)), 0) \
for L in src_len], 0), 1)
# 这里用(2, 4, 1) dot (2, 1, 4)去表示每个句子中每个词和其他词的关联
valid_encoder_pos = torch.bmm(valid_encoder_pos.transpose(1, 2), valid_encoder_pos)
invalid_encoder_pos = 1 - valid_encoder_pos
mask_encoder_self_attention = invalid_encoder_pos.to(torch.bool)
# 模拟一下
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = torch.masked_fill(score ,mask_encoder_self_attention, -np.inf)
prob = F.softmax(masked_score, dim=-1)
3.2 intra-attention mask矩阵构造
# Q @ K.T shape = [batch_size, tgt_len, src_len]
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len)-L)), 0) \
for L in src_len], 0), 2)
valid_decoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(tgt_len)-L)), 0) \
for L in tgt_len], 0), 2)
valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1-valid_cross_pos_matrix
masked_cross_pos_matrix = invalid_cross_pos_matrix.to(torch.bool)
3.3 decoder-self-attention mask矩阵构造
下三角矩阵
valid_tri_matrix = [F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len)-L, 0, max(tgt_len)-L)) for L in tgt_len]
valid_tri_matrix = [torch.unsqueeze(x, dim=0) for x in valid_tri_matrix]
valid_tri_matrix = torch.cat(valid_tri_matrix, dim = 0)
invalid_decoder_tri_matrix = 1 - valid_tri_matrix
masked_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
masked_decoder_tri_matrix
3.4 构造scale self_attention
def scaled_dot_product_attention(Q, K, V, attn_mask):
# shape of QKV is [batch_size * head, seq_len, model_dim / head]
score = torch.bmm(Q, K.tranpose(-2, -1))
masked_score = torch.masked_fill(score, attn_mask, -1e9)
prob = F.softmax(masked_score, -1)
context = torch.bmm(prob, V)
return context
4.Loss Mask
logits = torch.rand(2,3,4)
# batch_size=2, seq_len=3, vocab_size=4
label = torch.randint(0, 4, (2, 3))
# 取值为[0, 3]
mask = torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(tag_len)-L)), 0) for L in tgt_len])
F.cross_entropy(logits, label, reduction='none') * mask
5.模型总结
1.无先验假设(如局部关联性,有序建模性)(先验假设即注入更多人的经验
2.数据量的要求与先验假设的程度成反比
3.核心计算在self-attention上,O(n2)(想要降低这个复杂度,就要加入先验假设,比如不让attention看全部序列,仅看周围,这就是局部关联性
4.Q:Multi_head_self_attention和FNN的作用一样吗?
A:是不一样的,attention是对位置信息进行混合,FNN是对位置上的特征进行混合
5.