train.py 3.75 KB
import os.path

from torch import nn, optim
import torch
from dataset import *
from torch.utils.data import random_split
from net import *
import tqdm
import time

if __name__ == '__main__':
    train_rate=0.8
    batch_size=50
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    epochs=50

    datasets = FaceMaskDataset('/data2/new_face_mask')
    train_datasets, test_datasets = random_split(
        datasets,
        [int(len(datasets) * train_rate), len(datasets) - int(len(datasets) * train_rate)],
    )
    print(f'train_datasets:{len(train_datasets)}   test_datasets:{len(test_datasets)}')
    train_data_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
    test_data_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)
    loss_fun = nn.CrossEntropyLoss()

    net = FaceMaskNet().to(device)
    if os.path.exists('params/new_face_mobilenet_v2.pth'):
        net.load_state_dict(torch.load('params/new_face_mobilenet_v2.pth'))
        print('successfully loading weights!')
    opt = optim.Adam(net.parameters())


    for epoch in range(1, epochs):

        with tqdm.tqdm(train_data_loader) as t1:
            for i, (image_data, image_label) in enumerate(train_data_loader):
                net.train()
                image_data, image_label = image_data.to(device), image_label.to(device)
                out = net(image_data)
                train_loss = loss_fun(out, image_label)
                opt.zero_grad()
                train_loss.backward()
                opt.step()
                t1.set_description(f'Epoch {epoch} train')
                t1.set_postfix(train_loss=train_loss.item(),
                               train_acc=torch.mean(torch.eq(image_label, torch.argmax(out,dim=1)).float()).item())
                time.sleep(0.1)
                t1.update(1)
                if (i+1) % 10 == 0:
                    torch.save(net.state_dict(), 'params/new_face_mobilenet_v2.pth')
                    print(f'epoch : {epoch}  {i}  successfully save weights!')


        acc, temp = 0, 0
        with torch.no_grad():
            net.eval()
            with tqdm.tqdm(test_data_loader) as t2:
                for j, (image_data, image_label) in enumerate(test_data_loader):
                    image_data, image_label = image_data.to(device), image_label.to(device)
                    out = net(image_data)
                    test_loss = loss_fun(out, image_label)

                    t2.set_description(f'Epoch {epoch} test')
                    out = torch.argmax(out, dim=1)
                    t2.set_postfix(test_loss=test_loss.item(),
                                  test_acc=torch.mean(torch.eq(image_label, out).float()).item())
                    time.sleep(0.1)
                    t2.update(1)
                    acc += torch.mean(torch.eq(image_label, out).float()).item()
                    temp += 1
                print(f'epoch : {epoch}  avg acc : ', acc / temp)

        # acc,temp=0,0
        # with torch.no_grad():
        #     net.eval()
        #     for i, (image_data, image_label) in enumerate(tqdm.tqdm(test_data_loader)):
        #         image_data, image_label = image_data.to(device), image_label.to(device)
        #         out = net(image_data)
        #         test_loss = loss_fun(out, image_label)
        #
        #         out = torch.argmax(out, dim=1)
        #
        #         acc += torch.mean(torch.eq(image_label, out).float()).item()
        #         temp+=1
        #         if i % 5 == 0:
        #             print(f'epoch : {epoch}  {i}  test_loss : ', test_loss.item())
        #             print(f'epoch : {epoch}  {i}  test acc  : ',torch.mean(torch.eq(image_label, out).float()).item())
        #     print(f'epoch : {epoch}  avg acc : ',acc/temp)