c919b68e by 周伟奇

fix masked bug

1 parent 05e0f320
......@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens):
# [batch_size, num_heads, seq_len, seq_len]
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[2])
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
......@@ -126,8 +126,16 @@ class Attention(nn.Module):
# transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len]
# @: multiply -> [batch_size, num_heads, seq_len, seq_len]
attn = (q @ k.transpose(-2, -1)) * self.scale
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# attn = attn.softmax(dim=-1)
attn = masked_softmax(attn, valid_lens)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head]
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!