c424aba7 by 周伟奇

fix device

1 parent 9a4eb652
......@@ -51,7 +51,8 @@ class MLPSolver(object):
def train_loop(self):
self.model.train()
train_loss, correct = 0
train_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for batch, (X, y) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
......@@ -73,13 +74,14 @@ class MLPSolver(object):
correct /= self.train_dataset_size
train_loss /= self.train_loader_size
self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}')
self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
val_loss, correct = 0, 0
val_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
......@@ -93,7 +95,7 @@ class MLPSolver(object):
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}")
self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!