3310e154 by 周伟奇

add cuda

1 parent bcb17d0f
......@@ -18,6 +18,7 @@ class MLPModel(nn.Module):
nn.ReLU(),
nn.Linear(256, 5),
nn.Sigmoid(),
# nn.ReLU(),
)
self._initialize_weights()
......
......@@ -13,6 +13,8 @@ from utils import SOLVER_REGISTRY, get_logger_and_log_dir
class MLPSolver(object):
def __init__(self, cfg):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
......@@ -20,7 +22,7 @@ class MLPSolver(object):
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg)
self.model = build_model(cfg).to(self.device)
self.loss_fn = build_loss(cfg)
......@@ -49,10 +51,14 @@ class MLPSolver(object):
def train_loop(self):
self.model.train()
train_loss = 0
train_loss, correct = 0
for batch, (X, y) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
correct += self.evaluate(pred, y)
# loss = self.loss_fn(pred, y, reduction="mean")
loss = self.loss_fn(pred, y)
train_loss += loss.item()
......@@ -65,8 +71,9 @@ class MLPSolver(object):
loss.backward()
self.optimizer.step()
correct /= self.train_dataset_size
train_loss /= self.train_loader_size
self.logger.info(f'train mean loss: {train_loss :.4f}')
self.logger.info(f'train accuracy: {correct :.4f}, train mean loss: {train_loss :.4f}')
@torch.no_grad()
def val_loop(self, t):
......@@ -74,6 +81,8 @@ class MLPSolver(object):
val_loss, correct = 0, 0
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X)
correct += self.evaluate(pred, y)
......@@ -84,7 +93,7 @@ class MLPSolver(object):
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}")
self.logger.info(f"val accuracy: {correct :.4f}, val mean loss: {val_loss :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!