Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
b3694ec8
authored
2022-12-15 19:24:06 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add eval
1 parent
82a85c6d
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
6 deletions
config/vit.yaml
main.py
solver/vit_solver.py
config/vit.yaml
View file @
b3694ec
...
...
@@ -41,6 +41,7 @@ solver:
epoch
:
100
no_other
:
false
base_on
:
null
model_path
:
null
optimizer
:
name
:
'
Adam'
...
...
main.py
View file @
b3694ec
...
...
@@ -7,6 +7,7 @@ from solver.builder import build_solver
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
default
=
'./config/mlp.yaml'
,
type
=
str
,
help
=
'config file'
)
parser
.
add_argument
(
'-e'
,
'--eval'
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
cfg
=
yaml
.
load
(
open
(
args
.
config
,
'r'
)
.
read
(),
Loader
=
yaml
.
FullLoader
)
...
...
@@ -14,7 +15,11 @@ def main():
# print(torch.cuda.is_available())
solver
=
build_solver
(
cfg
)
solver
.
run
()
if
args
.
eval
:
solver
.
evaluate
()
else
:
solver
.
run
()
if
__name__
==
'__main__'
:
...
...
solver/vit_solver.py
View file @
b3694ec
...
...
@@ -8,6 +8,7 @@ from loss import build_loss
from
model
import
build_model
from
optimizer
import
build_lr_scheduler
,
build_optimizer
from
utils
import
SOLVER_REGISTRY
,
get_logger_and_log_dir
from
sklearn.metrics
import
confusion_matrix
,
accuracy_score
,
classification_report
@SOLVER_REGISTRY.register
()
...
...
@@ -32,6 +33,7 @@ class VITSolver(object):
self
.
hyper_params
=
cfg
[
'solver'
][
'args'
]
self
.
no_other
=
self
.
hyper_params
[
'no_other'
]
self
.
base_on
=
self
.
hyper_params
[
'base_on'
]
self
.
model_path
=
self
.
hyper_params
[
'model_path'
]
try
:
self
.
epoch
=
self
.
hyper_params
[
'epoch'
]
except
Exception
:
...
...
@@ -39,7 +41,7 @@ class VITSolver(object):
self
.
logger
,
self
.
log_dir
=
get_logger_and_log_dir
(
**
cfg
[
'solver'
][
'logger'
])
def
evaluate
(
self
,
y_pred
,
y_true
,
thresholds
=
0.5
):
def
accuracy
(
self
,
y_pred
,
y_true
,
thresholds
=
0.5
):
if
self
.
no_other
:
return
(
y_pred
.
argmax
(
1
)
==
y_true
.
argmax
(
1
))
.
type
(
torch
.
float
)
.
sum
()
.
item
()
else
:
...
...
@@ -80,9 +82,9 @@ class VITSolver(object):
self
.
optimizer
.
step
()
if
self
.
no_other
:
correct
+=
self
.
evaluate
(
pred
,
y
)
correct
+=
self
.
accuracy
(
pred
,
y
)
else
:
correct
+=
self
.
evaluate
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
correct
+=
self
.
accuracy
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
correct
/=
self
.
train_dataset_size
train_loss
/=
self
.
train_loader_size
...
...
@@ -107,9 +109,9 @@ class VITSolver(object):
val_loss
+=
loss
.
item
()
if
self
.
no_other
:
correct
+=
self
.
evaluate
(
pred
,
y
)
correct
+=
self
.
accuracy
(
pred
,
y
)
else
:
correct
+=
self
.
evaluate
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
correct
+=
self
.
accuracy
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
correct
/=
self
.
val_dataset_size
val_loss
/=
self
.
val_loader_size
...
...
@@ -140,3 +142,44 @@ class VITSolver(object):
lr_scheduler
.
step
()
self
.
logger
.
info
(
'==> End Training'
)
def
evaluate
(
self
):
if
isinstance
(
self
.
model_path
,
str
)
and
os
.
path
.
exists
(
self
.
model_path
):
self
.
model
.
load_state_dict
(
torch
.
load
(
self
.
model_path
))
self
.
logger
.
info
(
f
'==> Load Model from {self.model_path}'
)
else
:
return
self
.
model
.
eval
()
label_true_list
=
[]
label_pred_list
=
[]
for
X
,
y
in
self
.
val_loader
:
X
,
y_true
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
if
self
.
no_other
:
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
))
else
:
# pred = torch.nn.Sigmoid()(self.model(X))
pred
=
self
.
model
(
X
)
y_pred
=
torch
.
nn
.
Sigmoid
()(
pred
)
y_pred_idx
=
torch
.
argmax
(
y_pred
,
dim
=
1
)
+
1
y_pred_is_other
=
(
torch
.
amax
(
y_pred
,
dim
=
1
)
>
0.5
)
.
int
()
y_pred_rebuild
=
torch
.
multiply
(
y_pred_idx
,
y_pred_is_other
)
y_true_idx
=
torch
.
argmax
(
y_true
,
dim
=
1
)
+
1
y_true_is_other
=
torch
.
sum
(
y_true
,
dim
=
1
)
y_true_rebuild
=
torch
.
multiply
(
y_true_idx
,
y_true_is_other
)
label_true_list
.
extend
(
y_true_rebuild
.
cpu
()
.
numpy
()
.
tolist
())
label_pred_list
.
extend
(
y_pred_rebuild
.
cpu
()
.
numpy
()
.
tolist
())
acc
=
accuracy_score
(
label_true_list
,
label_pred_list
)
cm
=
confusion_matrix
(
label_true_list
,
label_pred_list
)
report
=
classification_report
(
label_true_list
,
label_pred_list
)
print
(
acc
)
print
(
cm
)
print
(
report
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment