Source code for causalpy.checks.convex_hull

#   Copyright 2022 - 2026 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
Convex hull diagnostic check for Synthetic Control experiments.

Verifies that pre-treatment values of treated units fall within
the range of control units — a key assumption of the synthetic
control method.
"""

from __future__ import annotations

import pandas as pd

from causalpy.checks.base import CheckResult
from causalpy.experiments.base import BaseExperiment
from causalpy.experiments.synthetic_control import SyntheticControl
from causalpy.pipeline import PipelineContext
from causalpy.utils import check_convex_hull_violation


[docs] class ConvexHullCheck: """Check that treated unit values lie within the convex hull of controls. Wraps the existing ``SyntheticControl._check_convex_hull()`` logic. """ applicable_methods: set[type[BaseExperiment]] = {SyntheticControl}
[docs] def validate(self, experiment: BaseExperiment) -> None: """Verify the experiment is a SyntheticControl instance.""" if not isinstance(experiment, SyntheticControl): raise TypeError("ConvexHullCheck requires a SyntheticControl experiment.")
[docs] def run( self, experiment: BaseExperiment, context: PipelineContext, ) -> CheckResult: """Run the convex hull violation check on pre-treatment data.""" sc = experiment datapre_control = sc.datapre_control # type: ignore[attr-defined] datapre_treated = sc.datapre_treated # type: ignore[attr-defined] all_results = [] total_violations = 0 all_pass = True for unit_idx in range(datapre_treated.shape[1]): treated_series = datapre_treated[:, unit_idx] result = check_convex_hull_violation(treated_series, datapre_control) all_results.append(result) total_violations += result["n_violations"] if not result["passes"]: all_pass = False rows = [] treated_units = getattr( sc, "treated_units", [f"unit_{i}" for i in range(len(all_results))] ) # type: ignore[attr-defined] for unit_name, res in zip(treated_units, all_results, strict=True): rows.append( { "treated_unit": unit_name, "passes": res["passes"], "n_violations": res["n_violations"], "pct_above": res["pct_above"], "pct_below": res["pct_below"], } ) table = pd.DataFrame(rows) if rows else None if all_pass: text = ( "Convex hull check passed: all treated unit values lie " "within the range of control units in the pre-treatment period." ) else: text = ( f"Convex hull check failed: {total_violations} violations " f"detected. Some treated unit values fall outside the range " f"of control units, which may compromise the synthetic " f"control fit." ) return CheckResult( check_name="ConvexHullCheck", passed=all_pass, table=table, text=text, )