c919b68e by 周伟奇

fix masked bug

1 parent 05e0f320
...@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): ...@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens):
18 # [batch_size, num_heads, seq_len, seq_len] 18 # [batch_size, num_heads, seq_len, seq_len]
19 shape = X.shape 19 shape = X.shape
20 if valid_lens.dim() == 1: 20 if valid_lens.dim() == 1:
21 valid_lens = torch.repeat_interleave(valid_lens, shape[2]) 21 valid_lens = torch.repeat_interleave(valid_lens, shape[1])
22 else: 22 else:
23 valid_lens = valid_lens.reshape(-1) 23 valid_lens = valid_lens.reshape(-1)
24 # On the last axis, replace masked elements with a very large negative 24 # On the last axis, replace masked elements with a very large negative
...@@ -126,8 +126,16 @@ class Attention(nn.Module): ...@@ -126,8 +126,16 @@ class Attention(nn.Module):
126 # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] 126 # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len]
127 # @: multiply -> [batch_size, num_heads, seq_len, seq_len] 127 # @: multiply -> [batch_size, num_heads, seq_len, seq_len]
128 attn = (q @ k.transpose(-2, -1)) * self.scale 128 attn = (q @ k.transpose(-2, -1)) * self.scale
129
130 if valid_lens is not None:
131 # On axis 0, copy the first item (scalar or vector) for
132 # `num_heads` times, then copy the next item, and so on
133 valid_lens = torch.repeat_interleave(
134 valid_lens, repeats=self.num_heads, dim=0)
129 # attn = attn.softmax(dim=-1) 135 # attn = attn.softmax(dim=-1)
130 attn = masked_softmax(attn, valid_lens) 136 attn = masked_softmax(attn, valid_lens)
137
138
131 attn = self.attn_drop(attn) 139 attn = self.attn_drop(attn)
132 140
133 # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] 141 # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head]
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!