fix masked bug
Showing
1 changed file
with
9 additions
and
1 deletions
| ... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): | ... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): |
| 18 | # [batch_size, num_heads, seq_len, seq_len] | 18 | # [batch_size, num_heads, seq_len, seq_len] |
| 19 | shape = X.shape | 19 | shape = X.shape |
| 20 | if valid_lens.dim() == 1: | 20 | if valid_lens.dim() == 1: |
| 21 | valid_lens = torch.repeat_interleave(valid_lens, shape[2]) | 21 | valid_lens = torch.repeat_interleave(valid_lens, shape[1]) |
| 22 | else: | 22 | else: |
| 23 | valid_lens = valid_lens.reshape(-1) | 23 | valid_lens = valid_lens.reshape(-1) |
| 24 | # On the last axis, replace masked elements with a very large negative | 24 | # On the last axis, replace masked elements with a very large negative |
| ... | @@ -126,8 +126,16 @@ class Attention(nn.Module): | ... | @@ -126,8 +126,16 @@ class Attention(nn.Module): |
| 126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] | 126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] |
| 127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] | 127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] |
| 128 | attn = (q @ k.transpose(-2, -1)) * self.scale | 128 | attn = (q @ k.transpose(-2, -1)) * self.scale |
| 129 | |||
| 130 | if valid_lens is not None: | ||
| 131 | # On axis 0, copy the first item (scalar or vector) for | ||
| 132 | # `num_heads` times, then copy the next item, and so on | ||
| 133 | valid_lens = torch.repeat_interleave( | ||
| 134 | valid_lens, repeats=self.num_heads, dim=0) | ||
| 129 | # attn = attn.softmax(dim=-1) | 135 | # attn = attn.softmax(dim=-1) |
| 130 | attn = masked_softmax(attn, valid_lens) | 136 | attn = masked_softmax(attn, valid_lens) |
| 137 | |||
| 138 | |||
| 131 | attn = self.attn_drop(attn) | 139 | attn = self.attn_drop(attn) |
| 132 | 140 | ||
| 133 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | 141 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | ... | ... |
-
Please register or sign in to post a comment