b3694ec8 by 周伟奇

add eval

1 parent 82a85c6d
......@@ -41,6 +41,7 @@ solver:
epoch: 100
no_other: false
base_on: null
model_path: null
optimizer:
name: 'Adam'
......
......@@ -7,6 +7,7 @@ from solver.builder import build_solver
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
parser.add_argument('-e', '--eval', action="store_true")
args = parser.parse_args()
cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
......@@ -14,6 +15,10 @@ def main():
# print(torch.cuda.is_available())
solver = build_solver(cfg)
if args.eval:
solver.evaluate()
else:
solver.run()
......
......@@ -8,6 +8,7 @@ from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
@SOLVER_REGISTRY.register()
......@@ -32,6 +33,7 @@ class VITSolver(object):
self.hyper_params = cfg['solver']['args']
self.no_other = self.hyper_params['no_other']
self.base_on = self.hyper_params['base_on']
self.model_path = self.hyper_params['model_path']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
......@@ -39,7 +41,7 @@ class VITSolver(object):
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
def evaluate(self, y_pred, y_true, thresholds=0.5):
def accuracy(self, y_pred, y_true, thresholds=0.5):
if self.no_other:
return (y_pred.argmax(1) == y_true.argmax(1)).type(torch.float).sum().item()
else:
......@@ -80,9 +82,9 @@ class VITSolver(object):
self.optimizer.step()
if self.no_other:
correct += self.evaluate(pred, y)
correct += self.accuracy(pred, y)
else:
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
correct += self.accuracy(torch.nn.Sigmoid()(pred), y)
correct /= self.train_dataset_size
train_loss /= self.train_loader_size
......@@ -107,9 +109,9 @@ class VITSolver(object):
val_loss += loss.item()
if self.no_other:
correct += self.evaluate(pred, y)
correct += self.accuracy(pred, y)
else:
correct += self.evaluate(torch.nn.Sigmoid()(pred), y)
correct += self.accuracy(torch.nn.Sigmoid()(pred), y)
correct /= self.val_dataset_size
val_loss /= self.val_loader_size
......@@ -140,3 +142,44 @@ class VITSolver(object):
lr_scheduler.step()
self.logger.info('==> End Training')
def evaluate(self):
if isinstance(self.model_path, str) and os.path.exists(self.model_path):
self.model.load_state_dict(torch.load(self.model_path))
self.logger.info(f'==> Load Model from {self.model_path}')
else:
return
self.model.eval()
label_true_list = []
label_pred_list = []
for X, y in self.val_loader:
X, y_true = X.to(self.device), y.to(self.device)
if self.no_other:
pred = torch.nn.Softmax(dim=1)(self.model(X))
else:
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
y_pred = torch.nn.Sigmoid()(pred)
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
acc = accuracy_score(label_true_list, label_pred_list)
cm = confusion_matrix(label_true_list, label_pred_list)
report = classification_report(label_true_list, label_pred_list)
print(acc)
print(cm)
print(report)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!