Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
c919b68e
authored
2022-12-20 17:31:08 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
fix masked bug
1 parent
05e0f320
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletions
model/seq_labeling.py
model/seq_labeling.py
View file @
c919b68
...
...
@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens):
# [batch_size, num_heads, seq_len, seq_len]
shape
=
X
.
shape
if
valid_lens
.
dim
()
==
1
:
valid_lens
=
torch
.
repeat_interleave
(
valid_lens
,
shape
[
2
])
valid_lens
=
torch
.
repeat_interleave
(
valid_lens
,
shape
[
1
])
else
:
valid_lens
=
valid_lens
.
reshape
(
-
1
)
# On the last axis, replace masked elements with a very large negative
...
...
@@ -126,8 +126,16 @@ class Attention(nn.Module):
# transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len]
# @: multiply -> [batch_size, num_heads, seq_len, seq_len]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
if
valid_lens
is
not
None
:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens
=
torch
.
repeat_interleave
(
valid_lens
,
repeats
=
self
.
num_heads
,
dim
=
0
)
# attn = attn.softmax(dim=-1)
attn
=
masked_softmax
(
attn
,
valid_lens
)
attn
=
self
.
attn_drop
(
attn
)
# @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head]
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment