Source code for detectors.eval

"""
Module containing evaluation metrics.
"""
from typing import Dict

import numpy as np
import sklearn
import sklearn.metrics
import torch
from torch import Tensor


[docs]def fpr_at_fixed_tpr(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray, tpr_level: float = 0.95): """Return the FPR at a fixed TPR level. Args: fprs (np.ndarray): False positive rates. tprs (np.ndarray): True positive rates. thresholds (np.ndarray): Thresholds. tpr_level (float, optional): TPR level. Defaults to 0.95. Returns: Tuple[float, float, float]: FPR, TPR, threshold. """ # return np.interp(tpr_level, tprs, fprs) if all(tprs < tpr_level): raise ValueError(f"No threshold allows for TPR at least {tpr_level}.") idxs = [i for i, x in enumerate(tprs) if x >= tpr_level] idx = min(idxs) return float(fprs[idx]), float(tprs[idx]), float(thresholds[idx])
[docs]def compute_detection_error(fpr: float, tpr: float, pos_ratio: float): """Compute the detection error. Args: fpr (float): False positive rate at a fixed TPR. tpr (float): True positive rate. pos_ratio (float): Ratio of positive labels. Returns: float: Detection error. """ # Get ratios of positives to negatives neg_ratio = 1 - pos_ratio # Get indexes of all TPR >= fixed tpr level detection_error = pos_ratio * (1 - tpr) + neg_ratio * fpr return detection_error
[docs]def minimum_detection_error(fprs: np.ndarray, tprs: np.ndarray, pos_ratio: float): """Compute the minimum detection error. Args: fprs (np.ndarray): False positive rates. tprs (np.ndarray): True positive rates. thresholds (np.ndarray): Thresholds. pos_ratio (float): Ratio of positive labels. Returns: Tuple[float, float, float]: FPR, TPR, threshold. """ detection_errors = [compute_detection_error(fpr, tpr, pos_ratio) for fpr, tpr in zip(fprs, tprs)] idx = np.argmin(detection_errors) return detection_errors[idx]
[docs]def get_ood_results(in_scores: Tensor, ood_scores: Tensor) -> Dict[str, float]: """Compute OOD detection metrics. Args: in_scores (Tensor): In-distribution scores. ood_scores (Tensor): Out-of-distribution scores. Returns: Dict[str, float]: OOD detection metrics. keys: `fpr_at_0.95_tpr`, `tnr_at_0.95_tpr`, `detection_error`, `auroc`, `aupr_in`, `aupr_out`, `thr`. """ if isinstance(in_scores, np.ndarray) or isinstance(in_scores, list): in_scores = torch.tensor(in_scores) if isinstance(ood_scores, np.ndarray) or isinstance(ood_scores, list): ood_scores = torch.tensor(ood_scores) in_labels = torch.ones(len(in_scores)) ood_labels = torch.zeros(len(ood_scores)) _test_scores = torch.cat([in_scores, ood_scores]).cpu().numpy() _test_labels = torch.cat([in_labels, ood_labels]).cpu().numpy() fprs, tprs, thrs = sklearn.metrics.roc_curve(_test_labels, _test_scores) precision, recall, _ = sklearn.metrics.precision_recall_curve(_test_labels, _test_scores, pos_label=1) precision_out, recall_out, _ = sklearn.metrics.precision_recall_curve(_test_labels, _test_scores, pos_label=0) fpr, tpr, thr = fpr_at_fixed_tpr(fprs, tprs, thrs, 0.95) auroc = sklearn.metrics.auc(fprs, tprs) aupr_in = sklearn.metrics.auc(recall, precision) aupr_out = sklearn.metrics.auc(recall_out, precision_out) pos_ratio = np.mean(_test_labels == 1) detection_error = minimum_detection_error(fprs, tprs, pos_ratio) results = { "fpr_at_0.95_tpr": fpr, "tnr_at_0.95_tpr": 1 - fpr, "detection_error": detection_error, "auroc": auroc, "aupr_in": aupr_in, "aupr_out": aupr_out, "thr": thr, } return results