Causal AI, exploring the integration of causal reasoning into machine learning
Welcome to my series on Causal AI, where we will explore the integration of causal reasoning into machine learning models. Expect to explore a number of practical applications in different business contexts.
In the last article we covered measure the intrinsic causal influence of your marketing campaigns. In this article we will move on to validate the causal impact of synthetic controls.
If you missed the last article on intrinsic causal influence, check it out here:
In this article, we will focus on understanding the synthetic control method and how we can validate the estimated causal impact.
The following aspects will be covered:
- What is the synthetic control method?
- What challenge is he trying to take on?
- How to validate the estimated causal impact?
- A Python case study using realistic Google trends data, demonstrating how we can validate the estimated causal impact of synthetic controls.
The full notebook can be found here:
What is this?
The synthetic control method is a causal technique that can be used to assess the causal impact of an intervention or treatment when a randomized controlled trial (RCT) or A/B testing was not possible. It was initially proposed in 2003 by Abadie and Allezabal. The following article includes an excellent case study to help you understand the proposed method:
https://web.stanford.edu/~jhain/Paper/JASA2010.pdf
Let’s take a look at some of the basics ourselves… The synthetic control method creates a counterfactual version of the treatment unit by creating a weighted combination of control units that did not receive the intervention or treatment.
- Unit processed: The unit receiving the intervention.
- Control units: A set of similar units that did not benefit from the intervention.
- Counterfactual: Created as a weighted combination of control units. The goal is to find weights for each control unit that result in a counterfactual that closely matches the treated unit in the pre-intervention period.
- Causal impact: The difference between the post-intervention treatment unit and the counterfactual.
If we really wanted to simplify things, we could think of this as a linear regression where each control unit is a feature and the processing unit is the target. The pre-intervention period is our train, and we use the model to score our post-intervention period. The difference between actual and expected is causal impact.
Below are some examples to bring it to life when we might consider using it:
- When we run a television marketing campaign, we cannot randomly divide the audience into those who can and those who cannot see the campaign. We could, however, carefully select a region to test the campaign and use the remaining regions as control units. Once we have measured the effect, the campaign could be extended to other regions. This is often called a geo-lift test.
- Policy changes that are made to some areas but not others — For example, a local council may implement a policy change to reduce unemployment. Other regions where this policy was not in place could serve as control units.
What challenge is he trying to take on?
When we combine high dimensionality (lots of features) with limited observations, we can obtain an overfit model.
Let’s take the example of geo-lift to illustrate. If we use last year's weekly data as the pre-intervention period, this gives us 52 observations. If we then decide to test our intervention in several European countries, this will give us an observation/characteristic ratio of 1:1!
Earlier we explained how the synthetic control method could be implemented using linear regression. However, the observation-to-feature ratio means that it is very likely that the linear regression is overfit, resulting in poor estimation of the causal impact in the post-intervention period.
In linear regression, the weights (coefficients) for each feature (control unit) can be negative or positive and their sum can give a number greater than 1. However, the synthetic control method learns the weights while applying the constraints below. -below:
- Constrain weights to sum to 1
- Constrain weights to be ≥ 0
These constraints facilitate regularization and avoid extrapolation beyond the range of observed data.
It should be noted that in terms of regularization, Ridge and Lasso regression can achieve this and are reasonable alternatives in some cases. But we will test this in the case study!
How to validate the estimated causal impact?
Perhaps a more significant challenge is that we are unable to validate the estimated causal impact in the post-intervention period.
How long should my pre-intervention period be? Are we sure we haven't over-equipped our pre-intervention period? How can we know if our model generalizes well in the post-intervention period? What if I want to try different implementations of synthetic control methods?
We could randomly select a few observations from the pre-intervention period and retain them for validation. But we've already highlighted the challenge that comes from having limited observations, which could make things worse!
What if we could do some kind of pre-intervention simulation? Could this help us answer some of the questions highlighted above and gain confidence in the estimated causal impact of our models? Everything will be explained in the case study!
Background
After convincing Finance that brand marketing provides real added value, the marketing team contacts you to ask you questions about geo-lift testing. Someone from Facebook told them it was the next big thing (even though it was the same person who told them Prophet was a good forecasting model) and they want to know if they could use it to measure their new upcoming television campaign.
You're a little worried, because the last time you did a geo-lift test, the marketing analytics team thought it was a good idea to play with the pre-intervention period used until that it has a significant causal impact.
This time you suggest they do a “pre-intervention simulation”, after which you propose that the pre-intervention period be agreed before the test begins.
So let’s explore what a “pre-intervention simulation” looks like!
Data creation
To make this as realistic as possible, I pulled up some Google trends data for the majority of countries in Europe. The search term is irrelevant, just pretend it's your company's sales (and that you operate across Europe).
However, if you want to know how I got the Google trend data, check out my notebook:
Below we can see the dataframe. We have been selling for 3 years in 50 European countries. The marketing team plans to broadcast its television campaign in Britain.
Now here's the clever part. We will simulate an intervention in the last 7 weeks of the time series.
np.random.seed(1234)# Create intervention flag
mask = (df('date') >= "2024-04-14") & (df('date') <= "2024-06-02")
df('intervention') = mask.astype(int)
row_count = len(df)
# Create intervention uplift
df('uplift_perc') = np.random.uniform(0.10, 0.20, size=row_count)
df('uplift_abs') = round(df('uplift_perc') * df('GB'))
df('y') = df('GB')
df.loc(df('intervention') == 1, 'y') = df('GB') + df('uplift_abs')
Now let's plot actual and counterfactual sales in Britain to bring to life what we've done:
def synth_plot(df, counterfactual):plt.figure(figsize=(14, 8))
sns.set_style("white")
# Create plot
sns.lineplot(data=df, x='date', y='y', label='Actual', color='b', linewidth=2.5)
sns.lineplot(data=df, x='date', y=counterfactual, label='Counterfactual', color='r', linestyle='--', linewidth=2.5)
plt.title('Synthetic Control Method: Actual vs. Counterfactual', fontsize=24)
plt.xlabel('Date', fontsize=20)
plt.ylabel('Metric Value', fontsize=20)
plt.legend(fontsize=16)
plt.gca().xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
plt.xticks(rotation=90)
plt.grid(True, linestyle='--', alpha=0.5)
# High the intervention point
intervention_date = '2024-04-07'
plt.axvline(pd.to_datetime(intervention_date), color='k', linestyle='--', linewidth=1)
plt.text(pd.to_datetime(intervention_date), plt.ylim()(1)*0.95, 'Intervention', color='k', fontsize=18, ha='right')
plt.tight_layout()
plt.show()
synth_plot(df, 'GB')
Now that we have simulated an intervention, we can explore how well the synthetic control method will work.
Pretreatment
All European countries except Great Britain are defined as control units (features). The treatment unit (target) corresponds to sales in GB with the intervention applied.
# Delete the original target column so we don't use it as a feature by accident
del df('GB')# set feature & targets
X = df.columns(1:50)
y = 'y'
Regression
Below I have configured a function that we can reuse with different pre-intervention periods and different regression models (e.g. Ridge, Lasso):
def train_reg(df, start_index, reg_class):df_temp = df.iloc(start_index:).copy().reset_index()
X_pre = df_temp(df_temp('intervention') == 0)(X)
y_pre = df_temp(df_temp('intervention') == 0)(y)
X_train, X_test, y_train, y_test = train_test_split(X_pre, y_pre, test_size=0.10, random_state=42)
model = reg_class
model.fit(X_train, y_train)
yhat_train = model.predict(X_train)
yhat_test = model.predict(X_test)
mse_train = mean_squared_error(y_train, yhat_train)
mse_test = mean_squared_error(y_test, yhat_test)
print(f"Mean Squared Error train: {round(mse_train, 2)}")
print(f"Mean Squared Error test: {round(mse_test, 2)}")
r2_train = r2_score(y_train, yhat_train)
r2_test = r2_score(y_test, yhat_test)
print(f"R2 train: {round(r2_train, 2)}")
print(f"R2 test: {round(r2_test, 2)}")
df_temp('pred') = model.predict(df_temp.loc(:, X))
df_temp('delta') = df_temp('y') - df_temp('pred')
pred_lift = df_temp(df_temp('intervention') == 1)('delta').sum()
actual_lift = df_temp(df_temp('intervention') == 1)('uplift_abs').sum()
abs_error_perc = abs(pred_lift - actual_lift) / actual_lift
print(f"Predicted lift: {round(pred_lift, 2)}")
print(f"Actual lift: {round(actual_lift, 2)}")
print(f"Absolute error percentage: {round(abs_error_perc, 2)}")
return df_temp, abs_error_perc
To start, we keep things simple and use linear regression to estimate the causal impact, using a short pre-intervention period:
df_lin_reg_100, pred_lift_lin_reg_100 = train_reg(df, 100, LinearRegression())
Looking at the results, linear regression does not give good results. But this is not surprising given the observation/feature ratio.
synth_plot(df_lin_reg_100, 'pred')
Synthetic control method
Let's go ahead and see how this compares to the synthetic control method. Below I have configured a function similar to the previous one, but applying the synthetic control method using sciPy:
def synthetic_control(weights, control_units, treated_unit):synthetic = np.dot(control_units.values, weights)
return np.sqrt(np.sum((treated_unit - synthetic)**2))
def train_synth(df, start_index):
df_temp = df.iloc(start_index:).copy().reset_index()
X_pre = df_temp(df_temp('intervention') == 0)(X)
y_pre = df_temp(df_temp('intervention') == 0)(y)
X_train, X_test, y_train, y_test = train_test_split(X_pre, y_pre, test_size=0.10, random_state=42)
initial_weights = np.ones(len(X)) / len(X)
constraints = ({'type': 'eq', 'fun': lambda w: np.sum(w) - 1})
bounds = ((0, 1) for _ in range(len(X)))
result = minimize(synthetic_control,
initial_weights,
args=(X_train, y_train),
method='SLSQP',
bounds=bounds,
constraints=constraints,
options={'disp': False, 'maxiter': 1000, 'ftol': 1e-9},
)
optimal_weights = result.x
yhat_train = np.dot(X_train.values, optimal_weights)
yhat_test = np.dot(X_test.values, optimal_weights)
mse_train = mean_squared_error(y_train, yhat_train)
mse_test = mean_squared_error(y_test, yhat_test)
print(f"Mean Squared Error train: {round(mse_train, 2)}")
print(f"Mean Squared Error test: {round(mse_test, 2)}")
r2_train = r2_score(y_train, yhat_train)
r2_test = r2_score(y_test, yhat_test)
print(f"R2 train: {round(r2_train, 2)}")
print(f"R2 test: {round(r2_test, 2)}")
df_temp('pred') = np.dot(df_temp.loc(:, X).values, optimal_weights)
df_temp('delta') = df_temp('y') - df_temp('pred')
pred_lift = df_temp(df_temp('intervention') == 1)('delta').sum()
actual_lift = df_temp(df_temp('intervention') == 1)('uplift_abs').sum()
abs_error_perc = abs(pred_lift - actual_lift) / actual_lift
print(f"Predicted lift: {round(pred_lift, 2)}")
print(f"Actual lift: {round(actual_lift, 2)}")
print(f"Absolute error percentage: {round(abs_error_perc, 2)}")
return df_temp, abs_error_perc
I keep the pre-intervention period the same to create a fair comparison with linear regression:
df_synth_100, pred_lift_synth_100 = train_synth(df, 100)
Wow! I'll be the first to admit that I wasn't expecting such a significant improvement!
synth_plot(df_synth_100, 'pred')
Comparison of results
Let's not get carried away yet. Below we conduct some additional experiments exploring model types and pre-intervention periods:
# run regression experiments
df_lin_reg_00, pred_lift_lin_reg_00 = train_reg(df, 0, LinearRegression())
df_lin_reg_100, pred_lift_lin_reg_100 = train_reg(df, 100, LinearRegression())
df_ridge_00, pred_lift_ridge_00 = train_reg(df, 0, RidgeCV())
df_ridge_100, pred_lift_ridge_100 = train_reg(df, 100, RidgeCV())
df_lasso_00, pred_lift_lasso_00 = train_reg(df, 0, LassoCV())
df_lasso_100, pred_lift_lasso_100 = train_reg(df, 100, LassoCV())# run synthetic control experiments
df_synth_00, pred_lift_synth_00 = train_synth(df, 0)
df_synth_100, pred_lift_synth_100 = train_synth(df, 100)
experiment_data = {
"Method": ("Linear", "Linear", "Ridge", "Ridge", "Lasso", "Lasso", "Synthetic Control", "Synthetic Control"),
"Data Size": ("Large", "Small", "Large", "Small", "Large", "Small", "Large", "Small"),
"Value": (pred_lift_lin_reg_00, pred_lift_lin_reg_100, pred_lift_ridge_00, pred_lift_ridge_100,pred_lift_lasso_00, pred_lift_lasso_100, pred_lift_synth_00, pred_lift_synth_100)
}
df_experiments = pd.DataFrame(experiment_data)
We will use the code below to visualize the results:
# Set the style
sns.set_style="whitegrid"# Create the bar plot
plt.figure(figsize=(10, 6))
bar_plot = sns.barplot(x="Method", y="Value", hue="Data Size", data=df_experiments, palette="muted")
# Add labels and title
plt.xlabel("Method")
plt.ylabel("Absolute error percentage")
plt.title("Synthetic Controls - Comparison of Methods Across Different Data Sizes")
plt.legend(title="Data Size")
# Show the plot
plt.show()
The results for the small dataset are really interesting! As expected, regularization helped improve causal impact estimates. Synthetic control then went even further!
Results from the large data set suggest that longer pre-intervention periods are not always better.
However, what I want you to take away is how useful it is to perform a pre-procedure simulation. There are so many avenues you can explore with your own dataset!
Today we explored the synthetic control method and how you can validate causal impact. I leave you with some final thoughts:
- The simplicity of the synthetic control method makes it one of the most widely used techniques in the causal AI toolbox.
- Unfortunately, it's also the most widely abused – Let's run the R package CausalImpact, changing the pre-intervention period until we see an improvement we like. 😭
- This is where I highly recommend carrying out pre-intervention simulations to agree on the test design from the outset.
- Synthetic control method is a much studied area. It's worth checking out the proposed adaptations Augmented SC, Robust SC and Penalized SC.
Alberto Abadie, Alexis Diamond and Jens Hainmueller (2010) Synthetic control methods for comparative case studies: Estimating the effect of the California tobacco control program, Journal of the American Statistical Association, 105: 490, 493-505, DOI: 10.1198/jasa.2009. .ap08746