Source code for causalpy.checks.prior_sensitivity

#   Copyright 2022 - 2026 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
Prior sensitivity check for Bayesian causal inference experiments.

Re-fits the experiment with alternative prior specifications and
compares posterior estimates to assess how sensitive the conclusions
are to prior choices.
"""

from __future__ import annotations

import logging
from typing import Any

import pandas as pd

from causalpy.checks.base import CheckResult, clone_model
from causalpy.experiments.base import BaseExperiment
from causalpy.experiments.diff_in_diff import DifferenceInDifferences
from causalpy.experiments.instrumental_variable import InstrumentalVariable
from causalpy.experiments.interrupted_time_series import InterruptedTimeSeries
from causalpy.experiments.inverse_propensity_weighting import (
    InversePropensityWeighting,
)
from causalpy.experiments.prepostnegd import PrePostNEGD
from causalpy.experiments.regression_discontinuity import RegressionDiscontinuity
from causalpy.experiments.regression_kink import RegressionKink
from causalpy.experiments.staggered_did import StaggeredDifferenceInDifferences
from causalpy.experiments.synthetic_control import SyntheticControl
from causalpy.pipeline import PipelineContext
from causalpy.pymc_models import PyMCModel

logger = logging.getLogger(__name__)


[docs] class PriorSensitivity: """Re-fit the experiment with alternative models/priors and compare. Each alternative is specified as a dict with ``"name"`` and ``"model"`` keys. The check re-instantiates the experiment for each alternative model and compares the resulting effect summaries. Parameters ---------- alternatives : list of dict Each dict must have ``"name"`` (str) and ``"model"`` (PyMCModel or RegressorMixin) keys. Examples -------- >>> import causalpy as cp # doctest: +SKIP >>> check = cp.checks.PriorSensitivity( # doctest: +SKIP ... alternatives=[ ... {"name": "diffuse", "model": cp.pymc_models.LinearRegression(...)}, ... {"name": "tight", "model": cp.pymc_models.LinearRegression(...)}, ... ] ... ) """ applicable_methods: set[type[BaseExperiment]] = { InterruptedTimeSeries, DifferenceInDifferences, SyntheticControl, StaggeredDifferenceInDifferences, RegressionDiscontinuity, RegressionKink, PrePostNEGD, InversePropensityWeighting, InstrumentalVariable, }
[docs] def __init__(self, alternatives: list[dict[str, Any]]) -> None: if not alternatives: raise ValueError("alternatives must be a non-empty list") for i, alt in enumerate(alternatives): if "name" not in alt or "model" not in alt: raise ValueError( f"Alternative {i} must have 'name' and 'model' keys, " f"got keys: {list(alt.keys())}" ) self.alternatives = alternatives
[docs] def validate(self, experiment: BaseExperiment) -> None: """Verify the experiment uses a Bayesian (PyMC) model.""" if not isinstance(experiment.model, PyMCModel): raise TypeError( "PriorSensitivity requires a Bayesian (PyMC) model. " f"Got {type(experiment.model).__name__}." )
[docs] def run( self, experiment: BaseExperiment, context: PipelineContext, ) -> CheckResult: """Re-fit with each alternative model and compare effect estimates.""" if context.experiment_config is None: raise RuntimeError( "No experiment_config in context. Use EstimateEffect " "before SensitivityAnalysis." ) method = context.experiment_config["method"] base_kwargs = { k: v for k, v in context.experiment_config.items() if k not in ("method", "model") } rows: list[dict[str, Any]] = [] for alt in self.alternatives: name = alt["name"] model = clone_model(alt["model"]) logger.info("PriorSensitivity: fitting with '%s'", name) alt_experiment = method(context.data, model=model, **base_kwargs) try: summary = alt_experiment.effect_summary() row: dict[str, Any] = {"prior_spec": name} if summary.table is not None and not summary.table.empty: for col in summary.table.columns: row[col] = summary.table[col].iloc[0] rows.append(row) except (NotImplementedError, Exception) as exc: logger.warning( "PriorSensitivity: effect_summary() failed for '%s': %s", name, exc, ) rows.append({"prior_spec": name, "error": str(exc)}) table = pd.DataFrame(rows) if rows else None text = ( f"Prior sensitivity analysis: compared {len(self.alternatives)} " f"alternative prior specifications." ) return CheckResult( check_name="PriorSensitivity", passed=None, table=table, text=text, )