# Copyright 2022 - 2026 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interrupted Time Series Analysis."""
from dataclasses import dataclass
from typing import Any, Literal
import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import pyplot as plt
from patsy import build_design_matrices, dmatrices
from sklearn.base import RegressorMixin
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
from causalpy.custom_exceptions import BadIndexException
from causalpy.date_utils import _combine_datetime_indices, format_date_axes
from causalpy.experiments.model_adapter import build_coords
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import LinearRegression, PyMCModel
from causalpy.reporting import EffectSummary
from causalpy.utils import _as_scalar, round_num
from .base import BaseExperiment
@dataclass
class _Config:
"""Container for experiment configuration."""
data: pd.DataFrame
treatment_time: int | float | pd.Timestamp
treatment_end_time: int | float | pd.Timestamp | None
expt_type: str
formula: str
@dataclass
class _DesignInfo:
"""Container for design matrix metadata."""
outcome_variable_name: str
labels: list[str]
@dataclass
class _DesignData:
"""Container for pre and post design datasets."""
pre_design: xr.Dataset
post_design: xr.Dataset
@dataclass
class _Results:
"""Container for experiment results."""
score: Any
pre_pred: Any
post_pred: Any
pre_impact: Any
post_impact: Any
post_impact_cumulative: Any
plot_data: Any = None
@dataclass
class _PeriodData:
"""Container for three-period analysis results.
Groups the data, predictions, impacts, and cumulative impacts for a single
period (intervention or post-intervention) into one object, replacing 8
separate instance attributes.
"""
data: pd.DataFrame
pred: Any
impact: Any
impact_cumulative: Any
def _as_array(data: Any) -> np.ndarray:
"""Convert data to a numpy array, handling .values attribute."""
if hasattr(data, "values"):
return data.values
return np.asarray(data)
[docs]
class InterruptedTimeSeries(BaseExperiment):
"""
The class for interrupted time series analysis.
Supports both two-period (permanent intervention) and three-period (temporary
intervention) designs. When ``treatment_end_time`` is provided, the analysis
splits the post-intervention period into an intervention period and a
post-intervention period, enabling analysis of effect persistence and decay.
Parameters
----------
data : pd.DataFrame
A pandas dataframe with time series data. The index should be either
a DatetimeIndex or numeric (integer/float).
treatment_time : Union[int, float, pd.Timestamp]
The time when treatment occurred, should be in reference to the data index.
Must match the index type (DatetimeIndex requires pd.Timestamp).
**INCLUSIVE**: Observations at exactly ``treatment_time`` are included in the
post-intervention period (uses ``>=`` comparison).
formula : str
A statistical model formula using patsy syntax (e.g., "y ~ 1 + t + C(month)").
model : Union[PyMCModel, RegressorMixin], optional
A PyMC (Bayesian) or sklearn (OLS) model. If None, defaults to a PyMC
LinearRegression model.
treatment_end_time : Union[int, float, pd.Timestamp], optional
The time when treatment ended, enabling three-period analysis. Must be
greater than ``treatment_time`` and within the data range. If None (default),
the analysis assumes a permanent intervention (two-period design).
**INCLUSIVE**: Observations at exactly ``treatment_end_time`` are included in the
post-intervention period (uses ``>=`` comparison).
**kwargs : dict
Additional keyword arguments passed to the model.
Notes
-----
For Bayesian models, the causal impact is calculated using the posterior expectation
(``mu``) rather than the posterior predictive (``y_hat``). This means the impact and
its uncertainty represent the systematic causal effect, excluding observation-level
noise. The uncertainty bands in the plots reflect parameter uncertainty and
counterfactual prediction uncertainty, but not individual observation variability.
The three-period design is useful for analyzing temporary interventions such as:
- Marketing campaigns with defined start and end dates
- Policy trials or pilot programs
- Clinical treatments with limited duration
- Seasonal interventions
Use ``effect_summary(period="intervention")`` to analyze effects during the
intervention, and ``effect_summary(period="post")`` to analyze effect persistence
after the intervention ends.
Examples
--------
**Two-period design (permanent intervention):**
>>> import causalpy as cp
>>> df = (
... cp.load_data("its")
... .assign(date=lambda x: pd.to_datetime(x["date"]))
... .set_index("date")
... )
>>> treatment_time = pd.to_datetime("2017-01-01")
>>> result = cp.InterruptedTimeSeries(
... df,
... treatment_time,
... formula="y ~ 1 + t + C(month)",
... model=cp.pymc_models.LinearRegression(
... sample_kwargs={"random_seed": 42, "progressbar": False}
... ),
... )
**Three-period design (temporary intervention):**
>>> treatment_time = pd.to_datetime("2017-01-01")
>>> treatment_end_time = pd.to_datetime("2017-06-01")
>>> result = cp.InterruptedTimeSeries(
... df,
... treatment_time,
... formula="y ~ 1 + t + C(month)",
... model=cp.pymc_models.LinearRegression(
... sample_kwargs={"random_seed": 42, "progressbar": False}
... ),
... treatment_end_time=treatment_end_time,
... )
>>> # Get period-specific effect summaries
>>> intervention_summary = result.effect_summary(period="intervention")
>>> post_summary = result.effect_summary(period="post")
"""
supports_ols = True
supports_bayes = True
_default_model_class = LinearRegression
_deprecated_design_aliases = {
"pre_X": ("pre_design", "X"),
"pre_y": ("pre_design", "y"),
"post_X": ("post_design", "X"),
"post_y": ("post_design", "y"),
}
[docs]
def __init__(
self,
data: pd.DataFrame,
treatment_time: int | float | pd.Timestamp,
formula: str,
model: PyMCModel | RegressorMixin | None = None,
treatment_end_time: int | float | pd.Timestamp | None = None,
**kwargs: Any,
) -> None:
super().__init__(model=model)
data.index.name = "obs_ind"
self.input_validation(data, treatment_time, treatment_end_time)
self._config = _Config(
data=data,
treatment_time=treatment_time,
treatment_end_time=treatment_end_time,
expt_type="Pre-Post Fit",
formula=formula,
)
design_info, pre_raw, post_raw = self._build_design_matrices()
self._design_info = design_info
self._design_data = self._prepare_data(pre_raw, post_raw)
self._period_results = None
self.algorithm()
# ------------------------------------------------------------------
# Backward-compatible properties — delegate to container dataclasses.
# These preserve the public API so existing code (including tests)
# continues to work without changes.
# ------------------------------------------------------------------
@property
def data(self) -> pd.DataFrame:
"""Experiment data."""
return self._config.data
@property
def treatment_time(self) -> int | float | pd.Timestamp:
"""Start of the treatment period."""
return self._config.treatment_time
@property
def treatment_end_time(self) -> int | float | pd.Timestamp | None:
"""Optional end of the treatment period."""
return self._config.treatment_end_time
@property
def formula(self) -> str:
"""Statistical model formula."""
return self._config.formula
@property
def expt_type(self) -> str:
"""Experiment type label."""
return self._config.expt_type
@property
def outcome_variable_name(self) -> str:
"""Name of the outcome variable."""
return self._design_info.outcome_variable_name
@property
def labels(self) -> list[str]:
"""Coefficient / column names for the design matrix."""
return self._design_info.labels
@property
def pre_design(self) -> xr.Dataset:
"""Design matrix for the pre-intervention period."""
return self._design_data.pre_design
@property
def post_design(self) -> xr.Dataset:
"""Design matrix for the post-intervention period."""
return self._design_data.post_design
@property
def score(self) -> Any:
"""Model score (e.g. Bayesian R²)."""
return self._results.score
@property
def pre_pred(self) -> Any:
"""Predictions for the pre-intervention period."""
return self._results.pre_pred
@property
def post_pred(self) -> Any:
"""Predictions for the post-intervention period."""
return self._results.post_pred
@property
def pre_impact(self) -> Any:
"""Causal impact for the pre-intervention period."""
return self._results.pre_impact
@property
def post_impact(self) -> Any:
"""Causal impact for the post-intervention period."""
return self._results.post_impact
@property
def post_impact_cumulative(self) -> Any:
"""Cumulative causal impact for the post-intervention period."""
return self._results.post_impact_cumulative
@property
def plot_data(self) -> Any:
"""Cached plot data frame."""
return self._results.plot_data
@plot_data.setter
def plot_data(self, value: Any) -> None:
self._results.plot_data = value
def _build_design_matrices(self) -> tuple[_DesignInfo, tuple, tuple]:
"""Build design matrices for pre and post intervention periods using patsy.
Returns
-------
design_info : _DesignInfo
Design matrix metadata.
pre_raw : tuple of (np.ndarray, np.ndarray)
Pre-period X and y raw arrays.
post_raw : tuple of (np.ndarray, np.ndarray)
Post-period X and y raw arrays.
"""
y, X = dmatrices(self.formula, self.datapre)
y_design_info = y.design_info
x_design_info = X.design_info
labels = X.design_info.column_names
pre_y_raw, pre_X_raw = np.asarray(y), np.asarray(X)
(new_y, new_x) = build_design_matrices(
[y_design_info, x_design_info], self.datapost
)
post_X_raw = np.asarray(new_x)
post_y_raw = np.asarray(new_y)
return (
_DesignInfo(
outcome_variable_name=y.design_info.column_names[0],
labels=labels,
),
(pre_X_raw, pre_y_raw),
(post_X_raw, post_y_raw),
)
def _prepare_data(
self, pre_raw: tuple, post_raw: tuple
) -> _DesignData:
"""Bundle design matrices into ``xr.Dataset`` objects for pre and post periods.
Parameters
----------
pre_raw : tuple of (np.ndarray, np.ndarray)
Pre-period X and y raw arrays from ``_build_design_matrices``.
post_raw : tuple of (np.ndarray, np.ndarray)
Post-period X and y raw arrays from ``_build_design_matrices``.
Returns
-------
_DesignData
Container with pre_design and post_design xr.Dataset objects.
"""
pre_X_raw, pre_y_raw = pre_raw
post_X_raw, post_y_raw = post_raw
pre_design = self._build_design_dataset(
pre_X_raw,
pre_y_raw,
obs_ind=self.datapre.index,
coeffs=self.labels,
)
post_design = self._build_design_dataset(
post_X_raw,
post_y_raw,
obs_ind=self.datapost.index,
coeffs=self.labels,
)
return _DesignData(pre_design=pre_design, post_design=post_design)
[docs]
def algorithm(self) -> None:
"""Run the experiment algorithm: fit model, predict, and calculate causal impact."""
pre_X = self.pre_design["X"]
pre_y = self.pre_design["y"]
post_X = self.post_design["X"]
post_y = self.post_design["y"]
self._model_backend.fit(
X=pre_X,
y=pre_y,
coords=build_coords(
self.labels,
pre_X.shape[0],
datetime_index=self.datapre.index,
),
)
score = self._model_backend.score(X=pre_X, y=pre_y)
pre_pred = self._model_backend.predict(X=pre_X)
post_pred = self._model_backend.predict(X=post_X, out_of_sample=True)
if self._model_backend.is_bayesian:
pre_impact = self.model.calculate_impact(pre_y, pre_pred)
post_impact = self.model.calculate_impact(post_y, post_pred)
else:
pre_impact = self.model.calculate_impact(
pre_y.isel(treated_units=0), pre_pred
)
post_impact = self.model.calculate_impact(
post_y.isel(treated_units=0), post_pred
)
post_impact_cumulative = self.model.calculate_cumulative_impact(
post_impact
)
self._results = _Results(
score=score,
pre_pred=pre_pred,
post_pred=post_pred,
pre_impact=pre_impact,
post_impact=post_impact,
post_impact_cumulative=post_impact_cumulative,
)
# Split post period into intervention and post-intervention if treatment_end_time is provided
if self.treatment_end_time is not None:
self._split_post_period()
@property
def datapre(self) -> pd.DataFrame:
"""Data from before the treatment time (exclusive).
Pre-period: index < treatment_time
"""
return self.data[self.data.index < self.treatment_time]
@property
def datapost(self) -> pd.DataFrame:
"""Data from on or after the treatment time (inclusive).
Post-period: index >= treatment_time
"""
return self.data[self.data.index >= self.treatment_time]
def _split_post_period(self) -> None:
"""Split post period into intervention and post-intervention periods.
Creates _PeriodData objects for each period and stores them in
``self._period_results``. Only called when treatment_end_time is provided.
Key insight: intervention_pred and post_intervention_pred are slices of post_pred,
not new computations. The model makes one continuous forecast (post_pred), which is
then sliced into two periods for analysis.
NOTE: treatment_end_time is INCLUSIVE (>=) in post-intervention period.
- Intervention period: treatment_time <= index < treatment_end_time
- Post-intervention period: index >= treatment_end_time (inclusive)
"""
during_mask = self.datapost.index < self.treatment_end_time
post_mask = self.datapost.index >= self.treatment_end_time
intervention_data = self.datapost[during_mask]
post_intervention_data = self.datapost[post_mask]
is_pymc = self._model_backend.is_bayesian
if is_pymc:
time_dim = "obs_ind"
intervention_coords = intervention_data.index
post_intervention_coords = post_intervention_data.index
intervention_pred_dataset = self.post_pred.posterior_predictive.sel(
{time_dim: intervention_coords}
)
post_intervention_pred_dataset = self.post_pred.posterior_predictive.sel(
{time_dim: post_intervention_coords}
)
intervention_pred = az.InferenceData(
posterior_predictive=intervention_pred_dataset
)
post_intervention_pred = az.InferenceData(
posterior_predictive=post_intervention_pred_dataset
)
if "treated_units" in self.post_impact.dims:
post_impact_sel = self.post_impact.isel(treated_units=0)
else:
post_impact_sel = self.post_impact
intervention_impact = post_impact_sel.sel(
{time_dim: intervention_coords}
)
post_intervention_impact = post_impact_sel.sel(
{time_dim: post_intervention_coords}
)
intervention_impact_cumulative = (
self.model.calculate_cumulative_impact(intervention_impact)
)
post_intervention_impact_cumulative = (
self.model.calculate_cumulative_impact(post_intervention_impact)
)
else:
intervention_indices = [
self.datapost.index.get_loc(coord)
for coord in intervention_data.index
]
post_intervention_indices = [
self.datapost.index.get_loc(coord)
for coord in post_intervention_data.index
]
intervention_pred = self.post_pred[intervention_indices]
post_intervention_pred = self.post_pred[post_intervention_indices]
intervention_impact = self.post_impact[intervention_indices]
post_intervention_impact = self.post_impact[post_intervention_indices]
intervention_impact_cumulative = (
self.model.calculate_cumulative_impact(intervention_impact)
)
post_intervention_impact_cumulative = (
self.model.calculate_cumulative_impact(post_intervention_impact)
)
self._period_results = {
"intervention": _PeriodData(
data=intervention_data,
pred=intervention_pred,
impact=intervention_impact,
impact_cumulative=intervention_impact_cumulative,
),
"post_intervention": _PeriodData(
data=post_intervention_data,
pred=post_intervention_pred,
impact=post_intervention_impact,
impact_cumulative=post_intervention_impact_cumulative,
),
}
# Backward-compatible properties for three-period attributes.
# These delegate to self._period_results so internal code and external
# tests continue to work without change.
@property
def data_intervention(self) -> pd.DataFrame:
"""Data from the intervention period (treatment_time <= index < treatment_end_time)."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["intervention"].data
@property
def data_post_intervention(self) -> pd.DataFrame:
"""Data from the post-intervention period (index >= treatment_end_time)."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["post_intervention"].data
@property
def intervention_pred(self) -> Any:
"""Predictions for the intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["intervention"].pred
@property
def post_intervention_pred(self) -> Any:
"""Predictions for the post-intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["post_intervention"].pred
@property
def intervention_impact(self) -> Any:
"""Causal impact for the intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["intervention"].impact
@property
def post_intervention_impact(self) -> Any:
"""Causal impact for the post-intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["post_intervention"].impact
@property
def intervention_impact_cumulative(self) -> Any:
"""Cumulative causal impact for the intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["intervention"].impact_cumulative
@property
def post_intervention_impact_cumulative(self) -> Any:
"""Cumulative causal impact for the post-intervention period."""
if self._period_results is None:
msg = "No three-period results available. Provide treatment_end_time."
raise AttributeError(msg)
return self._period_results["post_intervention"].impact_cumulative
def _comparison_period_summary(
self,
direction: Literal["increase", "decrease", "two-sided"] = "increase",
alpha: float = 0.05,
cumulative: bool = True,
relative: bool = True,
min_effect: float | None = None,
):
"""Generate comparative summary between intervention and post-intervention periods.
Parameters
----------
direction : {"increase", "decrease", "two-sided"}, default="increase"
Direction for tail probability calculation (PyMC only)
alpha : float, default=0.05
Significance level for HDI/CI intervals
cumulative : bool, default=True
Whether to include cumulative effect statistics
relative : bool, default=True
Whether to include relative effect statistics
min_effect : float, optional
Region of Practical Equivalence (ROPE) threshold (PyMC only)
Returns
-------
EffectSummary
Object with .table (DataFrame) and .text (str) attributes
"""
from causalpy.reporting import _extract_hdi_bounds
is_pymc = self._model_backend.is_bayesian
time_dim = "obs_ind"
hdi_prob = 1 - alpha
prob_persisted: float | None
if is_pymc:
# PyMC: Compute statistics for both periods
intervention_avg = self.intervention_impact.mean(dim=time_dim)
intervention_mean = _as_scalar(intervention_avg.mean(dim=["chain", "draw"]))
intervention_hdi = az.hdi(intervention_avg, hdi_prob=hdi_prob)
intervention_lower, intervention_upper = _extract_hdi_bounds(
intervention_hdi, hdi_prob
)
post_avg = self.post_intervention_impact.mean(dim=time_dim)
post_mean = _as_scalar(post_avg.mean(dim=["chain", "draw"]))
post_hdi = az.hdi(post_avg, hdi_prob=hdi_prob)
post_lower, post_upper = _extract_hdi_bounds(post_hdi, hdi_prob)
# Persistence ratio: post_mean / intervention_mean (as percentage)
epsilon = 1e-8
persistence_ratio_pct = (post_mean / (intervention_mean + epsilon)) * 100
# Probability that some effect persisted (P(post_mean > 0))
prob_persisted = _as_scalar((post_avg > 0).mean())
# Build simple table
table = pd.DataFrame(
{
"mean": [intervention_mean, post_mean],
"hdi_lower": [intervention_lower, post_lower],
"hdi_upper": [intervention_upper, post_upper],
"persistence_ratio_pct": [None, persistence_ratio_pct],
"prob_persisted": [None, prob_persisted],
},
index=["intervention", "post_intervention"],
)
# Generate simple prose
hdi_pct = int(hdi_prob * 100)
text = (
f"Effect persistence: The post-intervention effect "
f"({post_mean:.1f}, {hdi_pct}% HDI [{post_lower:.1f}, {post_upper:.1f}]) "
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
f"({intervention_mean:.1f}, {hdi_pct}% HDI [{intervention_lower:.1f}, {intervention_upper:.1f}]), "
f"with a posterior probability of {prob_persisted:.2f} that some effect persisted "
f"beyond the intervention period."
)
else:
# OLS: Compute statistics for both periods
from causalpy.reporting import _compute_statistics_ols
intervention_stats = _compute_statistics_ols(
self.intervention_impact.values
if hasattr(self.intervention_impact, "values")
else np.asarray(self.intervention_impact),
self.intervention_pred,
alpha=alpha,
cumulative=False,
relative=False,
)
post_stats = _compute_statistics_ols(
self.post_intervention_impact.values
if hasattr(self.post_intervention_impact, "values")
else np.asarray(self.post_intervention_impact),
self.post_intervention_pred,
alpha=alpha,
cumulative=False,
relative=False,
)
# Persistence ratio (as percentage)
epsilon = 1e-8
persistence_ratio_pct = (
post_stats["avg"]["mean"]
/ (intervention_stats["avg"]["mean"] + epsilon)
) * 100
# For OLS, use 1 - p-value as proxy for probability
prob_persisted = (
1 - post_stats["avg"]["p_value"]
if "p_value" in post_stats["avg"]
else None
)
# Build simple table
table_data = {
"mean": [
intervention_stats["avg"]["mean"],
post_stats["avg"]["mean"],
],
"ci_lower": [
intervention_stats["avg"]["ci_lower"],
post_stats["avg"]["ci_lower"],
],
"ci_upper": [
intervention_stats["avg"]["ci_upper"],
post_stats["avg"]["ci_upper"],
],
"persistence_ratio_pct": [None, persistence_ratio_pct],
}
if prob_persisted is not None:
table_data["prob_persisted"] = [None, prob_persisted]
table = pd.DataFrame(
table_data,
index=["intervention", "post_intervention"],
)
# Generate simple prose
ci_pct = int((1 - alpha) * 100)
if prob_persisted is not None:
text = (
f"Effect persistence: The post-intervention effect "
f"({post_stats['avg']['mean']:.1f}, {ci_pct}% CI [{post_stats['avg']['ci_lower']:.1f}, {post_stats['avg']['ci_upper']:.1f}]) "
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
f"({intervention_stats['avg']['mean']:.1f}, {ci_pct}% CI [{intervention_stats['avg']['ci_lower']:.1f}, {intervention_stats['avg']['ci_upper']:.1f}]), "
f"with a probability of {prob_persisted:.2f} that some effect persisted "
f"beyond the intervention period."
)
else:
text = (
f"Effect persistence: The post-intervention effect "
f"({post_stats['avg']['mean']:.1f}, {ci_pct}% CI [{post_stats['avg']['ci_lower']:.1f}, {post_stats['avg']['ci_upper']:.1f}]) "
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
f"({intervention_stats['avg']['mean']:.1f}, {ci_pct}% CI [{intervention_stats['avg']['ci_lower']:.1f}, {intervention_stats['avg']['ci_upper']:.1f}])."
)
return EffectSummary(table=table, text=text)
[docs]
def summary(self, round_to: int | None = None) -> None:
"""Print summary of main results and model coefficients.
Parameters
----------
round_to : int, optional
Number of decimals used to round results. Defaults to 2. Use
``None`` to return raw numbers.
"""
print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
self.print_coefficients(round_to)
[docs]
def plot(
self,
*,
round_to: int | None = 2,
hdi_prob: float = HDI_PROB,
figsize: tuple[float, float] = (7, 8),
show: bool = True,
legend_kwargs: dict[str, Any] | None = None,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot the interrupted time-series results.
Parameters
----------
round_to : int, optional
Number of decimals used to round numerical results in the figure
title (e.g. the Bayesian :math:`R^2`). Defaults to 2. Use
``None`` to render raw numbers.
hdi_prob : float
Probability mass of the highest density interval drawn around the
posterior predictive, causal impact, and cumulative impact bands.
Must be in ``(0, 1]``. Ignored for OLS models. Defaults to
:data:`~causalpy.constants.HDI_PROB` (currently 0.94).
figsize : tuple of (float, float)
Width and height of the figure in inches, passed to
:func:`matplotlib.pyplot.subplots`. Defaults to ``(7, 8)``.
show : bool
Whether to automatically display the plot. Defaults to ``True``.
Set to ``False`` if you want to modify the figure before
displaying it.
legend_kwargs : dict, optional
Keyword arguments to adjust legend placement and styling.
Supported keys: ``loc``, ``bbox_to_anchor``, ``fontsize``,
``frameon``, ``title`` (``bbox_transform`` is accepted alongside
``bbox_to_anchor``). The existing legend is modified **in
place** so that custom handles are preserved.
Returns
-------
fig : matplotlib.figure.Figure
The figure that was created.
ax : list[matplotlib.axes.Axes]
The three axes (top: predictions, middle: causal impact,
bottom: cumulative impact).
"""
return self._render_plot(
show=show,
legend_kwargs=legend_kwargs,
round_to=round_to,
hdi_prob=hdi_prob,
figsize=figsize,
)
@staticmethod
def _draw_singleton_hdi_marker(
ax: plt.Axes,
x: Any,
Y: xr.DataArray,
color: str,
hdi_prob: float = HDI_PROB,
) -> Any:
"""Overlay a median dot + HDI errorbar for a single post-period datum.
``plot_xY`` (and the ``arviz.plot_hdi`` it wraps) renders a degenerate
zero-area polygon when the post-period contains a single observation,
so neither the median line nor the HDI ribbon is visible. Drawing an
explicit point and errorbar makes both the central tendency and the
uncertainty plain to read in that edge case. Returns the matplotlib
``ErrorbarContainer`` so callers can use it as a legend handle.
"""
Y_plot = Y.isel(treated_units=0) if "treated_units" in Y.dims else Y
median = float(np.asarray(Y_plot.median(("chain", "draw")).values).item())
hdi = az.hdi(Y_plot, hdi_prob=hdi_prob)
data_var = list(hdi.data_vars)[0]
bounds = np.asarray(hdi[data_var].values).reshape(-1)
lower, upper = float(bounds[0]), float(bounds[1])
return ax.errorbar(
x,
[median],
yerr=[[median - lower], [upper - median]],
fmt="o",
color=color,
ecolor=color,
capsize=4,
zorder=3,
)
def _plot_bayesian_top_panel(
self,
ax: plt.Axes,
hdi_prob: float,
single_post_obs: bool,
round_to: int | None,
) -> tuple[list, list]:
"""Plot the top panel: pre/post intervention predictions.
Returns ``(handles, labels)`` for building the legend.
"""
counterfactual_label = "Counterfactual"
pre_mu = self.pre_pred["posterior_predictive"].mu
pre_mu_plot = (
pre_mu.isel(treated_units=0) if "treated_units" in pre_mu.dims else pre_mu
)
h_line, h_patch = plot_xY(
self.datapre.index,
pre_mu_plot,
ax=ax,
hdi_prob=hdi_prob,
plot_hdi_kwargs={"color": "C0"},
)
handles = [(h_line, h_patch)]
labels = ["Pre-intervention period"]
(h,) = ax.plot(
self.datapre.index,
self.pre_design["y"].isel(treated_units=0),
"k.",
label="Observations",
)
handles.append(h)
labels.append("Observations")
post_mu = self.post_pred["posterior_predictive"].mu
post_mu_plot = (
post_mu.isel(treated_units=0)
if "treated_units" in post_mu.dims
else post_mu
)
h_line, h_patch = plot_xY(
self.datapost.index,
post_mu_plot,
ax=ax,
hdi_prob=hdi_prob,
plot_hdi_kwargs={"color": "C1"},
)
if single_post_obs:
errbar = self._draw_singleton_hdi_marker(
ax, self.datapost.index, post_mu, color="C1"
)
handles.append(errbar)
else:
handles.append((h_line, h_patch))
labels.append(counterfactual_label)
ax.plot(
self.datapost.index,
self.post_design["y"].isel(treated_units=0),
"k.",
zorder=3,
)
post_pred_mu = az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
)
if "treated_units" in post_pred_mu.dims:
post_pred_mu = post_pred_mu.isel(treated_units=0)
post_pred_mu = post_pred_mu.mean("sample")
h = ax.fill_between(
self.datapost.index,
y1=post_pred_mu,
y2=self.post_design["y"].isel(treated_units=0),
color="C0",
alpha=0.25,
)
if not single_post_obs:
handles.append(h)
labels.append("Causal impact")
r2_val = None
r2_std_val = None
try:
if isinstance(self.score, pd.Series):
if "unit_0_r2" in self.score.index:
r2_val = self.score["unit_0_r2"]
r2_std_val = self.score.get("unit_0_r2_std", None)
elif "r2" in self.score.index:
r2_val = self.score["r2"]
r2_std_val = self.score.get("r2_std", None)
except Exception:
pass
title_str = "Pre-intervention Bayesian $R^2$"
if r2_val is not None:
title_str += f": {round_num(r2_val, round_to)}"
if r2_std_val is not None:
title_str += f"\n(std = {round_num(r2_std_val, round_to)})"
ax.set(title=title_str)
return handles, labels
def _plot_bayesian_middle_panel(
self,
ax: plt.Axes,
hdi_prob: float,
single_post_obs: bool,
) -> None:
"""Plot the middle panel: causal impact with HDI bands."""
pre_impact_plot = (
self.pre_impact.isel(treated_units=0)
if hasattr(self.pre_impact, "dims")
and "treated_units" in self.pre_impact.dims
else self.pre_impact
)
plot_xY(
self.datapre.index,
pre_impact_plot,
ax=ax,
hdi_prob=hdi_prob,
plot_hdi_kwargs={"color": "C0"},
)
post_impact_plot = (
self.post_impact.isel(treated_units=0)
if hasattr(self.post_impact, "dims")
and "treated_units" in self.post_impact.dims
else self.post_impact
)
plot_xY(
self.datapost.index,
post_impact_plot,
ax=ax,
hdi_prob=hdi_prob,
plot_hdi_kwargs={"color": "C1"},
)
if single_post_obs:
self._draw_singleton_hdi_marker(
ax, self.datapost.index, self.post_impact, color="C1"
)
ax.axhline(y=0, c="k")
post_impact_mean = (
self.post_impact.mean(["chain", "draw"])
if hasattr(self.post_impact, "mean")
else self.post_impact
)
if (
hasattr(post_impact_mean, "dims")
and "treated_units" in post_impact_mean.dims
):
post_impact_mean = post_impact_mean.isel(treated_units=0)
ax.fill_between(
self.datapost.index,
y1=post_impact_mean,
color="C0",
alpha=0.25,
label="Causal impact",
)
ax.set(title="Causal Impact")
def _plot_bayesian_bottom_panel(
self,
ax: plt.Axes,
hdi_prob: float,
single_post_obs: bool,
) -> None:
"""Plot the bottom panel: cumulative causal impact."""
ax.set(title="Cumulative Causal Impact")
post_cum_plot = (
self.post_impact_cumulative.isel(treated_units=0)
if hasattr(self.post_impact_cumulative, "dims")
and "treated_units" in self.post_impact_cumulative.dims
else self.post_impact_cumulative
)
plot_xY(
self.datapost.index,
post_cum_plot,
ax=ax,
hdi_prob=hdi_prob,
plot_hdi_kwargs={"color": "C1"},
)
if single_post_obs:
self._draw_singleton_hdi_marker(
ax, self.datapost.index, self.post_impact_cumulative, color="C1"
)
ax.axhline(y=0, c="k")
def _add_intervention_lines(
self,
axes: list[plt.Axes],
) -> None:
"""Add vertical lines for treatment start and optional treatment end."""
for i in [0, 1, 2]:
axes[i].axvline(
x=self.treatment_time,
ls="--",
lw=1.5,
color="k",
zorder=1.5,
label="Treatment start" if i == 0 else None,
)
if self.treatment_end_time is not None:
axes[i].axvline(
x=self.treatment_end_time,
ls=":",
lw=1.5,
color="k",
zorder=1.5,
label="Treatment end" if i == 0 else None,
)
def _format_date_axes(
self,
axes: list[plt.Axes],
) -> None:
"""Apply intelligent date formatting if data has a datetime index."""
if isinstance(self.datapre.index, pd.DatetimeIndex):
full_index = _combine_datetime_indices(
pd.DatetimeIndex(self.datapre.index),
pd.DatetimeIndex(self.datapost.index),
)
format_date_axes(axes, full_index)
def _bayesian_plot(
self,
round_to: int | None = 2,
hdi_prob: float = HDI_PROB,
figsize: tuple[float, float] = (7, 8),
**kwargs: Any,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results.
Parameters
----------
round_to : int, optional
Number of decimals used to round results. Defaults to 2. Use ``None``
to return raw numbers.
hdi_prob : float, optional
Probability mass of the highest density interval drawn around the
posterior predictive, causal impact, and cumulative impact bands.
Must be in ``(0, 1]``. Defaults to
:data:`~causalpy.constants.HDI_PROB` (currently 0.94).
figsize : tuple of (float, float), optional
Width and height of the figure in inches. Defaults to ``(7, 8)``.
"""
single_post_obs = len(self.datapost) <= 1
fig, ax = plt.subplots(3, 1, sharex=True, figsize=figsize)
handles, labels = self._plot_bayesian_top_panel(
ax[0], hdi_prob, single_post_obs, round_to
)
self._plot_bayesian_middle_panel(ax[1], hdi_prob, single_post_obs)
self._plot_bayesian_bottom_panel(ax[2], hdi_prob, single_post_obs)
self._add_intervention_lines(ax)
ax[0].legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)
self._format_date_axes(ax)
return fig, ax
def _ols_plot(
self,
round_to: int | None = 2,
figsize: tuple[float, float] = (7, 8),
**kwargs: Any,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
:param figsize:
Width and height of the figure in inches. Defaults to ``(7, 8)``.
"""
counterfactual_label = "Counterfactual"
fig, ax = plt.subplots(3, 1, sharex=True, figsize=figsize)
ax[0].plot(self.datapre.index, self.pre_design["y"], "k.")
ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit")
ax[0].plot(self.datapost.index, self.post_design["y"], "k.")
ax[0].plot(
self.datapost.index,
self.post_pred,
label=counterfactual_label,
ls=":",
c="k",
)
# Shaded causal effect
ax[0].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_pred),
y2=np.squeeze(self.post_design["y"]),
color="C0",
alpha=0.25,
label="Causal impact",
)
ax[0].set(
title=f"$R^2$ on pre-intervention data = {round_num(_as_scalar(self.score), round_to)}"
)
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
ax[1].plot(
self.datapost.index,
self.post_impact,
"k.",
label=counterfactual_label,
)
ax[1].axhline(y=0, c="k")
# Shaded causal effect
ax[1].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_impact),
color="C0",
alpha=0.25,
label="Causal impact",
)
ax[1].set(title="Causal Impact")
ax[2].plot(self.datapost.index, self.post_impact_cumulative, c="k")
ax[2].axhline(y=0, c="k")
ax[2].set(title="Cumulative Causal Impact")
# Intervention lines. Use a thin dashed black style and a zorder just
# below the data so the treatment marker reads as a neutral
# annotation rather than data, and never occludes data points.
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="--",
lw=1.5,
color="k",
zorder=1.5,
label="Treatment start" if i == 0 else None,
)
if self.treatment_end_time is not None:
ax[i].axvline(
x=self.treatment_end_time,
ls=":",
lw=1.5,
color="k",
zorder=1.5,
label="Treatment end" if i == 0 else None,
)
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
# Apply intelligent date formatting if data has datetime index
if isinstance(self.datapre.index, pd.DatetimeIndex):
# Combine pre and post indices for full date range
full_index = _combine_datetime_indices(
pd.DatetimeIndex(self.datapre.index),
pd.DatetimeIndex(self.datapost.index),
)
format_date_axes(ax, full_index)
return (fig, ax)
def _extract_predictions(
self,
pre_data: pd.DataFrame,
post_data: pd.DataFrame,
hdi_prob: float,
lower_col: str,
upper_col: str,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Extract posterior predictive mean and HDI bounds, store in pre/post data."""
pre_mu = az.extract(
self.pre_pred, group="posterior_predictive", var_names="mu"
)
post_mu = az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
)
if "treated_units" in pre_mu.dims:
pre_mu = pre_mu.isel(treated_units=0)
if "treated_units" in post_mu.dims:
post_mu = post_mu.isel(treated_units=0)
pre_data["prediction"] = pre_mu.mean("sample").values
post_data["prediction"] = post_mu.mean("sample").values
hdi_pre_pred = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
)
hdi_post_pred = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
)
if (
isinstance(hdi_pre_pred.index, pd.MultiIndex)
and "treated_units" in hdi_pre_pred.index.names
):
pre_data[[lower_col, upper_col]] = hdi_pre_pred.xs(
"unit_0", level="treated_units"
).set_index(pre_data.index)
post_data[[lower_col, upper_col]] = hdi_post_pred.xs(
"unit_0", level="treated_units"
).set_index(post_data.index)
else:
pre_data[[lower_col, upper_col]] = hdi_pre_pred.set_index(
pre_data.index
)
post_data[[lower_col, upper_col]] = hdi_post_pred.set_index(
post_data.index
)
return pre_data, post_data
def _extract_impacts(
self,
pre_data: pd.DataFrame,
post_data: pd.DataFrame,
hdi_prob: float,
lower_col: str,
upper_col: str,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Extract impact mean and HDI bounds via quantiles, store in pre/post data."""
pre_impact_mean = (
self.pre_impact.mean(dim=["chain", "draw"])
if hasattr(self.pre_impact, "mean")
else self.pre_impact
)
post_impact_mean = (
self.post_impact.mean(dim=["chain", "draw"])
if hasattr(self.post_impact, "mean")
else self.post_impact
)
if (
hasattr(pre_impact_mean, "dims")
and "treated_units" in pre_impact_mean.dims
):
pre_impact_mean = pre_impact_mean.isel(treated_units=0)
if (
hasattr(post_impact_mean, "dims")
and "treated_units" in post_impact_mean.dims
):
post_impact_mean = post_impact_mean.isel(treated_units=0)
pre_data["impact"] = pre_impact_mean.values
post_data["impact"] = post_impact_mean.values
alpha = 1 - hdi_prob
lower_q = alpha / 2
upper_q = 1 - alpha / 2
pre_lower_da = self.pre_impact.quantile(lower_q, dim=["chain", "draw"])
pre_upper_da = self.pre_impact.quantile(upper_q, dim=["chain", "draw"])
post_lower_da = self.post_impact.quantile(lower_q, dim=["chain", "draw"])
post_upper_da = self.post_impact.quantile(upper_q, dim=["chain", "draw"])
if hasattr(pre_lower_da, "dims") and "treated_units" in pre_lower_da.dims:
pre_lower_da = pre_lower_da.sel(treated_units="unit_0")
pre_upper_da = pre_upper_da.sel(treated_units="unit_0")
if hasattr(post_lower_da, "dims") and "treated_units" in post_lower_da.dims:
post_lower_da = post_lower_da.sel(treated_units="unit_0")
post_upper_da = post_upper_da.sel(treated_units="unit_0")
pre_data[lower_col] = (
pre_lower_da.to_series().reindex(pre_data.index).values
)
pre_data[upper_col] = (
pre_upper_da.to_series().reindex(pre_data.index).values
)
post_data[lower_col] = (
post_lower_da.to_series().reindex(post_data.index).values
)
post_data[upper_col] = (
post_upper_da.to_series().reindex(post_data.index).values
)
return pre_data, post_data
[docs]
def get_plot_data_bayesian(self, hdi_prob: float = HDI_PROB) -> pd.DataFrame:
"""
Recover the data of the experiment along with the prediction and causal impact information.
Parameters
----------
hdi_prob : float, default :data:`~causalpy.constants.HDI_PROB`
Probability mass of the highest density interval. Defaults to the
project-wide :data:`~causalpy.constants.HDI_PROB` (currently 0.94).
"""
if self._model_backend.is_bayesian:
hdi_pct = int(round(hdi_prob * 100))
pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"
pre_data = self.datapre.copy()
post_data = self.datapost.copy()
pre_data, post_data = self._extract_predictions(
pre_data, post_data, hdi_prob, pred_lower_col, pred_upper_col
)
pre_data, post_data = self._extract_impacts(
pre_data, post_data, hdi_prob, impact_lower_col, impact_upper_col
)
self.plot_data = pd.concat([pre_data, post_data])
return self.plot_data
else:
raise ValueError("Unsupported model type")
[docs]
def get_plot_data_ols(self) -> pd.DataFrame:
"""
Recover the data of the experiment along with the prediction and causal impact information.
"""
pre_data = self.datapre.copy()
post_data = self.datapost.copy()
pre_data["prediction"] = self.pre_pred
post_data["prediction"] = self.post_pred
pre_data["impact"] = self.pre_impact
post_data["impact"] = self.post_impact
self.plot_data = pd.concat([pre_data, post_data])
return self.plot_data
[docs]
def analyze_persistence(
self,
hdi_prob: float = HDI_PROB,
direction: Literal["increase", "decrease", "two-sided"] = "increase",
) -> dict[str, Any]:
"""Analyze effect persistence between intervention and post-intervention periods.
Computes mean effects, persistence ratio, and total (cumulative) impacts for both periods.
The persistence ratio is the post-intervention mean effect divided by the intervention
mean effect (as a decimal, e.g., 0.30 means 30% persistence, 1.5 means 150%).
Note: The ratio can exceed 1.0 if the post-intervention effect is larger than the
intervention effect.
Automatically prints a summary of the results.
Parameters
----------
hdi_prob : float
Probability for the HDI interval (Bayesian models only). Defaults
to :data:`~causalpy.constants.HDI_PROB` (currently 0.94).
direction : {"increase", "decrease", "two-sided"}, default="increase"
Direction for tail probability calculation (Bayesian models only)
Returns
-------
dict[str, Any]
Dictionary containing:
- "mean_effect_during": Mean effect during intervention period
- "mean_effect_post": Mean effect during post-intervention period
- "persistence_ratio": Post-intervention mean effect divided by intervention mean (decimal, can exceed 1.0)
- "total_effect_during": Total (cumulative) effect during intervention period
- "total_effect_post": Total (cumulative) effect during post-intervention period
Raises
------
ValueError
If treatment_end_time is not provided (two-period design)
Examples
--------
>>> import causalpy as cp
>>> import pandas as pd
>>> df = (
... cp.load_data("its")
... .assign(date=lambda x: pd.to_datetime(x["date"]))
... .set_index("date")
... )
>>> result = cp.InterruptedTimeSeries(
... df,
... treatment_time=pd.Timestamp("2017-01-01"),
... treatment_end_time=pd.Timestamp("2017-06-01"),
... formula="y ~ 1 + t + C(month)",
... model=cp.pymc_models.LinearRegression(
... sample_kwargs={"random_seed": 42, "progressbar": False}
... ),
... )
>>> persistence = result.analyze_persistence() # doctest: +SKIP
... # Note: Results are automatically printed to console
>>> persistence["persistence_ratio"] # doctest: +SKIP
-1.224
"""
if self.treatment_end_time is None:
raise ValueError(
"analyze_persistence() requires treatment_end_time to be provided. "
"This method is only available for three-period designs."
)
is_pymc = self._model_backend.is_bayesian
time_dim = "obs_ind"
if is_pymc:
# PyMC: Compute statistics using xarray operations
from causalpy.reporting import _extract_hdi_bounds
# Intervention period
intervention_avg = self.intervention_impact.mean(dim=time_dim)
intervention_mean = _as_scalar(intervention_avg.mean(dim=["chain", "draw"]))
intervention_hdi = az.hdi(intervention_avg, hdi_prob=hdi_prob)
intervention_lower, intervention_upper = _extract_hdi_bounds(
intervention_hdi, hdi_prob
)
# Post-intervention period
post_avg = self.post_intervention_impact.mean(dim=time_dim)
post_mean = _as_scalar(post_avg.mean(dim=["chain", "draw"]))
post_hdi = az.hdi(post_avg, hdi_prob=hdi_prob)
post_lower, post_upper = _extract_hdi_bounds(post_hdi, hdi_prob)
# Cumulative (total) impacts
intervention_cum = self.intervention_impact_cumulative.isel({time_dim: -1})
intervention_cum_mean = _as_scalar(
intervention_cum.mean(dim=["chain", "draw"])
)
post_cum = self.post_intervention_impact_cumulative.isel({time_dim: -1})
post_cum_mean = _as_scalar(post_cum.mean(dim=["chain", "draw"]))
# Persistence ratio: post_mean / intervention_mean (as decimal, not percentage)
epsilon = 1e-8
persistence_ratio = post_mean / (intervention_mean + epsilon)
result = {
"mean_effect_during": intervention_mean,
"mean_effect_post": post_mean,
"persistence_ratio": float(persistence_ratio),
"total_effect_during": intervention_cum_mean,
"total_effect_post": post_cum_mean,
}
# Store HDI bounds for printing
intervention_ci_lower = intervention_lower
intervention_ci_upper = intervention_upper
post_ci_lower = post_lower
post_ci_upper = post_upper
else:
# OLS: Compute statistics using numpy operations
from causalpy.reporting import _compute_statistics_ols
# Get counterfactual predictions for each period
intervention_counterfactual = self.intervention_pred
post_counterfactual = self.post_intervention_pred
# Compute statistics for intervention period
intervention_stats = _compute_statistics_ols(
self.intervention_impact.values
if hasattr(self.intervention_impact, "values")
else np.asarray(self.intervention_impact),
intervention_counterfactual,
alpha=1 - hdi_prob,
cumulative=True,
relative=False,
)
# Compute statistics for post-intervention period
post_stats = _compute_statistics_ols(
self.post_intervention_impact.values
if hasattr(self.post_intervention_impact, "values")
else np.asarray(self.post_intervention_impact),
post_counterfactual,
alpha=1 - hdi_prob,
cumulative=True,
relative=False,
)
# Persistence ratio (as decimal)
epsilon = 1e-8
persistence_ratio = post_stats["avg"]["mean"] / (
intervention_stats["avg"]["mean"] + epsilon
)
result = {
"mean_effect_during": intervention_stats["avg"]["mean"],
"mean_effect_post": post_stats["avg"]["mean"],
"persistence_ratio": float(persistence_ratio),
"total_effect_during": intervention_stats["cum"]["mean"],
"total_effect_post": post_stats["cum"]["mean"],
}
# Store CI bounds for printing
intervention_ci_lower = intervention_stats["avg"]["ci_lower"]
intervention_ci_upper = intervention_stats["avg"]["ci_upper"]
post_ci_lower = post_stats["avg"]["ci_lower"]
post_ci_upper = post_stats["avg"]["ci_upper"]
self._print_persistence_results(
result,
hdi_prob,
is_pymc,
intervention_ci_lower,
intervention_ci_upper,
post_ci_lower,
post_ci_upper,
)
return result
def _print_persistence_results(
self,
result,
hdi_prob,
is_pymc,
intervention_ci_lower,
intervention_ci_upper,
post_ci_lower,
post_ci_upper,
):
hdi_pct = int(hdi_prob * 100)
ci_label = "HDI" if is_pymc else "CI"
print("=" * 60)
print("Effect Persistence Analysis")
print("=" * 60)
print("\nDuring intervention period:")
print(f" Mean effect: {result['mean_effect_during']:.2f}")
print(
f" {hdi_pct}% {ci_label}: [{intervention_ci_lower:.2f}, {intervention_ci_upper:.2f}]"
)
print(f" Total effect: {result['total_effect_during']:.2f}")
print("\nPost-intervention period:")
print(f" Mean effect: {result['mean_effect_post']:.2f}")
print(f" {hdi_pct}% {ci_label}: [{post_ci_lower:.2f}, {post_ci_upper:.2f}]")
print(f" Total effect: {result['total_effect_post']:.2f}")
print(f"\nPersistence ratio: {result['persistence_ratio']:.3f}")
print(
f" ({result['persistence_ratio'] * 100:.1f}% of intervention effect persisted)"
)
print("=" * 60)
def _resolve_summary_period(
self,
period: str | None,
window: str | tuple | slice,
prefix: str,
direction: str,
alpha: float,
cumulative: bool,
relative: bool,
min_effect: float | None,
treated_unit: str | None,
) -> tuple[EffectSummary | None, str | tuple | slice | None, str | None]:
"""Handle period parameter for three-period designs.
Returns a ``(early_summary, window, prefix)`` tuple:
- If ``period is None``: ``(None, window, prefix)`` — caller proceeds
with the defaults passed in.
- If ``period == "comparison"``:
``(EffectSummary, None, None)`` — caller returns this immediately.
- If ``period == "intervention"`` or ``"post"``:
``(None, resolved_window, resolved_prefix)`` — caller uses those
values instead of the defaults.
"""
if period is None:
return None, window, prefix
valid_periods = ["intervention", "post", "comparison"]
if period not in valid_periods:
raise ValueError(
f"period must be one of {valid_periods}, got '{period}'"
)
if not (
hasattr(self, "treatment_end_time")
and self.treatment_end_time is not None
):
raise ValueError(
f"Period '{period}' not available. This experiment may not "
"support three-period designs. Provide treatment_end_time to "
"enable period-specific analysis."
)
if period == "comparison":
return (
self._comparison_period_summary(
direction=direction,
alpha=alpha,
cumulative=cumulative,
relative=relative,
min_effect=min_effect,
),
None,
None,
)
if period == "intervention":
intervention_indices = self.datapost.index[
self.datapost.index < self.treatment_end_time
]
resolved_window = (self.treatment_time, intervention_indices.max())
resolved_prefix = "During intervention"
elif period == "post":
resolved_window = (
self.treatment_end_time,
self.datapost.index.max(),
)
resolved_prefix = "Post-intervention"
return None, resolved_window, resolved_prefix
[docs]
def effect_summary(
self,
*,
window: Literal["post"] | tuple | slice = "post",
direction: Literal["increase", "decrease", "two-sided"] = "increase",
alpha: float = 0.05,
cumulative: bool = True,
relative: bool = True,
min_effect: float | None = None,
treated_unit: str | None = None,
period: Literal["intervention", "post", "comparison"] | None = None,
prefix: str = "Post-period",
**kwargs: Any,
) -> EffectSummary:
"""
Generate a decision-ready summary of causal effects for Interrupted Time Series.
Parameters
----------
window : str, tuple, or slice, default="post"
Time window for analysis:
- "post": All post-treatment time points (default)
- (start, end): Tuple of start and end times (handles both datetime and integer indices)
- slice: Python slice object for integer indices
direction : {"increase", "decrease", "two-sided"}, default="increase"
Direction for tail probability calculation (PyMC only, ignored for OLS).
alpha : float, default=0.05
Significance level for HDI/CI intervals (1-alpha confidence level).
cumulative : bool, default=True
Whether to include cumulative effect statistics.
relative : bool, default=True
Whether to include relative effect statistics (% change vs counterfactual).
min_effect : float, optional
Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
treated_unit : str, optional
Ignored for Interrupted Time Series (single unit).
period : {"intervention", "post", "comparison"}, optional
For three-period designs (with treatment_end_time), specify which period to summarize.
Defaults to None for standard behavior.
prefix : str, optional
Prefix for prose generation (e.g., "During intervention", "Post-intervention").
Defaults to "Post-period".
**kwargs
Reserved for forward-compatibility; not consumed by this
implementation.
Returns
-------
EffectSummary
Object with .table (DataFrame) and .text (str) attributes.
The .text attribute contains a detailed multi-paragraph narrative report.
"""
from causalpy.reporting import (
_compute_statistics,
_compute_statistics_ols,
_extract_counterfactual,
_extract_window,
_generate_prose_detailed,
_generate_prose_detailed_ols,
_generate_table,
_generate_table_ols,
)
is_pymc = self._model_backend.is_bayesian
# Handle three-period design via shared helper
early_summary, resolved_window, resolved_prefix = (
self._resolve_summary_period(
period, window, prefix, direction, alpha, cumulative, relative,
min_effect, treated_unit,
)
)
if early_summary is not None:
return early_summary
if resolved_window is not None:
window = resolved_window
prefix = resolved_prefix
windowed_impact, window_coords = _extract_window(
self, window, treated_unit=treated_unit
)
counterfactual = _extract_counterfactual(
self, window_coords, treated_unit=treated_unit
)
if is_pymc:
# PyMC model: use posterior draws
hdi_prob = 1 - alpha
stats = _compute_statistics(
windowed_impact,
counterfactual,
hdi_prob=hdi_prob,
direction=direction,
cumulative=cumulative,
relative=relative,
min_effect=min_effect,
)
table = _generate_table(stats, cumulative=cumulative, relative=relative)
# Compute observed/counterfactual averages for prose
time_dim = "obs_ind"
cf_avg = _as_scalar(counterfactual.mean(dim=[time_dim, "chain", "draw"]))
obs_avg = cf_avg + stats["avg"]["mean"]
cf_cum = _as_scalar(
counterfactual.sum(dim=time_dim).mean(dim=["chain", "draw"])
)
obs_cum = cf_cum + stats["cum"]["mean"] if cumulative else None
text = _generate_prose_detailed(
stats,
window_coords,
alpha=alpha,
direction=direction,
cumulative=cumulative,
relative=relative,
prefix=prefix,
observed_avg=obs_avg,
counterfactual_avg=cf_avg,
observed_cum=obs_cum,
counterfactual_cum=cf_cum if cumulative else None,
experiment_type="its",
)
else:
# OLS model: use point estimates and CIs
impact_array = _as_array(windowed_impact)
counterfactual_array = _as_array(counterfactual)
stats = _compute_statistics_ols(
impact_array,
counterfactual_array,
alpha=alpha,
cumulative=cumulative,
relative=relative,
)
table = _generate_table_ols(stats, cumulative=cumulative, relative=relative)
cf_avg = float(np.mean(counterfactual_array))
obs_avg = cf_avg + stats["avg"]["mean"]
cf_cum = float(np.sum(counterfactual_array))
obs_cum = cf_cum + stats["cum"]["mean"] if cumulative else None
text = _generate_prose_detailed_ols(
stats,
window_coords,
alpha=alpha,
cumulative=cumulative,
relative=relative,
prefix=prefix,
observed_avg=obs_avg,
counterfactual_avg=cf_avg,
observed_cum=obs_cum,
counterfactual_cum=cf_cum if cumulative else None,
experiment_type="its",
)
return EffectSummary(table=table, text=text)