add cuda
Showing
2 changed files
with
14 additions
and
4 deletions
| ... | @@ -18,6 +18,7 @@ class MLPModel(nn.Module): | ... | @@ -18,6 +18,7 @@ class MLPModel(nn.Module): |
| 18 | nn.ReLU(), | 18 | nn.ReLU(), |
| 19 | nn.Linear(256, 5), | 19 | nn.Linear(256, 5), |
| 20 | nn.Sigmoid(), | 20 | nn.Sigmoid(), |
| 21 | # nn.ReLU(), | ||
| 21 | ) | 22 | ) |
| 22 | self._initialize_weights() | 23 | self._initialize_weights() |
| 23 | 24 | ... | ... |
| ... | @@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ... | @@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir |
| 13 | class MLPSolver(object): | 13 | class MLPSolver(object): |
| 14 | 14 | ||
| 15 | def __init__(self, cfg): | 15 | def __init__(self, cfg): |
| 16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 17 | |||
| 16 | self.cfg = copy.deepcopy(cfg) | 18 | self.cfg = copy.deepcopy(cfg) |
| 17 | 19 | ||
| 18 | self.train_loader, self.val_loader = build_dataloader(cfg) | 20 | self.train_loader, self.val_loader = build_dataloader(cfg) |
| ... | @@ -20,7 +22,7 @@ class MLPSolver(object): | ... | @@ -20,7 +22,7 @@ class MLPSolver(object): |
| 20 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | 22 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) |
| 21 | 23 | ||
| 22 | # BatchNorm ? | 24 | # BatchNorm ? |
| 23 | self.model = build_model(cfg) | 25 | self.model = build_model(cfg).to(self.device) |
| 24 | 26 | ||
| 25 | self.loss_fn = build_loss(cfg) | 27 | self.loss_fn = build_loss(cfg) |
| 26 | 28 | ||
| ... | @@ -49,10 +51,14 @@ class MLPSolver(object): | ... | @@ -49,10 +51,14 @@ class MLPSolver(object): |
| 49 | def train_loop(self): | 51 | def train_loop(self): |
| 50 | self.model.train() | 52 | self.model.train() |
| 51 | 53 | ||
| 52 | train_loss = 0 | 54 | train_loss, correct = 0 |
| 53 | for batch, (X, y) in enumerate(self.train_loader): | 55 | for batch, (X, y) in enumerate(self.train_loader): |
| 56 | X, y = X.to(self.device), y.to(self.device) | ||
| 57 | |||
| 54 | pred = self.model(X) | 58 | pred = self.model(X) |
| 55 | 59 | ||
| 60 | correct += self.evaluate(pred, y) | ||
| 61 | |||
| 56 | # loss = self.loss_fn(pred, y, reduction="mean") | 62 | # loss = self.loss_fn(pred, y, reduction="mean") |
| 57 | loss = self.loss_fn(pred, y) | 63 | loss = self.loss_fn(pred, y) |
| 58 | train_loss += loss.item() | 64 | train_loss += loss.item() |
| ... | @@ -65,8 +71,9 @@ class MLPSolver(object): | ... | @@ -65,8 +71,9 @@ class MLPSolver(object): |
| 65 | loss.backward() | 71 | loss.backward() |
| 66 | self.optimizer.step() | 72 | self.optimizer.step() |
| 67 | 73 | ||
| 74 | correct /= self.train_dataset_size | ||
| 68 | train_loss /= self.train_loader_size | 75 | train_loss /= self.train_loader_size |
| 69 | self.logger.info(f'train mean loss: {train_loss :.4f}') | 76 | self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}') |
| 70 | 77 | ||
| 71 | @torch.no_grad() | 78 | @torch.no_grad() |
| 72 | def val_loop(self, t): | 79 | def val_loop(self, t): |
| ... | @@ -74,6 +81,8 @@ class MLPSolver(object): | ... | @@ -74,6 +81,8 @@ class MLPSolver(object): |
| 74 | 81 | ||
| 75 | val_loss, correct = 0, 0 | 82 | val_loss, correct = 0, 0 |
| 76 | for X, y in self.val_loader: | 83 | for X, y in self.val_loader: |
| 84 | X, y = X.to(self.device), y.to(self.device) | ||
| 85 | |||
| 77 | pred = self.model(X) | 86 | pred = self.model(X) |
| 78 | 87 | ||
| 79 | correct += self.evaluate(pred, y) | 88 | correct += self.evaluate(pred, y) |
| ... | @@ -84,7 +93,7 @@ class MLPSolver(object): | ... | @@ -84,7 +93,7 @@ class MLPSolver(object): |
| 84 | correct /= self.val_dataset_size | 93 | correct /= self.val_dataset_size |
| 85 | val_loss /= self.val_loader_size | 94 | val_loss /= self.val_loader_size |
| 86 | 95 | ||
| 87 | self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}") | 96 | self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}") |
| 88 | 97 | ||
| 89 | def save_checkpoint(self, epoch_id): | 98 | def save_checkpoint(self, epoch_id): |
| 90 | self.model.eval() | 99 | self.model.eval() | ... | ... |
-
Please register or sign in to post a comment