fix device
Showing
1 changed file
with
6 additions
and
4 deletions
| ... | @@ -51,7 +51,8 @@ class MLPSolver(object): | ... | @@ -51,7 +51,8 @@ class MLPSolver(object): |
| 51 | def train_loop(self): | 51 | def train_loop(self): |
| 52 | self.model.train() | 52 | self.model.train() |
| 53 | 53 | ||
| 54 | train_loss, correct = 0 | 54 | train_loss = torch.zeros(1).to(self.device) |
| 55 | correct = torch.zeros(1).to(self.device) | ||
| 55 | for batch, (X, y) in enumerate(self.train_loader): | 56 | for batch, (X, y) in enumerate(self.train_loader): |
| 56 | X, y = X.to(self.device), y.to(self.device) | 57 | X, y = X.to(self.device), y.to(self.device) |
| 57 | 58 | ||
| ... | @@ -73,13 +74,14 @@ class MLPSolver(object): | ... | @@ -73,13 +74,14 @@ class MLPSolver(object): |
| 73 | 74 | ||
| 74 | correct /= self.train_dataset_size | 75 | correct /= self.train_dataset_size |
| 75 | train_loss /= self.train_loader_size | 76 | train_loss /= self.train_loader_size |
| 76 | self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}') | 77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') |
| 77 | 78 | ||
| 78 | @torch.no_grad() | 79 | @torch.no_grad() |
| 79 | def val_loop(self, t): | 80 | def val_loop(self, t): |
| 80 | self.model.eval() | 81 | self.model.eval() |
| 81 | 82 | ||
| 82 | val_loss, correct = 0, 0 | 83 | val_loss = torch.zeros(1).to(self.device) |
| 84 | correct = torch.zeros(1).to(self.device) | ||
| 83 | for X, y in self.val_loader: | 85 | for X, y in self.val_loader: |
| 84 | X, y = X.to(self.device), y.to(self.device) | 86 | X, y = X.to(self.device), y.to(self.device) |
| 85 | 87 | ||
| ... | @@ -93,7 +95,7 @@ class MLPSolver(object): | ... | @@ -93,7 +95,7 @@ class MLPSolver(object): |
| 93 | correct /= self.val_dataset_size | 95 | correct /= self.val_dataset_size |
| 94 | val_loss /= self.val_loader_size | 96 | val_loss /= self.val_loader_size |
| 95 | 97 | ||
| 96 | self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}") | 98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") |
| 97 | 99 | ||
| 98 | def save_checkpoint(self, epoch_id): | 100 | def save_checkpoint(self, epoch_id): |
| 99 | self.model.eval() | 101 | self.model.eval() | ... | ... |
-
Please register or sign in to post a comment