add cuda
Showing
2 changed files
with
14 additions
and
4 deletions
... | @@ -18,6 +18,7 @@ class MLPModel(nn.Module): | ... | @@ -18,6 +18,7 @@ class MLPModel(nn.Module): |
18 | nn.ReLU(), | 18 | nn.ReLU(), |
19 | nn.Linear(256, 5), | 19 | nn.Linear(256, 5), |
20 | nn.Sigmoid(), | 20 | nn.Sigmoid(), |
21 | # nn.ReLU(), | ||
21 | ) | 22 | ) |
22 | self._initialize_weights() | 23 | self._initialize_weights() |
23 | 24 | ... | ... |
... | @@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ... | @@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir |
13 | class MLPSolver(object): | 13 | class MLPSolver(object): |
14 | 14 | ||
15 | def __init__(self, cfg): | 15 | def __init__(self, cfg): |
16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
17 | |||
16 | self.cfg = copy.deepcopy(cfg) | 18 | self.cfg = copy.deepcopy(cfg) |
17 | 19 | ||
18 | self.train_loader, self.val_loader = build_dataloader(cfg) | 20 | self.train_loader, self.val_loader = build_dataloader(cfg) |
... | @@ -20,7 +22,7 @@ class MLPSolver(object): | ... | @@ -20,7 +22,7 @@ class MLPSolver(object): |
20 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | 22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) |
21 | 23 | ||
22 | # BatchNorm ? | 24 | # BatchNorm ? |
23 | self.model = build_model(cfg) | 25 | self.model = build_model(cfg).to(self.device) |
24 | 26 | ||
25 | self.loss_fn = build_loss(cfg) | 27 | self.loss_fn = build_loss(cfg) |
26 | 28 | ||
... | @@ -49,10 +51,14 @@ class MLPSolver(object): | ... | @@ -49,10 +51,14 @@ class MLPSolver(object): |
49 | def train_loop(self): | 51 | def train_loop(self): |
50 | self.model.train() | 52 | self.model.train() |
51 | 53 | ||
52 | train_loss = 0 | 54 | train_loss, correct = 0 |
53 | for batch, (X, y) in enumerate(self.train_loader): | 55 | for batch, (X, y) in enumerate(self.train_loader): |
56 | X, y = X.to(self.device), y.to(self.device) | ||
57 | |||
54 | pred = self.model(X) | 58 | pred = self.model(X) |
55 | 59 | ||
60 | correct += self.evaluate(pred, y) | ||
61 | |||
56 | # loss = self.loss_fn(pred, y, reduction="mean") | 62 | # loss = self.loss_fn(pred, y, reduction="mean") |
57 | loss = self.loss_fn(pred, y) | 63 | loss = self.loss_fn(pred, y) |
58 | train_loss += loss.item() | 64 | train_loss += loss.item() |
... | @@ -65,8 +71,9 @@ class MLPSolver(object): | ... | @@ -65,8 +71,9 @@ class MLPSolver(object): |
65 | loss.backward() | 71 | loss.backward() |
66 | self.optimizer.step() | 72 | self.optimizer.step() |
67 | 73 | ||
74 | correct /= self.train_dataset_size | ||
68 | train_loss /= self.train_loader_size | 75 | train_loss /= self.train_loader_size |
69 | self.logger.info(f'train mean loss: {train_loss :.4f}') | 76 | self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}') |
70 | 77 | ||
71 | @torch.no_grad() | 78 | @torch.no_grad() |
72 | def val_loop(self, t): | 79 | def val_loop(self, t): |
... | @@ -74,6 +81,8 @@ class MLPSolver(object): | ... | @@ -74,6 +81,8 @@ class MLPSolver(object): |
74 | 81 | ||
75 | val_loss, correct = 0, 0 | 82 | val_loss, correct = 0, 0 |
76 | for X, y in self.val_loader: | 83 | for X, y in self.val_loader: |
84 | X, y = X.to(self.device), y.to(self.device) | ||
85 | |||
77 | pred = self.model(X) | 86 | pred = self.model(X) |
78 | 87 | ||
79 | correct += self.evaluate(pred, y) | 88 | correct += self.evaluate(pred, y) |
... | @@ -84,7 +93,7 @@ class MLPSolver(object): | ... | @@ -84,7 +93,7 @@ class MLPSolver(object): |
84 | correct /= self.val_dataset_size | 93 | correct /= self.val_dataset_size |
85 | val_loss /= self.val_loader_size | 94 | val_loss /= self.val_loader_size |
86 | 95 | ||
87 | self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}") | 96 | self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}") |
88 | 97 | ||
89 | def save_checkpoint(self, epoch_id): | 98 | def save_checkpoint(self, epoch_id): |
90 | self.model.eval() | 99 | self.model.eval() | ... | ... |
-
Please register or sign in to post a comment