Linear Regression in Python#
There are many packages that implement linear regression in Python. As detailed in our last reading, however, depending on whether they are intended for use in prediction or inference, the way these packages operate can vary substantially.
In this reading, we will look at how linear regression has been implemented in two major packages — statsmodels and scikit-learn. Both of these packages can fit a wide range of linear models, but the way they are organized and the results they report reflect the different audiences for whom they were designed.
Broadly speaking, statsmodels is a library written by statisticians for statisticians, biostatisticians, social scientists, and natural scientists. It can do prediction, but its focus is inference, and as we will see that is reflected throughout the package.
scikit-learn, by contrast, was written by and for computer scientists interested in machine learning. Its focus is on prediction, and while it includes a far more diverse collection of machine learning models than statsmodels, it does not include all the features someone doing inference might expect for evaluating model performance or doing things like calculating different types of standard errors.
Regression in statsmodels#
Because it is the more feature-rich library when it comes to regression, we will start our exploration of linear regression in Python with statsmodels. If you have any interest in inference, are coming from a programming language like R or Stata, and/or have a background in statistics, social science, or the natural sciences, then statsmodels is the package that will feel most familiar and useful.
While our focus will be on linear regression, the statsmodels package includes a wide range of tools for inference and modeling, from simple models like linear and logistic regression to generalized linear models (GLMs), non-parametric regression, robust linear models, time series models, survival analysis, multiple imputation, generalized additive models (GAMs), and more. (Curious if it includes that one model that’s near and dear to your heart? Feel free to go check and come back).
Moreover, it provides by far the easiest interface for moving from a pandas DataFrame to a regression in the Python ecosystem. To illustrate, let’s fit a quick regression looking at countries’ under-5 mortality rates as a function of GDP per capita and Work Bank ratings of each countries’ public sector in terms of transparency, accountability, and levels of corruption:
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
pd.set_option("mode.copy_on_write", True)
# Load data on infant mortality, gdp per capita, and
# World Bank CPIA public sector transparency, accountability,
# and corruption in the public sector scores
# (1 = low transparency and accountability, 6 = high transparency and accountability).
wdi = pd.read_csv("data/wdi_corruption.csv")
# Check one observation to get a feel for things.
wdi.sample().T
73 | |
---|---|
country_name | Yemen, Rep. |
gdp_per_capita_ppp | 3108.764217 |
CPIA_public_sector_rating | 1.5 |
mortality_rate_under5_per_1000 | 55.4 |
Mortality rate, under-5, female (per 1,000 live births) | 51.3 |
Mortality rate, under-5, male (per 1,000 live births) | 59.4 |
Population, total | 26497889.0 |
region | Middle East and North Africa |
# Fit model
corruption_model = smf.ols(
"mortality_rate_under5_per_1000 ~ np.log(gdp_per_capita_ppp) +"
" CPIA_public_sector_rating + region",
data=wdi,
).fit()
# Get regression result
corruption_model.summary()
Dep. Variable: | mortality_rate_under5_per_1000 | R-squared: | 0.586 |
---|---|---|---|
Model: | OLS | Adj. R-squared: | 0.541 |
Method: | Least Squares | F-statistic: | 13.12 |
Date: | Sat, 20 Jul 2024 | Prob (F-statistic): | 2.11e-10 |
Time: | 19:04:19 | Log-Likelihood: | -322.68 |
No. Observations: | 73 | AIC: | 661.4 |
Df Residuals: | 65 | BIC: | 679.7 |
Df Model: | 7 | ||
Covariance Type: | nonrobust |
coef | std err | t | P>|t| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | 169.9397 | 36.430 | 4.665 | 0.000 | 97.183 | 242.696 |
region[T.Europe and Central Asia] | -15.9265 | 12.304 | -1.294 | 0.200 | -40.499 | 8.646 |
region[T.Latin America and Caribbean] | 1.9023 | 9.226 | 0.206 | 0.837 | -16.523 | 20.327 |
region[T.Middle East and North Africa] | 3.7668 | 23.057 | 0.163 | 0.871 | -42.280 | 49.814 |
region[T.South Asia] | 4.9372 | 9.818 | 0.503 | 0.617 | -14.671 | 24.545 |
region[T.Sub-Saharan Africa] | 27.8448 | 7.360 | 3.783 | 0.000 | 13.145 | 42.544 |
np.log(gdp_per_capita_ppp) | -13.3790 | 4.547 | -2.942 | 0.005 | -22.461 | -4.297 |
CPIA_public_sector_rating | -7.1417 | 4.387 | -1.628 | 0.108 | -15.902 | 1.619 |
Omnibus: | 4.467 | Durbin-Watson: | 1.617 |
---|---|---|---|
Prob(Omnibus): | 0.107 | Jarque-Bera (JB): | 4.375 |
Skew: | 0.592 | Prob(JB): | 0.112 |
Kurtosis: | 2.813 | Cond. No. | 128. |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
(why smf.ols
? Recall that one name for linear regression is “Ordinary Least Squares,” which is often shortened to “OLS.”)
Not bad, huh? Not only was our regression a one-liner, but the .summary()
method has provided us with all our coefficients with labels, standard errors, t-statistics, p-values, 95% confidence intervals, and a range of additional statistics about the regression.
We were also able to apply come manipulations right in the regression specification — note the use of np.log()
to apply the log function to gdp_per_capita_ppp
without having to create a new variable in our DataFrame.
Finally (if this doesn’t mean much to you, don’t worry about it), statsmodels recognized that region
contained strings rather than numbers, so it dynamically created a set of indicator variables (i.e., it created one-hot encodings).
If you come from R or Stata, none of that is likely to seem notable to you, but as we’ll discuss in the next section (and as we’ll see when we get to scikit-learn
), those are conveniences that should not be taken for granted.
Accessing Results#
While .summary()
is great for printing out results, this is a class about programming for data science, and so, of course, we also need to know how to access these results programatically.
All results — along with a large number of other useful statistics about model performance — are accessible through the fit model (in this example, corruption_model
). This fit model is a RegressionResult
object. You can find a full list RegressionResult
attributes and methods here. Two of the most important attributes of a RegressionResult
object are: .params
(a pandas Series of regression coefficients) and .df_resid
(a DataFrame of residuals).
To illustrate:
corruption_model.params
Intercept 169.939656
region[T.Europe and Central Asia] -15.926485
region[T.Latin America and Caribbean] 1.902313
region[T.Middle East and North Africa] 3.766843
region[T.South Asia] 4.937173
region[T.Sub-Saharan Africa] 27.844846
np.log(gdp_per_capita_ppp) -13.379033
CPIA_public_sector_rating -7.141716
dtype: float64
print(
"The partial correlation of CPIA rating with under "
f"five mortality is {corruption_model.params['CPIA_public_sector_rating']:.2f}"
)
The partial correlation of CPIA rating with under five mortality is -7.14
For those who are familiar with these concepts, a RegressionResult object also has methods for post-regression testing (e.g., f-tests, Lagrange Multiplier (LM) tests of linear restrictions, likelihood ratio tests).
For example, to test whether the coefficient for the South Asia region is statistically significantly different from the coefficient for the Middle East and North Africa region, you would:
corruption_model.t_test("region[T.South Asia] = region[T.Middle East and North Africa]")
<class 'statsmodels.stats.contrast.ContrastResults'>
Test for Constraints
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
c0 1.1703 23.625 0.050 0.961 -46.013 48.354
==============================================================================
Alternative Standard Error Calculations#
RegressionResult
objects also have a method for different standard error calculations (again, don’t worry if this doesn’t mean anything to you!). statsmodels supports HC1
, HC2
, and HC3
heteroskedastic robust standard errors, as well as heteroskedasticity-autocorrelation robust standard errors:
corruption_model.get_robustcov_results(cov_type="HC2").summary()
Dep. Variable: | mortality_rate_under5_per_1000 | R-squared: | 0.586 |
---|---|---|---|
Model: | OLS | Adj. R-squared: | 0.541 |
Method: | Least Squares | F-statistic: | 48.92 |
Date: | Sat, 20 Jul 2024 | Prob (F-statistic): | 1.68e-23 |
Time: | 19:04:19 | Log-Likelihood: | -322.68 |
No. Observations: | 73 | AIC: | 661.4 |
Df Residuals: | 65 | BIC: | 679.7 |
Df Model: | 7 | ||
Covariance Type: | HC2 |
coef | std err | t | P>|t| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | 169.9397 | 37.846 | 4.490 | 0.000 | 94.357 | 245.522 |
region[T.Europe and Central Asia] | -15.9265 | 5.763 | -2.764 | 0.007 | -27.436 | -4.417 |
region[T.Latin America and Caribbean] | 1.9023 | 6.687 | 0.284 | 0.777 | -11.453 | 15.257 |
region[T.Middle East and North Africa] | 3.7668 | 8.304 | 0.454 | 0.652 | -12.817 | 20.351 |
region[T.South Asia] | 4.9372 | 9.361 | 0.527 | 0.600 | -13.759 | 23.633 |
region[T.Sub-Saharan Africa] | 27.8448 | 7.238 | 3.847 | 0.000 | 13.389 | 42.300 |
np.log(gdp_per_capita_ppp) | -13.3790 | 4.550 | -2.941 | 0.005 | -22.465 | -4.293 |
CPIA_public_sector_rating | -7.1417 | 3.966 | -1.801 | 0.076 | -15.063 | 0.779 |
Omnibus: | 4.467 | Durbin-Watson: | 1.617 |
---|---|---|---|
Prob(Omnibus): | 0.107 | Jarque-Bera (JB): | 4.375 |
Skew: | 0.592 | Prob(JB): | 0.112 |
Kurtosis: | 2.813 | Cond. No. | 128. |
Notes:
[1] Standard Errors are heteroscedasticity robust (HC2)
RegressionResults
also support clustered standard errors using corruption_model.get_robustcov_results(cov_type="cluster", groups=)
. However, groups
has to be passed a vector of integer group
corruption_model.get_robustcov_results(
cov_type="cluster", groups=wdi.dropna().region
).summary()
/Users/nce8/opt/miniconda3/lib/python3.11/site-packages/statsmodels/base/model.py:1896: ValueWarning: covariance of constraints does not have full rank. The number of constraints is 7, but rank is 2
warnings.warn('covariance of constraints does not have full '
Dep. Variable: | mortality_rate_under5_per_1000 | R-squared: | 0.586 |
---|---|---|---|
Model: | OLS | Adj. R-squared: | 0.541 |
Method: | Least Squares | F-statistic: | 4.404 |
Date: | Sat, 20 Jul 2024 | Prob (F-statistic): | 0.0789 |
Time: | 19:04:20 | Log-Likelihood: | -322.68 |
No. Observations: | 73 | AIC: | 661.4 |
Df Residuals: | 65 | BIC: | 679.7 |
Df Model: | 7 | ||
Covariance Type: | cluster |
coef | std err | t | P>|t| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | 169.9397 | 16.975 | 10.011 | 0.000 | 126.304 | 213.575 |
region[T.Europe and Central Asia] | -15.9265 | 1.241 | -12.829 | 0.000 | -19.118 | -12.735 |
region[T.Latin America and Caribbean] | 1.9023 | 0.694 | 2.739 | 0.041 | 0.117 | 3.688 |
region[T.Middle East and North Africa] | 3.7668 | 2.644 | 1.425 | 0.214 | -3.030 | 10.564 |
region[T.South Asia] | 4.9372 | 0.719 | 6.869 | 0.001 | 3.090 | 6.785 |
region[T.Sub-Saharan Africa] | 27.8448 | 1.385 | 20.098 | 0.000 | 24.283 | 31.406 |
np.log(gdp_per_capita_ppp) | -13.3790 | 2.727 | -4.906 | 0.004 | -20.389 | -6.369 |
CPIA_public_sector_rating | -7.1417 | 2.067 | -3.455 | 0.018 | -12.455 | -1.829 |
Omnibus: | 4.467 | Durbin-Watson: | 1.617 |
---|---|---|---|
Prob(Omnibus): | 0.107 | Jarque-Bera (JB): | 4.375 |
Skew: | 0.592 | Prob(JB): | 0.112 |
Kurtosis: | 2.813 | Cond. No. | 128. |
Notes:
[1] Standard Errors are robust to cluster correlation (cluster)
Linear Regression scikit-learn#
What about scikit-learn
? scikit-learn
is probably the most popular library for machine learning, and like statsmodels
it is also able to fit linear regressions. Because scikit-learn
is a library written for prediction rather than inference, however, the way it has been implemented is very, very different from statsmodels
. To illustrate, let’s fit the same regression we did at the top of this reading using scikit-learn.
(Note that for reasons we’ll cover in the following reading, I have to do some extra data wrangling first)
wdi["log_gdp_per_cap"] = np.log(wdi["gdp_per_capita_ppp"])
wdi_w_onehots = pd.concat(
[wdi, pd.get_dummies(wdi["region"], drop_first=True, prefix="reg").astype("int")],
axis="columns",
)
subset = wdi_w_onehots[
[
"mortality_rate_under5_per_1000",
"log_gdp_per_cap",
"CPIA_public_sector_rating",
"reg_Europe and Central Asia",
"reg_Latin America and Caribbean",
"reg_Middle East and North Africa",
"reg_South Asia",
"reg_Sub-Saharan Africa",
]
].dropna()
from sklearn.linear_model import LinearRegression
# Fit linear regression
my_model = LinearRegression(fit_intercept=True)
my_model.fit(subset.iloc[:, 1:].values, subset.iloc[:, 0].values)
# Get coefficients
my_model.coef_
array([-13.37903311, -7.14171606, -15.9264855 , 1.9023125 ,
3.76684256, 4.9371726 , 27.84484623])
And there it is! If you look back to the first regression we fit, you will see that the coefficients calculated by scikit-learn
are identical to those computed by statsmodels
.
However, as you can also see by the fact that the only method for presenting the regression results in scikit-learn
is to spit out the coefficients as a simple numpy array, scikit-learn
was never designed for users interested in interpreting the coefficients of a regression. There’s no .summary()
method that presents all the statistics provided by statsmodels
, and none of the additional functionality for doing things like calculating different types of standard errors (if you’re new to linear regression, those are just different ways for calculating the statistical properties of estimated coefficients).
Yes, scikit-learn
can fit the same models, but the user experience is radically different.
Next: Pulling Back the Curtain#
This reading has hopefully given you a sense of how easy statsmodels makes it to fit and analyze a linear regression. Some of the extensions to basic modelling discussed here may not be familiar to everyone, but even so, hopefully this gives you a sense of how things are organized so when you do learn more about linear regression, you’ll know how to use that knowledge.
In our next reading, we’ll put ourselves squarely back in the land of data science programming as we pull back the curtain and take a look at what’s going on behind the scenes of statsmodels.formula.api.ols
.