c424aba7 by 周伟奇

fix device

1 parent 9a4eb652
...@@ -51,7 +51,8 @@ class MLPSolver(object): ...@@ -51,7 +51,8 @@ class MLPSolver(object):
51 def train_loop(self): 51 def train_loop(self):
52 self.model.train() 52 self.model.train()
53 53
54 train_loss, correct = 0 54 train_loss = torch.zeros(1).to(self.device)
55 correct = torch.zeros(1).to(self.device)
55 for batch, (X, y) in enumerate(self.train_loader): 56 for batch, (X, y) in enumerate(self.train_loader):
56 X, y = X.to(self.device), y.to(self.device) 57 X, y = X.to(self.device), y.to(self.device)
57 58
...@@ -73,13 +74,14 @@ class MLPSolver(object): ...@@ -73,13 +74,14 @@ class MLPSolver(object):
73 74
74 correct /= self.train_dataset_size 75 correct /= self.train_dataset_size
75 train_loss /= self.train_loader_size 76 train_loss /= self.train_loader_size
76 self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}') 77 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
77 78
78 @torch.no_grad() 79 @torch.no_grad()
79 def val_loop(self, t): 80 def val_loop(self, t):
80 self.model.eval() 81 self.model.eval()
81 82
82 val_loss, correct = 0, 0 83 val_loss = torch.zeros(1).to(self.device)
84 correct = torch.zeros(1).to(self.device)
83 for X, y in self.val_loader: 85 for X, y in self.val_loader:
84 X, y = X.to(self.device), y.to(self.device) 86 X, y = X.to(self.device), y.to(self.device)
85 87
...@@ -93,7 +95,7 @@ class MLPSolver(object): ...@@ -93,7 +95,7 @@ class MLPSolver(object):
93 correct /= self.val_dataset_size 95 correct /= self.val_dataset_size
94 val_loss /= self.val_loader_size 96 val_loss /= self.val_loader_size
95 97
96 self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}") 98 self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
97 99
98 def save_checkpoint(self, epoch_id): 100 def save_checkpoint(self, epoch_id):
99 self.model.eval() 101 self.model.eval()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!