3310e154 by 周伟奇

add cuda

1 parent bcb17d0f
...@@ -18,6 +18,7 @@ class MLPModel(nn.Module): ...@@ -18,6 +18,7 @@ class MLPModel(nn.Module):
18 nn.ReLU(), 18 nn.ReLU(),
19 nn.Linear(256, 5), 19 nn.Linear(256, 5),
20 nn.Sigmoid(), 20 nn.Sigmoid(),
21 # nn.ReLU(),
21 ) 22 )
22 self._initialize_weights() 23 self._initialize_weights()
23 24
......
...@@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir ...@@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir
13 class MLPSolver(object): 13 class MLPSolver(object):
14 14
15 def __init__(self, cfg): 15 def __init__(self, cfg):
16 self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
16 self.cfg = copy.deepcopy(cfg) 18 self.cfg = copy.deepcopy(cfg)
17 19
18 self.train_loader, self.val_loader = build_dataloader(cfg) 20 self.train_loader, self.val_loader = build_dataloader(cfg)
...@@ -20,7 +22,7 @@ class MLPSolver(object): ...@@ -20,7 +22,7 @@ class MLPSolver(object):
20 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) 22 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
21 23
22 # BatchNorm ? 24 # BatchNorm ?
23 self.model = build_model(cfg) 25 self.model = build_model(cfg).to(self.device)
24 26
25 self.loss_fn = build_loss(cfg) 27 self.loss_fn = build_loss(cfg)
26 28
...@@ -49,10 +51,14 @@ class MLPSolver(object): ...@@ -49,10 +51,14 @@ class MLPSolver(object):
49 def train_loop(self): 51 def train_loop(self):
50 self.model.train() 52 self.model.train()
51 53
52 train_loss = 0 54 train_loss, correct = 0
53 for batch, (X, y) in enumerate(self.train_loader): 55 for batch, (X, y) in enumerate(self.train_loader):
56 X, y = X.to(self.device), y.to(self.device)
57
54 pred = self.model(X) 58 pred = self.model(X)
55 59
60 correct += self.evaluate(pred, y)
61
56 # loss = self.loss_fn(pred, y, reduction="mean") 62 # loss = self.loss_fn(pred, y, reduction="mean")
57 loss = self.loss_fn(pred, y) 63 loss = self.loss_fn(pred, y)
58 train_loss += loss.item() 64 train_loss += loss.item()
...@@ -65,8 +71,9 @@ class MLPSolver(object): ...@@ -65,8 +71,9 @@ class MLPSolver(object):
65 loss.backward() 71 loss.backward()
66 self.optimizer.step() 72 self.optimizer.step()
67 73
74 correct /= self.train_dataset_size
68 train_loss /= self.train_loader_size 75 train_loss /= self.train_loader_size
69 self.logger.info(f'train mean loss: {train_loss :.4f}') 76 self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}')
70 77
71 @torch.no_grad() 78 @torch.no_grad()
72 def val_loop(self, t): 79 def val_loop(self, t):
...@@ -74,6 +81,8 @@ class MLPSolver(object): ...@@ -74,6 +81,8 @@ class MLPSolver(object):
74 81
75 val_loss, correct = 0, 0 82 val_loss, correct = 0, 0
76 for X, y in self.val_loader: 83 for X, y in self.val_loader:
84 X, y = X.to(self.device), y.to(self.device)
85
77 pred = self.model(X) 86 pred = self.model(X)
78 87
79 correct += self.evaluate(pred, y) 88 correct += self.evaluate(pred, y)
...@@ -84,7 +93,7 @@ class MLPSolver(object): ...@@ -84,7 +93,7 @@ class MLPSolver(object):
84 correct /= self.val_dataset_size 93 correct /= self.val_dataset_size
85 val_loss /= self.val_loader_size 94 val_loss /= self.val_loader_size
86 95
87 self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}") 96 self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}")
88 97
89 def save_checkpoint(self, epoch_id): 98 def save_checkpoint(self, epoch_id):
90 self.model.eval() 99 self.model.eval()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!