Source code for causalpy.steps.sensitivity

#   Copyright 2022 - 2026 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
SensitivityAnalysis pipeline step.

A container step that holds a list of pluggable ``Check`` objects and
runs them against the fitted experiment.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any

from causalpy.checks.base import Check, CheckResult
from causalpy.experiments.base import BaseExperiment
from causalpy.pipeline import PipelineContext

logger = logging.getLogger(__name__)

# Registry mapping experiment types to their default checks.
# Populated by individual check modules via ``register_default_check``.
_DEFAULT_CHECKS: dict[type[BaseExperiment], list[type]] = {}


[docs] def register_default_check( check_class: type, experiment_types: set[type[BaseExperiment]], ) -> None: """Register a check class as a default for the given experiment types. Called by check modules at import time so that ``SensitivityAnalysis.default_for`` can auto-select checks. """ for exp_type in experiment_types: _DEFAULT_CHECKS.setdefault(exp_type, []).append(check_class)
[docs] @dataclass class SensitivitySummary: """Aggregate result of all sensitivity checks. Attributes ---------- results : list[CheckResult] Individual check results. all_passed : bool or None ``True`` if every check with a pass/fail criterion passed, ``False`` if any failed, or ``None`` if no check had a pass/fail criterion. text : str Combined prose summary. """ results: list[CheckResult] = field(default_factory=list) all_passed: bool | None = None text: str = ""
[docs] @classmethod def from_results(cls, results: list[CheckResult]) -> SensitivitySummary: """Build a summary from a list of check results.""" verdicts = [r.passed for r in results if r.passed is not None] all_passed = all(verdicts) if verdicts else None texts = [r.text for r in results if r.text] combined_text = "\n\n".join(texts) return cls(results=list(results), all_passed=all_passed, text=combined_text)
[docs] class SensitivityAnalysis: """Pipeline step that runs a suite of sensitivity / diagnostic checks. Parameters ---------- checks : list of Check The checks to run against the fitted experiment. Examples -------- >>> import causalpy as cp # doctest: +SKIP >>> step = cp.SensitivityAnalysis( # doctest: +SKIP ... checks=[ ... cp.checks.PlaceboInTime(n_folds=4), ... cp.checks.PriorSensitivity(priors=[...]), ... ] ... ) """
[docs] def __init__(self, checks: list[Any] | None = None) -> None: self.checks: list[Any] = list(checks) if checks else []
[docs] @classmethod def default_for(cls, method: type[BaseExperiment]) -> SensitivityAnalysis: """Create a ``SensitivityAnalysis`` pre-loaded with all registered default checks for *method*. Parameters ---------- method : type[BaseExperiment] The experiment class to look up defaults for. Returns ------- SensitivityAnalysis Instance with applicable default checks instantiated. """ check_classes = _DEFAULT_CHECKS.get(method, []) checks = [cc() for cc in check_classes] return cls(checks=checks)
[docs] def validate(self, context: PipelineContext) -> None: """Validate that checks are well-formed. At validation time the experiment may not yet be fitted, so we only check structural issues (e.g. that each object satisfies the Check protocol). Raises ------ TypeError If any item in ``checks`` does not satisfy the ``Check`` protocol. """ for i, check in enumerate(self.checks): if not isinstance(check, Check): raise TypeError( f"Check {i} ({type(check).__name__}) does not satisfy the " f"Check protocol" )
[docs] def run(self, context: PipelineContext) -> PipelineContext: """Run all checks against the fitted experiment. Raises ------ RuntimeError If no experiment has been fitted (``context.experiment is None``). TypeError If a check is not applicable to the experiment type. """ if context.experiment is None: raise RuntimeError( "SensitivityAnalysis requires a fitted experiment in the " "pipeline context. Add an EstimateEffect step before " "SensitivityAnalysis." ) experiment = context.experiment experiment_type = type(experiment) results: list[CheckResult] = [] for check in self.checks: if experiment_type not in check.applicable_methods: raise TypeError( f"{type(check).__name__} is not applicable to " f"{experiment_type.__name__}. Applicable methods: " f"{[m.__name__ for m in check.applicable_methods]}" ) check.validate(experiment) logger.info("Running check: %s", type(check).__name__) result = check.run(experiment, context) results.append(result) summary = SensitivitySummary.from_results(results) context.sensitivity_results = results context.report = summary # overwritten by GenerateReport if present return context
def __repr__(self) -> str: """Return a string representation of the step.""" check_names = [type(c).__name__ for c in self.checks] return f"SensitivityAnalysis(checks={check_names})"