Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
3310e154
authored
2022-12-12 15:22:03 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add cuda
1 parent
bcb17d0f
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
4 deletions
model/mlp.py
solver/mlp_solver.py
model/mlp.py
View file @
3310e15
...
...
@@ -18,6 +18,7 @@ class MLPModel(nn.Module):
nn
.
ReLU
(),
nn
.
Linear
(
256
,
5
),
nn
.
Sigmoid
(),
# nn.ReLU(),
)
self
.
_initialize_weights
()
...
...
solver/mlp_solver.py
View file @
3310e15
...
...
@@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir
class
MLPSolver
(
object
):
def
__init__
(
self
,
cfg
):
self
.
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
cfg
=
copy
.
deepcopy
(
cfg
)
self
.
train_loader
,
self
.
val_loader
=
build_dataloader
(
cfg
)
...
...
@@ -20,7 +22,7 @@ class MLPSolver(object):
self
.
train_dataset_size
,
self
.
val_dataset_size
=
len
(
self
.
train_loader
.
dataset
),
len
(
self
.
val_loader
.
dataset
)
# BatchNorm ?
self
.
model
=
build_model
(
cfg
)
self
.
model
=
build_model
(
cfg
)
.
to
(
self
.
device
)
self
.
loss_fn
=
build_loss
(
cfg
)
...
...
@@ -49,10 +51,14 @@ class MLPSolver(object):
def
train_loop
(
self
):
self
.
model
.
train
()
train_loss
=
0
train_loss
,
correct
=
0
for
batch
,
(
X
,
y
)
in
enumerate
(
self
.
train_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
pred
=
self
.
model
(
X
)
correct
+=
self
.
evaluate
(
pred
,
y
)
# loss = self.loss_fn(pred, y, reduction="mean")
loss
=
self
.
loss_fn
(
pred
,
y
)
train_loss
+=
loss
.
item
()
...
...
@@ -65,8 +71,9 @@ class MLPSolver(object):
loss
.
backward
()
self
.
optimizer
.
step
()
correct
/=
self
.
train_dataset_size
train_loss
/=
self
.
train_loader_size
self
.
logger
.
info
(
f
'train mean loss: {train_loss :.4f}'
)
self
.
logger
.
info
(
f
'train
accuracy: {correct :.4f}, train
mean loss: {train_loss :.4f}'
)
@torch.no_grad
()
def
val_loop
(
self
,
t
):
...
...
@@ -74,6 +81,8 @@ class MLPSolver(object):
val_loss
,
correct
=
0
,
0
for
X
,
y
in
self
.
val_loader
:
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
pred
=
self
.
model
(
X
)
correct
+=
self
.
evaluate
(
pred
,
y
)
...
...
@@ -84,7 +93,7 @@ class MLPSolver(object):
correct
/=
self
.
val_dataset_size
val_loss
/=
self
.
val_loader_size
self
.
logger
.
info
(
f
"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}"
)
self
.
logger
.
info
(
f
"val accuracy: {correct :.4f}, val
mean
loss: {val_loss :.4f}"
)
def
save_checkpoint
(
self
,
epoch_id
):
self
.
model
.
eval
()
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment