Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
fb5f4ba1
authored
2022-12-14 17:17:44 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add no other
1 parent
30cf7dbe
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
config/vit.yaml
solver/vit_solver.py
config/vit.yaml
View file @
fb5f4ba
...
...
@@ -39,6 +39,7 @@ solver:
name
:
'
VITSolver'
args
:
epoch
:
100
no_other
:
false
optimizer
:
name
:
'
Adam'
...
...
solver/vit_solver.py
View file @
fb5f4ba
...
...
@@ -29,6 +29,7 @@ class VITSolver(object):
self
.
optimizer
=
build_optimizer
(
cfg
)(
self
.
model
.
parameters
(),
**
cfg
[
'solver'
][
'optimizer'
][
'args'
])
self
.
hyper_params
=
cfg
[
'solver'
][
'args'
]
self
.
no_other
=
self
.
hyper_params
[
'no_other'
]
try
:
self
.
epoch
=
self
.
hyper_params
[
'epoch'
]
except
Exception
:
...
...
@@ -36,9 +37,8 @@ class VITSolver(object):
self
.
logger
,
self
.
log_dir
=
get_logger_and_log_dir
(
**
cfg
[
'solver'
][
'logger'
])
@staticmethod
def
evaluate
(
y_pred
,
y_true
,
thresholds
=
0.5
,
no_other
=
False
):
if
no_other
:
def
evaluate
(
self
,
y_pred
,
y_true
,
thresholds
=
0.5
):
if
self
.
no_other
:
return
(
y_pred
.
argmax
(
1
)
==
y_true
.
argmax
(
1
))
.
type
(
torch
.
float
)
.
sum
()
.
item
()
else
:
y_pred_idx
=
torch
.
argmax
(
y_pred
,
dim
=
1
)
+
1
...
...
@@ -59,7 +59,10 @@ class VITSolver(object):
for
batch
,
(
X
,
y
)
in
enumerate
(
self
.
train_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
if
self
.
no_other
:
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
))
else
:
pred
=
torch
.
nn
.
Sigmoid
(
self
.
model
(
X
))
correct
+=
self
.
evaluate
(
pred
,
y
)
...
...
@@ -88,7 +91,10 @@ class VITSolver(object):
for
X
,
y
in
self
.
val_loader
:
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
if
self
.
no_other
:
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
))
else
:
pred
=
torch
.
nn
.
Sigmoid
(
self
.
model
(
X
))
correct
+=
self
.
evaluate
(
pred
,
y
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment