It can be challenging enough trying to evaluate the quality of a run-of-the-mill machine-learning model. For more sophisticated classification models, it can seem like a near-impossible task. That might seem like something of a confusing statement – after all, can’t you just measure a model’s accuracy?
Not exactly.
While it’s true that a high level of accuracy tends to indicate that a model was well-trained, it doesn’t provide us with enough information to accurately measure performance. Let’s say, for instance, you feed a computer vision model a data set consisting of 200 images, and it classifies 190 of those images correctly. On paper, a 5% margin of error is excellent.
Looking a bit deeper, the problem becomes evident. What if that data set consisted of six classes of samples, and the ten incorrect predictions represent two out of those six classes? That means the model is nowhere near as accurate as it seems, but because the samples in the test set were imbalanced, there’s no way of actually knowing that.
Enter the confusion matrix.
What is a Confusion Matrix?
The fittingly named confusion matrix is an assessment framework that helps machine-learning engineers identify how and where a classification model becomes “confused” when making predictions. A confusion matrix starts by identifying each unique class present in a data set or test set. Once the data has been fed to a machine-learning model, the matrix then summarizes the number of correct predictions, incorrect predictions, and count values belonging to each class.
This provides you with a few insights:
- Classes that are under- or over-represented in the data set.
- The classes your machine-learning model consistently struggles to correctly predict or understand.
- The types of predictive errors your machine-learning model most frequently commits.
In other words, a confusion matrix not only helps you identify the strengths and weaknesses of your machine-learning model but also helps you find potential weaknesses in your training data, as well. This also makes confusion matrices invaluable in ensemble learning, as it helps researchers determine how best to combine one or more classifiers.
Confusion matrices cannot be applied when the output distribution of a data set is unknown. unstructured data. Their applications beyond classification models and supervised learning frameworks are, therefore, relatively limited. They cannot, for instance, be used to assess a predictive model that primarily works with unstructured data.
Understanding the Different Confusion Matrix Calculations
How a confusion matrix is calculated and laid out largely depends on the nature of the data set it’s being used to assess. In broad strokes, there are two types of confusion matrices. One is used to assess data sets with binary classes; the other is used for data sets with multiple classes.
We’ll start with an explanation of binary classes.
Binary Class Confusion Matrices
As the name suggests, a binary class data set consists exclusively of two distinct classifications of data. For instance, imagine a computer vision algorithm that identifies whether a photograph was taken during the day or at night. We have a test data set of thirty photographs, and the model correctly classified two of them.
First, we need to identify the two classes. Next, we’ll calculate the number of correct and incorrect classifications. We’ll start with the correct ones:
- Day Classified as Day: 16
- Night Classified as Night: 8
Followed by the incorrect classifications:
- Night Classified as Day: 4
- Day Classified as Night: 2
Finally, we can arrange these values into a binary confusion matrix:
Day | Night | |
Day | 16 | 4 |
Night | 2 | 8 |
To read the table above, keep the following in mind:
- The total number of daytime photographs in the data set is the sum of values in the Day column (16+2).
- The total number of nighttime photographs in the data set is the sum of values in the Night column (4+8).
- The correct classifications form a diagonal line from the top left to the bottom right of the matrix (16+8).
This confusion matrix provides us with two key insights:
- Daytime photographs are over-represented in the data set compared to nighttime photographs, representing roughly 60% of samples.
- The model more frequently classified nighttime photographs as day than daytime photographs as night.
There’s also a unique form of classification matrix that applies exclusively to machine-learning models designed to solve something known as a two-class problem. Essentially, two-class problems are intended to differentiate between an outcome and the absence of that outcome. For the purposes of developing a confusion matrix, we can classify the model’s predictions into the following categories:
- True Positive: The occurrence or presence of the event was correctly predicted.
- True Negative: The absence or non-occurrence of the event was correctly predicted.
- False Positive: The event was incorrectly predicted to have occurred.
- False Negative: The event was incorrectly predicted to have not occurred.
A confusion matrix created from the categories above can also be used to help calculate more advanced machine-learning assessment metrics, such as:
- Accuracy: The percentage of samples the model correctly classified out of the entire data set.
- Precision: The accuracy with which the model predicted samples as belonging to the positive class.
- Recall: The percentage of samples that actually belonged to the positive class. Ideally, precision and recall should be identical.
- Specificity: The accuracy with which the model predicted samples as belonging to the negative class.
Multiple-Class Confusion Matrices
Confusion matrices for multiple classes aren’t especially different from binary-class confusion matrices. They are read and calculated in much the same way, although they are slightly more complex by nature. Circling back to computer vision, imagine there is a data set consisting of photographs of cats, dogs, mice, and rabbits with the following distribution:
Class | Samples |
Cats | 43 |
Dogs | 55 |
Rabbits | 39 |
Mice | 27 |
Cat | Dog | Rabbit | Mouse | |
Cat | 37 | 15 | 1 | 1 |
Dog | 2 | 25 | 3 | 1 |
Rabbit | 3 | 10 | 20 | 3 |
Mouse | 1 | 5 | 8 | 22 |
Based on the confusion matrix above, we can determine the following:
- As with any confusion matrix, accuracy can be calculated by dividing the total samples appearing in the top left to bottom right cells by the number of samples in the full data set. The algorithm correctly identified 104 samples out of a total of 164, giving it a total accuracy of 64.41%.
- The algorithm displays a high degree of accuracy in categorizing both cats and mice, at 88% and 81%, respectively.
- The model appears extremely confused about what constitutes a dog. It also seems prone to confuse rabbits for mice.
- To improve the model’s performance, one’s efforts would be best directed towards further training on classifying dogs and rabbits.
How to Read a Confusion Matrix
As demonstrated above, reading a confusion matrix is relatively simple. You need only keep two facts in mind. First, each column in the confusion matrix represents one class from the data set, and lastly, correct results always form a diagonal line from the top left to the bottom right.
It’s also worth noting that you don’t necessarily need to compile a confusion matrix manually – several tools exist to help you automatically compute and generate matrices for you. For example, the scikit-learn Python package not only features the confusion_matrix function but several functions to help calculate associated metrics, as well.