Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
69e75f77
authored
2022-12-14 20:10:44 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add load model
1 parent
fb5f4ba1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
9 deletions
config/vit.yaml
solver/vit_solver.py
config/vit.yaml
View file @
69e75f7
...
...
@@ -40,6 +40,7 @@ solver:
args
:
epoch
:
100
no_other
:
false
base_on
:
null
optimizer
:
name
:
'
Adam'
...
...
solver/vit_solver.py
View file @
69e75f7
import
os
import
copy
import
os
import
torch
from
model
import
build_model
from
data
import
build_dataloader
from
optimizer
import
build_optimizer
,
build_lr_scheduler
from
loss
import
build_loss
from
model
import
build_model
from
optimizer
import
build_lr_scheduler
,
build_optimizer
from
utils
import
SOLVER_REGISTRY
,
get_logger_and_log_dir
...
...
@@ -30,6 +31,7 @@ class VITSolver(object):
self
.
hyper_params
=
cfg
[
'solver'
][
'args'
]
self
.
no_other
=
self
.
hyper_params
[
'no_other'
]
self
.
base_on
=
self
.
hyper_params
[
'base_on'
]
try
:
self
.
epoch
=
self
.
hyper_params
[
'epoch'
]
except
Exception
:
...
...
@@ -62,9 +64,8 @@ class VITSolver(object):
if
self
.
no_other
:
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
))
else
:
pred
=
torch
.
nn
.
Sigmoid
(
self
.
model
(
X
))
correct
+=
self
.
evaluate
(
pred
,
y
)
# pred = torch.nn.Sigmoid()(self.model(X))
pred
=
self
.
model
(
X
)
# loss = self.loss_fn(pred, y, reduction="mean")
loss
=
self
.
loss_fn
(
pred
,
y
)
...
...
@@ -73,6 +74,8 @@ class VITSolver(object):
if
batch
%
100
==
0
:
loss_value
,
current
=
loss
.
item
(),
batch
self
.
logger
.
info
(
f
'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}'
)
correct
+=
self
.
evaluate
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
...
...
@@ -94,13 +97,14 @@ class VITSolver(object):
if
self
.
no_other
:
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
))
else
:
pred
=
torch
.
nn
.
Sigmoid
(
self
.
model
(
X
))
correct
+=
self
.
evaluate
(
pred
,
y
)
# pred = torch.nn.Sigmoid()(self.model(X))
pred
=
self
.
model
(
X
)
loss
=
self
.
loss_fn
(
pred
,
y
)
val_loss
+=
loss
.
item
()
correct
+=
self
.
evaluate
(
torch
.
nn
.
Sigmoid
()(
pred
),
y
)
correct
/=
self
.
val_dataset_size
val_loss
/=
self
.
val_loader_size
...
...
@@ -111,6 +115,10 @@ class VITSolver(object):
torch
.
save
(
self
.
model
.
state_dict
(),
os
.
path
.
join
(
self
.
log_dir
,
f
'ckpt_epoch_{epoch_id}.pt'
))
def
run
(
self
):
if
isinstance
(
self
.
base_on
,
str
)
and
os
.
path
.
exists
(
self
.
base_on
):
self
.
model
.
load_state_dict
(
torch
.
load
(
self
.
base_on
))
self
.
logger
.
info
(
f
'==> Load Model from {self.base_on}'
)
self
.
logger
.
info
(
'==> Start Training'
)
print
(
self
.
model
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment