Source code for detectors.pipelines

"""
Pipeline module.
"""
from enum import Enum
from typing import Any, List, Optional, Tuple

from detectors.pipelines.base import Pipeline

pipeline_registry = {}


[docs]def register_pipeline(name: str, dataset_name: Optional[str] = None): """ Decorator to register a new pipeline in the registry. Args: name (str): The name of the pipeline to register. """ def decorator(f): if dataset_name is None: pipeline_registry[name] = f else: pipeline_registry[name + "_" + dataset_name] = f return f return decorator
[docs]def create_pipeline(name: str, dataset_name: Optional[str] = None, **kwargs) -> Pipeline: """ Utility factory method to build a Pipeline. Args: name (str, optional): The name defining which pipeline will be returned. Currently accepted pipeline names are: `ood_benchmark`, `ood_validation`, `ood_validation_noise`, `covariate_drift`, `osr` Complete list: `ood_benchmark_cifar10`, `ood_benchmark_cifar100`, `ood_benchmark_imagenet`, `ood_mnist_benchmark`, `ood_validation_cifar10`, `ood_validation_cifar100`, `ood_validation_imagenet`, `ood_validation_mnist`, `ood_validation_noise_cifar10`, `ood_validation_noise_cifar100`, `ood_validation_noise_imagenet`, `covariate_drift_cifar10`, `covariate_drift_cifar100`, `covariate_drift_imagenet`. **kwargs: Additional arguments to pass to the pipeline. Returns: Pipeline: A suitable pipeline for the task. Example:: >>> import detectors >>> pipe = detectors.create_pipeline("ood_benchmark", "cifar10") >>> pipe.run(detector) """ if dataset_name is None: return pipeline_registry[name](**kwargs) return pipeline_registry[name + "_" + dataset_name](**kwargs)
[docs]def list_pipelines() -> List[str]: """ List all available pipelines. Returns: List[str]: A list of available pipelines. """ return list(pipeline_registry.keys())
[docs]def list_pipeline_args(name: str) -> List[Tuple[str, Any]]: """ List all available arguments for a given pipeline. Args: name (str): The name of the pipeline. Returns: list: A list of available arguments and default values for the pipeline. """ import inspect signature = inspect.signature(pipeline_registry[name]).parameters return [(name, parameter.default) for name, parameter in signature.items()]
from .covariate_drift import * from .ood import * PipelinesRegistry = Enum("PipelinesRegistry", dict(zip(list_pipelines(), list_pipelines())))