Testing in a Data Analysis Workflow#
It’s hard to do much programming without coming across the concept of testing — the practice of writing code whose sole purpose is to verify the programs you write actually do what you expect them to do. Indeed, testing is such a core part of programming that few people would consider posting their code publicly on a service like github without including a suite of tests. Some programmers have even adopted a practice known as test-driven development in which they start a project by writing tests that they want the code the plan to write to eventually be able to pass and then work until those tests pass.
But while testing is deeply engrained in the practice of writing software, all too often, data scientists assume testing isn’t relevant when they are writing simple scripts that are meant to run linearly to do things like run load, clean, merge, reshape, and analyze datasets using canned library functions. After all, if the point of tests is to ensure that the code you write is working correctly, and all you use in your simple scripts are functions that come from libraries like numpy and pandas (that are tested by the library maintainers), why would you need to use tests?
The answer is twofold. First, it’s easy to make coding mistakes even when you aren’t writing complicated programs. Operations like merging, grouping, and reshaping can easily get wrong, so it’s good to verify that the results of those operations are what you expect.
But the second and bigger reason is that tests are needed in data analysis workflows to test that what you think is true about your data is actually correct. In other words, writing tests in a data analysis workflow is often not about ensuring you wrote the code you think you wrote, but rather about verifying your assumptions about the structure and properties of your data — and thus the behavior of the code you write — are correct.
A Simple Example#
To illustrate, consider the following simple example: I load a small subset of data from the World Bank World Development Indicators. The data includes data on countries, their GDP per capita, and Polity Scores (a measure of political freedom).
In the code, I then try to compare the average Polity score for large oil producers to all other countries:
import pandas as pd
import numpy as np
pd.set_option("mode.copy_on_write", True)
wdi = pd.read_csv("data/world-small.csv")
wdi
country | region | gdppcap08 | polityIV | |
---|---|---|---|---|
0 | Albania | C&E Europe | 7715 | 17.8 |
1 | Algeria | Africa | 8033 | 10.0 |
2 | Angola | Africa | 5899 | 8.0 |
3 | Argentina | S. America | 14333 | 18.0 |
4 | Armenia | C&E Europe | 6070 | 15.0 |
... | ... | ... | ... | ... |
140 | Venezuela | S. America | 12804 | 16.0 |
141 | Vietnam | Asia-Pacific | 2785 | 3.0 |
142 | Yemen | Middle East | 2400 | 8.0 |
143 | Zambia | Africa | 1356 | 15.0 |
144 | Zimbabwe | Africa | 188 | 6.0 |
145 rows × 4 columns
# Top 10 from Wikipedia: https://en.wikipedia.org/wiki/List_of_countries_by_oil_production
list_of_top10_oil_producers = [
"United States",
"Russia",
"Saudi Arabia",
"Canada",
"Iraq",
"China",
"Iran",
"Brazil",
"United Arab Emirates",
"Kuwait",
]
wdi["large_oil_producers"] = wdi["country"].isin(list_of_top10_oil_producers)
avg_income_of_large_producers = wdi.loc[wdi["large_oil_producers"], "gdppcap08"].mean()
avg_income_of_non_producers = wdi.loc[~wdi["large_oil_producers"], "gdppcap08"].mean()
print(
"The average GDP per Capita of the ten largest "
f"oil producers is: ${avg_income_of_large_producers:,.0f}."
)
print(
"The average GDP per Capita of all other "
f"countries is: ${avg_income_of_non_producers:,.0f}."
)
The average GDP per Capita of the ten largest oil producers is: $21,625.
The average GDP per Capita of all other countries is: $12,698.
Simple, right? Except… that answer is wrong. Why? Well, we were trying to subset for the 10 largest oil producers. But were we successful? Let’s check.
# Look at countries for which we subset:
wdi.loc[wdi["large_oil_producers"], "country"]
16 Brazil
21 Canada
25 China
60 Iran
61 Iraq
71 Kuwait
109 Russia
111 Saudi Arabia
137 United States
Name: country, dtype: object
# Wait... is that 10?
len(wdi.loc[wdi["large_oil_producers"], "country"])
9
Oops! What happened? Well, if we compare the list of country names we for which large_oil_producers
is True
to our original list, we can see what country is missing:
big_producers_in_df = wdi.loc[wdi["large_oil_producers"], "country"]
set(big_producers_in_df).symmetric_difference(list_of_top10_oil_producers)
{'United Arab Emirates'}
And if we look at our wdi
data more carefully, we can see the problem is that United Arab Emirates isn’t "United Arab Emirates"
in the World Development Indicator data, it’s "UAE"
:
list(wdi["country"][130:])
['Tunisia',
'Turkey',
'Turkmenistan',
'UAE',
'Uganda',
'Ukraine',
'United Kingdom',
'United States',
'Uruguay',
'Uzbekistan',
'Venezuela',
'Vietnam',
'Yemen',
'Zambia',
'Zimbabwe']
If this feels contrived, I can promise it’s not! “Messy data” is a constant in data science, especially when you’re trying to answer a question no one has ever answered before and are thus working with data in ways other people haven’t explored before!
So now let’s correct that problem and write a test to ensure we fixed it correctly!
list_of_top10_oil_producers = [
"United States",
"Russia",
"Saudi Arabia",
"Canada",
"Iraq",
"China",
"Iran",
"Brazil",
"UAE",
"Kuwait",
]
wdi["large_oil_producers"] = wdi["country"].isin(list_of_top10_oil_producers)
big_producers_in_df = wdi.loc[wdi["large_oil_producers"], "country"]
# Check all oil producers found in wdi
assert (
set(big_producers_in_df).symmetric_difference(list_of_top10_oil_producers) == set()
)
Note
Catching errors in your intermediate code is especially important when doing data analysis data science. Why? Because when analyzing data, you’re trying to answer a question no one has answered before, meaning it’s inherent to the undertaking that you can’t verify your final answer directly! The only way to have confidence in your answer is by having confidence in the steps you took to get to it.
I check my code, why test?#
In my experience, most data analysts check that things are working correctly interactively as they go. For example, they might do exactly what I did above — type len(wdi.loc[wdi["large_oil_producers"], "country"])
and look at the result. But that’s not a test. A test has to check something and alert you to the problem automatically, which requires, at the very least, an assert
statement. Why?
Tests are executed every time your code is run. Most of us check things the first time we write a piece of code. But days, weeks, or months later, we may come back, modify the code that occurs earlier in our code stream, and then just re-run the code. If those changes lead to problems in later files, we may not be aware of them. But if you have tests in place, then those early changes will result in an error in the later files, and you can track down the problem.
It gets you in the habit of always checking. Most of us only stop to check aspects of our data when we suspect problems. But if you become accustomed to writing a handful of tests at the bottom of every file – or after every execution of a certain operation (e.g., I always include tests after a merge or reshape), we get into the habit of always stopping to think about what our data should look like.
It helps you catch your problems faster. This is less about code integrity than sanity, but a great upside to tests is that they ensure that if a mistake slips into your code, you become aware of it quickly, making it easier to identify and fix the changes that caused the problem.
Tests catch more than anticipated problems. When problems emerge in code, they often manifest in lots of different ways. Duplicate observations, for example, will not only lead to inaccurate observation counts but may also give rise to bizarre summary statistics, bad subsequent merges, etc. Thus, adding tests not only guards against errors we’ve thought of but may also guard against errors we don’t anticipate during the test writing process.
Writing Tests#
Tests are easy to write in any language, but given this course’s focus on Python, I will discuss Python here. For examples of tests in Stata and R, you can see some resources I created towards the bottom of this site.
In Python, the easiest way to add tests to a data analysis script is with the assert
keyword:
x = 7
y = 5
# Make sure that x is greater than y
assert x > y
# Make sure that x is odd
assert x % 2 == 1
assert
can also be used with vectors, though doing so requires one additional step. Since logical tests applied to vectors return vectors of Booleans, we have to specify how to evaluate that whole vector using .any()
(returns True
if any entries in the vector are True
) and all()
(only returns True
if all the entries are True
). For example:
# Make sure everyone's GDP per capita estimates are positive:
assert (wdi["gdppcap08"] > 0).all()
If you don’t use .all()
or .any()
, you will get this error saying “the truth value of a Series is ambiguous,” meaning “This vector may have both True
and False
values, but assert
is looking for a single True
or False
. How should I interpret a mix of Trues and Falses?”
assert wdi["gdppcap08"] > 0
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/var/folders/fs/h_8_rwsn5hvg9mhp0txgc_s9v6191b/T/ipykernel_48692/2951563830.py in ?()
1 # Make sure everyone's GDP per capita estimates are positive:
----> 2 assert wdi["gdppcap08"] > 0
/users/nce8/miniforge3/lib/python3.12/site-packages/pandas/core/generic.py in ?(self)
1575 @final
1576 def __nonzero__(self) -> NoReturn:
-> 1577 raise ValueError(
1578 f"The truth value of a {type(self).__name__} is ambiguous. "
1579 "Use a.empty, a.bool(), a.item(), a.any() or a.all()."
1580 )
ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
When Should I Write Tests?#
The best way to get into writing tests is to think about how you check your data interactively to make stuff work. After a merge
or a groupby
command, most people pause to browse the data and/or watch the code step by step or do a set of quick tabs or plots. But these are not systematic, and you generally only do them once (when first writing the code).
So, a good starting point is to examine what you do when you check your data interactively and convert the logic of those interactive interrogations into systematic assert statements.
The best way to get into writing tests is to think about how you check your data interactively to make stuff work. After a merge
or a groupby
command, most people pause to browse the data and/or watch the code step by step or do a set of quick tabs or plots. But these are not systematic, and you generally only do them once (when first writing the code).
So, a good starting point is to examine what you do when you check your data interactively and convert the logic of those interactive interrogations into systematic assert statements.
After merge
: Nowhere are problems with data made clearer than in a merge. ALWAYS add tests after a merge!
After complicated manipulations: If you have to think more than a little about how to get Python or Pandas to do something, there’s a chance you missed something. Add a test or two to make sure you did it right! Personally, for example, I seldom use
groupby
commands without adding tests — it’s just not a natural way to think about things, so I know I may have screwed up (and often have!).Before dropping observations: Dropping observations masks problems. Before you drop variables, add a test to count the number of observations you expect to drop or retain.
Common Test Examples#
Test number of observations is right:
# We'll use `wdi` again:
wdi.sample(5)
country | region | gdppcap08 | polityIV | large_oil_producers | |
---|---|---|---|---|---|
16 | Brazil | S. America | 10296 | 18.0 | True |
67 | Jordan | Middle East | 5283 | 8.0 | False |
78 | Lithuania | C&E Europe | 18824 | 20.0 | False |
44 | France | W. Europe | 34045 | 19.0 | False |
55 | Haiti | S. America | 1177 | 8.0 | False |
assert len(wdi) == 145
Check var that should have no missing has no missing.
assert pd.notnull(wdi["country"]).all()
# Or, the same test written with any instead of all
assert not pd.isnull(wdi["country"]).any()
Check what I think is my unique identifier is actually unique.
assert not wdi["country"].duplicated().any()
Make sure values of GDP Per Capita have a reasonable value. Note this is a “reasonableness” test, not an absolute test. It’s possible this would fail and the data is ok, but this way if there’s a problem your attention will be flagged so you can check.
assert (0 < wdi.gdppcap08).all()