fix device
Showing
1 changed file
with
6 additions
and
4 deletions
... | @@ -51,7 +51,8 @@ class MLPSolver(object): | ... | @@ -51,7 +51,8 @@ class MLPSolver(object): |
51 | def train_loop(self): | 51 | def train_loop(self): |
52 | self.model.train() | 52 | self.model.train() |
53 | 53 | ||
54 | train_loss, correct = 0 | 54 | train_loss = torch.zeros(1).to(self.device) |
55 | correct = torch.zeros(1).to(self.device) | ||
55 | for batch, (X, y) in enumerate(self.train_loader): | 56 | for batch, (X, y) in enumerate(self.train_loader): |
56 | X, y = X.to(self.device), y.to(self.device) | 57 | X, y = X.to(self.device), y.to(self.device) |
57 | 58 | ||
... | @@ -73,13 +74,14 @@ class MLPSolver(object): | ... | @@ -73,13 +74,14 @@ class MLPSolver(object): |
73 | 74 | ||
74 | correct /= self.train_dataset_size | 75 | correct /= self.train_dataset_size |
75 | train_loss /= self.train_loader_size | 76 | train_loss /= self.train_loader_size |
76 | self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}') | 77 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') |
77 | 78 | ||
78 | @torch.no_grad() | 79 | @torch.no_grad() |
79 | def val_loop(self, t): | 80 | def val_loop(self, t): |
80 | self.model.eval() | 81 | self.model.eval() |
81 | 82 | ||
82 | val_loss, correct = 0, 0 | 83 | val_loss = torch.zeros(1).to(self.device) |
84 | correct = torch.zeros(1).to(self.device) | ||
83 | for X, y in self.val_loader: | 85 | for X, y in self.val_loader: |
84 | X, y = X.to(self.device), y.to(self.device) | 86 | X, y = X.to(self.device), y.to(self.device) |
85 | 87 | ||
... | @@ -93,7 +95,7 @@ class MLPSolver(object): | ... | @@ -93,7 +95,7 @@ class MLPSolver(object): |
93 | correct /= self.val_dataset_size | 95 | correct /= self.val_dataset_size |
94 | val_loss /= self.val_loader_size | 96 | val_loss /= self.val_loader_size |
95 | 97 | ||
96 | self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}") | 98 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") |
97 | 99 | ||
98 | def save_checkpoint(self, epoch_id): | 100 | def save_checkpoint(self, epoch_id): |
99 | self.model.eval() | 101 | self.model.eval() | ... | ... |
-
Please register or sign in to post a comment