Source code for detectors.trainer

"""Trainer for classification models."""
import json
import logging
import os
from typing import Callable, Tuple

import accelerate
import torch
import torch.utils.data
from accelerate import Accelerator
from torch import Tensor, nn, optim
from tqdm import tqdm

_logger = logging.getLogger(__name__)


[docs]def get_criterion_cls(criterion_name: str) -> nn.modules.loss._Loss: return getattr(nn, criterion_name)
[docs]def get_optimizer_cls(optimizer_name: str) -> optim.Optimizer: return getattr(optim, optimizer_name)
[docs]def get_scheduler_cls(scheduler_name: str) -> optim.lr_scheduler._LRScheduler: return getattr(optim.lr_scheduler, scheduler_name)
[docs]def training_iteration( batch: Tuple[Tensor, Tensor], model: nn.Module, optimizer: optim.Optimizer, criterion: Callable, accelerator: Accelerator, ): inputs, targets = batch optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) accelerator.backward(loss) optimizer.step() return {"loss": loss.item()}
[docs]def validation_iteration(batch: Tuple[Tensor, Tensor], model: nn.Module, criterion: Callable, accelerator: Accelerator): inputs, targets = batch with torch.no_grad(): outputs = model(inputs) predictions = torch.argmax(outputs, dim=1) predictions, targets = accelerator.gather_for_metrics((predictions, targets)) loss = criterion(outputs, targets).item() acc = (predictions == targets).float().sum().item() return {"loss": loss, "accuracy": acc}
[docs]def save_model(model: nn.Module, accelerator: Accelerator, filename: str): accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) accelerator.save(unwrapped_model.state_dict(), filename)
[docs]def trainer_supervised_classification( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, criterion: Callable, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, save_root: str, training_function=training_iteration, validation_function=validation_iteration, epochs=10, validation_frequency=1, seed: int = 42, ): os.makedirs(save_root, exist_ok=True) _logger.info(f"Saving model and progress to {save_root}") accelerate.utils.set_seed(seed) accelerator = Accelerator( log_with=["all"], logging_dir=os.path.join(save_root, "logs"), step_scheduler_with_optimizer=False ) accelerator.init_trackers("") model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, scheduler ) best_accuracy = 0.0 step = 0 progress_bar = tqdm(range(epochs), disable=not accelerator.is_local_main_process, dynamic_ncols=True) for epoch in progress_bar: # train model.train() for batch in train_loader: step += 1 tr_obj = training_function(batch, model, optimizer, criterion, accelerator) tr_loss = tr_obj["loss"] lr = optimizer.param_groups[0]["lr"] progress_bar.set_description_str(f"{step}it, loss={tr_loss:.4f}, lr={lr:.4f}") accelerator.log({f"train/{k}": v for k, v in tr_obj.items()}, step=step) accelerator.log({"lr": lr}, step=step) # validate if (epoch + 1) % validation_frequency == 0 or epoch == 0 or epoch == epochs - 1: model.eval() val_acc = 0 val_loss = 0 total = 0 val_obj = {} for batch in val_loader: val_obj = validation_function(batch, model, criterion, accelerator) val_acc += val_obj["accuracy"] val_loss += val_obj["loss"] total += len(batch[0]) # temporary fix val_acc /= total val_loss /= total if val_acc > best_accuracy: best_accuracy = val_acc save_model(model, accelerator, os.path.join(save_root, "best.pth")) progress_bar.set_postfix({"val/loss": val_loss, "val/acc": val_acc, "best/acc": best_accuracy}) accelerator.log({"val/acc": val_acc, "val/loss": val_loss}, step=epoch) if scheduler is not None: if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(val_acc) # finish training if model is not learning num_classes = model(batch[0]).shape[1] # temporary fix if val_acc < 1 / num_classes: _logger.error("Accuracy is too low, stopping training") break # sync accelerator after epoch accelerator.wait_for_everyone() progress_bar.update(1) if scheduler is not None: if not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(epoch) accelerator.end_training() # save model save_model(model, accelerator, os.path.join(save_root, "last.pth")) # save train results train_results = {"epoch": epoch, "train_loss": tr_loss} with open(os.path.join(save_root, "train_results.json"), "w") as f: json.dump(train_results, f) # save eval results eval_results = {"epoch": epoch, "best_accuracy": best_accuracy, "last_accuracy": val_acc, "eval_loss": val_loss} with open(os.path.join(save_root, "eval_results.json"), "w") as f: json.dump(eval_results, f) _logger.info(f"Training finished. Best accuracy: {best_accuracy:.4f}") _logger.info(f"Last model saved to {save_root}/last.pth") _logger.info(f"Best model saved to {save_root}/best.pth") _logger.info(f"Training logs saved to {save_root}/logs") _logger.info(f"Training results saved to {save_root}/train_results.json") _logger.info(f"Evaluation results saved to {save_root}/eval_results.json")