Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
3e58f6b0
authored
2022-12-19 11:48:58 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
modify evaluate
1 parent
40ca6fe1
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
15 deletions
config/sl.yaml
data/create_dataset2.py
solver/sl_solver.py
config/sl.yaml
View file @
3e58f6b
...
...
@@ -16,7 +16,7 @@ dataloader:
model
:
name
:
'
SLTransformer'
args
:
seq_lens
:
20
0
seq_lens
:
16
0
num_classes
:
10
embed_dim
:
9
depth
:
6
...
...
data/create_dataset2.py
View file @
3e58f6b
...
...
@@ -166,7 +166,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
X
=
list
()
y_true
=
list
()
for
i
in
range
(
20
0
):
for
i
in
range
(
16
0
):
if
i
>=
valid_lens
:
X
.
append
([
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
])
y_true
.
append
([
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
])
...
...
solver/sl_solver.py
View file @
3e58f6b
...
...
@@ -65,7 +65,7 @@ class SLSolver(object):
train_loss
=
torch
.
zeros
(
1
)
.
to
(
self
.
device
)
correct
=
torch
.
zeros
(
1
)
.
to
(
self
.
device
)
for
batch
,
(
X
,
y
,
valid_lens
)
in
enumerate
(
self
.
train_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
X
,
y
,
valid_lens
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
),
valid_lens
.
to
(
self
.
device
)
pred
=
self
.
model
(
X
,
valid_lens
)
# [batch_size, seq_len, num_classes]
...
...
@@ -97,7 +97,7 @@ class SLSolver(object):
val_loss
=
torch
.
zeros
(
1
)
.
to
(
self
.
device
)
correct
=
torch
.
zeros
(
1
)
.
to
(
self
.
device
)
for
X
,
y
,
valid_lens
in
self
.
val_loader
:
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
X
,
y
,
valid_lens
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
),
valid_lens
.
to
(
self
.
device
)
# pred = torch.nn.Sigmoid()(self.model(X))
pred
=
self
.
model
(
X
,
valid_lens
)
...
...
@@ -162,25 +162,30 @@ class SLSolver(object):
label_true_list
=
[]
label_pred_list
=
[]
for
X
,
y
in
self
.
val_loader
:
X
,
y_true
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
for
X
,
y
,
valid_lens
in
self
.
val_loader
:
X
,
y_true
,
valid_lens
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
),
valid_lens
.
to
(
self
.
device
)
# pred = torch.nn.Sigmoid()(self.model(X))
pred
=
self
.
model
(
X
)
y_pred
=
torch
.
nn
.
Sigmoid
()(
pred
)
y_pred
=
self
.
model
(
X
,
valid_lens
)
y_pred_idx
=
torch
.
argmax
(
y_pred
,
dim
=
1
)
+
1
y_pred_is_other
=
(
torch
.
amax
(
y_pred
,
dim
=
1
)
>
0.5
)
.
int
()
# [batch_size, seq_len, num_classes]
y_pred_sigmoid
=
torch
.
nn
.
Sigmoid
()(
y_pred
)
# [batch_size, seq_len]
y_pred_idx
=
torch
.
argmax
(
y_pred_sigmoid
,
dim
=-
1
)
+
1
# [batch_size, seq_len]
y_pred_is_other
=
(
torch
.
amax
(
y_pred_sigmoid
,
dim
=-
1
)
>
0.5
)
.
int
()
y_pred_rebuild
=
torch
.
multiply
(
y_pred_idx
,
y_pred_is_other
)
y_true_idx
=
torch
.
argmax
(
y_true
,
dim
=
1
)
+
1
y_true_is_other
=
torch
.
sum
(
y_true
,
dim
=
1
)
y_true_idx
=
torch
.
argmax
(
y_true
,
dim
=
-
1
)
+
1
y_true_is_other
=
torch
.
sum
(
y_true
,
dim
=
-
1
)
.
int
(
)
y_true_rebuild
=
torch
.
multiply
(
y_true_idx
,
y_true_is_other
)
label_true_list
.
extend
(
y_true_rebuild
.
cpu
()
.
numpy
()
.
tolist
())
label_pred_list
.
extend
(
y_pred_rebuild
.
cpu
()
.
numpy
()
.
tolist
())
# masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
for
idx
,
seq_result
in
enumerate
(
y_true_rebuild
.
cpu
()
.
numpy
()
.
tolist
()):
label_true_list
.
extend
(
seq_result
[:
valid_lens
.
cpu
()
.
numpy
()[
idx
]])
for
idx
,
seq_result
in
enumerate
(
y_pred_rebuild
.
cpu
()
.
numpy
()
.
tolist
()):
label_pred_list
.
extend
(
seq_result
[:
valid_lens
.
cpu
()
.
numpy
()[
idx
]])
acc
=
accuracy_score
(
label_true_list
,
label_pred_list
)
cm
=
confusion_matrix
(
label_true_list
,
label_pred_list
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment