Source code for causalpy.pipeline
# Copyright 2022 - 2026 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pipeline orchestration for composable causal inference workflows.
Provides a ``Pipeline`` class that chains steps (``EstimateEffect``,
``SensitivityAnalysis``, ``GenerateReport``) into a reproducible,
lazily-validated workflow. All steps are validated before any fitting
begins so that configuration errors surface before expensive MCMC sampling.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
import pandas as pd
from causalpy.experiments.base import BaseExperiment
from causalpy.reporting import EffectSummary
[docs]
@dataclass
class PipelineContext:
"""Mutable container that accumulates results as pipeline steps execute.
Each step reads from and writes to this context, building up a complete
record of the analysis.
Attributes
----------
data : pd.DataFrame
The input dataset.
experiment : BaseExperiment or None
The fitted experiment object, populated by ``EstimateEffect``.
experiment_config : dict or None
The configuration used to create the experiment (method class +
keyword arguments), so that downstream steps like
``SensitivityAnalysis`` can derive experiment factories.
effect_summary : EffectSummary or None
The effect summary from the primary experiment.
sensitivity_results : list
Accumulated sensitivity / diagnostic check results.
report : object or None
Generated report artifact, populated by ``GenerateReport``.
"""
data: pd.DataFrame
experiment: BaseExperiment | None = None
experiment_config: dict[str, Any] | None = None
effect_summary: EffectSummary | None = None
sensitivity_results: list[Any] = field(default_factory=list)
report: Any = None
[docs]
@dataclass
class PipelineResult:
"""Immutable result returned by :meth:`Pipeline.run`.
Attributes
----------
experiment : BaseExperiment or None
The fitted experiment.
effect_summary : EffectSummary or None
The effect summary from the experiment.
sensitivity_results : list
Results of all sensitivity / diagnostic checks.
report : object or None
Generated report artifact.
"""
experiment: BaseExperiment | None
effect_summary: EffectSummary | None
sensitivity_results: list[Any]
report: Any
[docs]
@classmethod
def from_context(cls, context: PipelineContext) -> PipelineResult:
"""Build a ``PipelineResult`` from a completed ``PipelineContext``."""
return cls(
experiment=context.experiment,
effect_summary=context.effect_summary,
sensitivity_results=list(context.sensitivity_results),
report=context.report,
)
[docs]
@runtime_checkable
class Step(Protocol):
"""Protocol that all pipeline steps must satisfy.
Implementations must provide two methods:
* ``validate`` -- called *before* any step runs. Should raise on
configuration errors (wrong types, missing parameters, etc.).
* ``run`` -- called sequentially. Receives the shared
``PipelineContext``, mutates it, and returns it.
"""
[docs]
def validate(self, context: PipelineContext) -> None:
"""Check configuration before execution."""
...
[docs]
def run(self, context: PipelineContext) -> PipelineContext:
"""Execute the step, mutating and returning the context."""
...
[docs]
class Pipeline:
"""Orchestrate a sequence of causal-inference steps.
The pipeline validates *all* steps before executing any of them,
ensuring configuration errors are caught before potentially expensive
model fitting.
Parameters
----------
data : pd.DataFrame
The dataset to analyse.
steps : list of Step
Ordered sequence of pipeline steps.
Examples
--------
>>> import causalpy as cp # doctest: +SKIP
>>> result = cp.Pipeline( # doctest: +SKIP
... data=df,
... steps=[
... cp.EstimateEffect(
... method=cp.InterruptedTimeSeries,
... treatment_time=pd.Timestamp("2020-01-01"),
... formula="y ~ 1 + t",
... model=cp.pymc_models.LinearRegression(),
... ),
... ],
... ).run()
"""
[docs]
def __init__(self, data: pd.DataFrame, steps: list[Step]) -> None:
if not isinstance(data, pd.DataFrame):
raise TypeError(
f"data must be a pandas DataFrame, got {type(data).__name__}"
)
if not steps:
raise ValueError("steps must be a non-empty list")
for i, step in enumerate(steps):
if not isinstance(step, Step):
raise TypeError(
f"Step {i} ({type(step).__name__}) does not satisfy the "
f"Step protocol (must implement validate and run)"
)
self.data = data
self.steps = list(steps)
[docs]
def run(self) -> PipelineResult:
"""Validate all steps, then execute them sequentially.
Returns
-------
PipelineResult
The accumulated results of the pipeline.
Raises
------
Exception
Re-raises any exception from validation or step execution.
"""
context = PipelineContext(data=self.data)
for step in self.steps:
step.validate(context)
for step in self.steps:
context = step.run(context)
return PipelineResult.from_context(context)