How to check for model assumptions with Python Seaborn Graphics


We need to check assumptions for models to give us confidence that are models have integrity and are not biased or overfitting the data.

We check for three assumptions in this example, with sub-plotted seaborn graphics for a linear regression model.
The code for creating the linear regression model can be found in this post
You can run the code below once you have built the model. The model models the relationship between radio advertising spending and radio sales.

assumption-checking

Graph 1: Checking the Linearity Assumption
The linearity assumption is as follows: ‘Each predictor variable (x) is linearly related to the outcome of variable y.’
In the first graph, we plot radio advertising spend against radio sales and can see there is a linear relationship (first graph). So we can conclude the linearity assumption is met.

Graph 2: Checking Homoscedacity assumption with a scatterplot

The homoscedasticity assumption (extra points if you can spell it correctly) is as follows:
y_pred are the predicted y values from a regression line.

In the second graph, we plot the residuals of the model, which are the difference between actuals and model forecasts.

Homoscedasticity means that the residuals have equal or almost equal variance across the regression line.
By plotting the error terms with predicted terms we can check that there should not be any pattern in the error terms.’ Good homoscedacity is therefore a balanced graph of residuals above and below zero.

Graph 3: Check for Normality Assumption
In the third graph, the histogram is used to plot the residuals of the regression line (the actual y values vs. the predicted y values) for x. If the model is unbiased, the residuals should be normally distributed (and we see that).

The fourth graph is a Q-Q plot which is also used to check the normality assumption.

fig, ax = plt.subplots(2, 2, figsize=(18, 10))

fig.suptitle('Assumption Checks')

#Check for linearity
ax[0, 0] = sns.regplot(
    ax=ax[0, 0],
    data = df,
    x = df['Radio'],
    y = df['Sales'], 
    );
ax[0, 0].set_title('Radio Sales')
ax[0, 0].set_xlabel('Radio Spend ($K)')
ax[0, 0].set_ylabel('Sales ($)')
#ax[0].set_xticks(range(0,10,10))
#ax[0].set_xticks(rotation=90)


#Check for Homeoscedacity
# Plot residuals against the fitted values
ax[0, 1] = sns.scatterplot( ax=ax[0, 1],x=y_pred, y=residuals)
ax[0, 1].set_title("Residuals vs Fitted Values")
ax[0, 1].set_xlabel("Fitted Values")
ax[0, 1].set_ylabel("Residuals")
ax[0, 1].axhline(0, linestyle='--', color='red')


#Check for normality
ax[1, 0] = sns.histplot(ax=ax[1, 0], x=residuals)
ax[1, 0].set_xlabel("Residual Value")
ax[1, 0].set_title("Histogram of Residuals")

#Check for nomrmality QQ plot
ax[1, 1] = sm.qqplot(residuals, line='s',ax = ax[1,1])
ax[1, 0].set_title("Q-Q Plot")



#sm.qqplot(test, loc = 20, scale = 5 ,  line='45')

plt.show()

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *