fix masked bug
Showing
1 changed file
with
9 additions
and
1 deletions
... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): | ... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): |
18 | # [batch_size, num_heads, seq_len, seq_len] | 18 | # [batch_size, num_heads, seq_len, seq_len] |
19 | shape = X.shape | 19 | shape = X.shape |
20 | if valid_lens.dim() == 1: | 20 | if valid_lens.dim() == 1: |
21 | valid_lens = torch.repeat_interleave(valid_lens, shape[2]) | 21 | valid_lens = torch.repeat_interleave(valid_lens, shape[1]) |
22 | else: | 22 | else: |
23 | valid_lens = valid_lens.reshape(-1) | 23 | valid_lens = valid_lens.reshape(-1) |
24 | # On the last axis, replace masked elements with a very large negative | 24 | # On the last axis, replace masked elements with a very large negative |
... | @@ -126,8 +126,16 @@ class Attention(nn.Module): | ... | @@ -126,8 +126,16 @@ class Attention(nn.Module): |
126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] | 126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] |
127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] | 127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] |
128 | attn = (q @ k.transpose(-2, -1)) * self.scale | 128 | attn = (q @ k.transpose(-2, -1)) * self.scale |
129 | |||
130 | if valid_lens is not None: | ||
131 | # On axis 0, copy the first item (scalar or vector) for | ||
132 | # `num_heads` times, then copy the next item, and so on | ||
133 | valid_lens = torch.repeat_interleave( | ||
134 | valid_lens, repeats=self.num_heads, dim=0) | ||
129 | # attn = attn.softmax(dim=-1) | 135 | # attn = attn.softmax(dim=-1) |
130 | attn = masked_softmax(attn, valid_lens) | 136 | attn = masked_softmax(attn, valid_lens) |
137 | |||
138 | |||
131 | attn = self.attn_drop(attn) | 139 | attn = self.attn_drop(attn) |
132 | 140 | ||
133 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | 141 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | ... | ... |
-
Please register or sign in to post a comment