import logging
from functools import partial, reduce
from typing import Callable, List, Literal
import numpy as np
import torch
from torch import Tensor, nn
_logger = logging.getLogger(__name__)
# Feature reductions
[docs]def flatten(data: Tensor, **kwargs):
return torch.flatten(data, 1)
[docs]def adaptive_avg_pool2d(data: Tensor, **kwargs):
if len(data.shape) > 2:
return torch.flatten(nn.AdaptiveAvgPool2d((1, 1))(data), 1)
return data
[docs]def adaptive_max_pool2d(data: Tensor, **kwargs):
if len(data.shape) > 2:
return torch.flatten(nn.AdaptiveMaxPool2d((1, 1))(data), 1)
return data
[docs]def getitem(data: Tensor, **kwargs):
if data.dim() == 3:
return data[:, 0].clone().contiguous()
elif data.dim() > 3:
raise ValueError("Data must be a 3D or 2D tensor")
return data
[docs]def avg_or_getitem(data: Tensor, **kwargs):
if data.dim() == 3:
return data[:, 0].clone().contiguous()
elif data.dim() > 3:
return torch.flatten(nn.AdaptiveAvgPool2d((1, 1))(data), 1)
return data
[docs]def max_or_getitem(data: Tensor, **kwargs):
if data.dim() == 3:
return data[:, 0].clone().contiguous()
elif data.dim() > 3:
return torch.flatten(nn.AdaptiveMaxPool2d((1, 1))(data), 1)
return data
[docs]def none_reduction(data: Tensor, **kwargs):
return data
reductions_registry = {
"flatten": flatten,
"avg": adaptive_avg_pool2d,
"max": adaptive_max_pool2d,
"getitem": getitem,
"avg_or_getitem": avg_or_getitem,
"max_or_getitem": max_or_getitem,
"none": none_reduction,
}
[docs]def create_reduction(reduction: str, **kwargs):
return partial(reductions_registry[reduction], **kwargs)
[docs]def get_penultimate_layer_name(model: nn.Module):
return list(model._modules.keys())[-2]
[docs]def get_penultimate_layer(model: nn.Module):
return list(model._modules.values())[-2]
[docs]def get_last_layer_name(model: nn.Module):
return list(model._modules.keys())[-1]
[docs]def get_last_layer(model: nn.Module):
return list(model._modules.values())[-1]
# matrix operations
[docs]def torch_reduction_matrix(sigma: Tensor, reduction_method="pseudo"):
import torch
if reduction_method == "cholesky":
C = torch.linalg.cholesky(sigma)
return torch.linalg.inv(C.T)
elif reduction_method == "svd":
u, s, _ = torch.linalg.svd(sigma)
return u @ torch.diag(torch.sqrt(1 / s))
elif reduction_method == "pseudo" or reduction_method == "pinv":
return torch.linalg.pinv(sigma)
elif reduction_method == "inverse" or reduction_method == "inv":
return torch.linalg.inv(sigma)
else:
raise ValueError(f"Unknown reduction method {reduction_method}")
[docs]def sklearn_cov_matrix_estimarion(
x: np.ndarray,
method: Literal[
"EmpiricalCovariance",
"GraphicalLasso",
"GraphicalLassoCV",
"LedoitWolf",
"MinCovDet",
"ShrunkCovariance",
"OAS",
] = "EmpiricalCovariance",
**method_kwargs,
):
import sklearn.covariance
try:
method = getattr(sklearn.covariance, method)(**method_kwargs)
except AttributeError:
raise ValueError(f"Unknown method {method}")
method.fit(x)
cov_mat = method.covariance_
_logger.debug("Cov mat determinant %s", np.linalg.det(cov_mat))
_logger.debug("Cov mat rank %s", np.linalg.matrix_rank(cov_mat))
_logger.debug("Cov mat condition number %s", np.linalg.cond(cov_mat))
_logger.debug("Cov mat norm %s", np.linalg.norm(cov_mat))
_logger.debug("Cov mat trace %s", np.trace(cov_mat))
_logger.debug("Cov mat eigvals %s", np.linalg.eigvalsh(cov_mat))
return method.location_, method.covariance_, method.precision_
[docs]def get_composed_attr(model, attrs: List[str]):
return reduce(lambda x, y: getattr(x, y), attrs, model)
[docs]def add_output_op(feature_extractor, output_op: Callable) -> nn.Module:
last_node = [n for n in feature_extractor.graph.nodes if n.op == "output"][0]
last_node_args = last_node.args
feature_extractor.graph.erase_node(last_node)
nodes = [n for n in feature_extractor.graph.nodes]
with feature_extractor.graph.inserting_after(nodes[-1]):
new_node = feature_extractor.graph.call_function(output_op, args=last_node_args)
nodes = [n for n in feature_extractor.graph.nodes]
with feature_extractor.graph.inserting_after(nodes[-1]):
feature_extractor.graph.output(new_node)
feature_extractor.recompile()
return feature_extractor