Machine learning
Model quality

Dealing with class imbalance - Part 2

Evaluating models on unbalanced datasets

How do you develop a machine learning (ML) model to pick out one fraudulent transaction out of a million, or to diagnose a rare disease, when there isn’t enough (or sometimes any) representative data?

Scenarios like these are not rare in the field of ML and are known as class imbalance problems. In the first post of this series, we tackled the first half of the problem, which is training models on unbalanced datasets, i.e., datasets that have many more samples from some classes than from others. Now, in the second part, we are going to explore the problems and solutions associated with another stage of the ML development pipeline: model evaluation.

If you haven’t read the first part yet, we encourage you to do so, but we made an effort to make this post is self-contained. Feel free to explore them in whichever order you prefer.

Learn the secrets of performant and explainable ML

Be the first to know by subscribing to the blog. You will be notified whenever there is a new post.

And hey — if you are already familiar with the topics that we cover today, feel free to start evaluating and debugging your models with Openlayer!

Evaluation metrics and unbalanced datasets

At the beginning of the first part of our series, we motivated the problem of class imbalance using the example of a model that detects a rare disease. We showed that even if the model had a very high accuracy of 99.9%, it was likely not a great model.

Now, you might be wondering: what accuracy value would be enough to make that model good? Would it be a 99.99% accuracy or should we be even more rigorous?

The answer is none of the above. In such a scenario, the problem is not the accuracy value per se, but the metric we use.

The accuracy has one important characteristic that makes it a tricky metric to use with the rare disease classifier: it treats all of the model’s predictions equally. This means that it doesn’t care if the correct prediction was on an example from the majority or minority class; both count exactly the same when we compute it.

The problem is that when we have an unbalanced dataset as we do in our example, the model’s performance on the majority class ends up dominating the accuracy, eclipsing the likely not-so-satisfactory performance on the minority class. This is why a model that does nothing but always predicts the majority class will have a high accuracy even though it is useless in the real world.

We cannot emphasize enough how important correctly evaluating an ML model is. Choosing the wrong metric to optimize for leaves ML teams navigating blindly through the cascade of decisions that need to be made when fine-tuning models, performing model selection, conducting error analysis, and deciding which version of the model to ship.

The accuracy is not alone in the category of popular metrics that are tricky to use with unbalanced datasets. The error rate is also in the same category as is any other metric that treats all the model’s predictions and mistakes equally.

Fortunately, there are solutions to such a problem and the first one starts by addressing the point we just highlighted above.

Solution #1: treating different model predictions differently

The problem with the accuracy and the error rate is that they treat all the model’s predictions equally, so a prediction in the majority class counts the same as a prediction in the minority class. One possible solution, then, is to start segregating the model’s predictions and start treating different model predictions differently.

Let’s look at the rare disease classifier’s predictions in more detail:

Our model’s predictions can be organized into a confusion matrix, illustrated above. Furthermore, notice that instead of simply saying the model’s predictions were “right” or “wrong”, using the confusion matrix, there are now two different types of “rights” and two different types of “wrongs”.

The correct model’s predictions are the true negatives (TN — when the model correctly predicts a patient doesn’t have the disease) and the true positives (TP — when the model correctly predicts the patient does have the disease). On the other hand, there are two types of mistakes that the model can make: false negatives (FN — which is the case when the model predicts that the patient is disease-free, but in fact, it has it) and false positives (FP — which occur when the model thinks the data is from a patient that has the disease when in fact, it doesn’t).

As we mentioned previously, the accuracy only cares about the ratio of correct model predictions; it doesn’t care if they are correct predictions for the majority class (TN) or the minority class (TP). Using the above notation, the accuracy is given by:


where n is the total number of samples. On an unbalanced dataset where there are many more negative examples, TN might be very high and TP very low, which still results in high accuracy.

However, now that we have separated the different model predictions, we can compute other metrics, that focus on the slices of the confusion matrix that we care the most about. This motivates the use of recall, precision, and F1. For the sake of completeness, let’s briefly define these other metrics, as there are various resources online that dive deeper into each of them, including this amazing diagram in Wikipedia.

The recall is the true positive rate, so it is the number of true positive predictions divided by the number of positive samples on the dataset. In our example, by using the recall, we would be zooming into the model performance over the few samples that indeed have the rare disease:

Recall=TPTP+FN\text{Recall}=\frac{TP}{TP + FN}

The precision is the fraction of true positive predictions the model made over all of the predicted positives. Again, in our example, by using the precision, we would be looking at all the samples the model predicted as having the disease and looking at how many of those were correct predictions:

Precision=TPTP+FP\text{Precision}=\frac{TP}{TP + FP}

The F1 is the harmonic mean of precision and recall.

F1=2Precision+2RecallF1=\frac{2}{\text{Precision}} + \frac{2}{\text{Recall}}

Notice that by zooming into specific slices of the confusion matrix and treating different model predictions differently, we can get rid of the distorting effects of the accuracy. Furthermore, if you are serious about evaluating your models and want to move to a new layer of ML maturity, you can also calculate these aggregate metrics for each class, as this often reveals further details about your model.

Depending on the problem and on the dataset, one metric is preferable over the other, so you must be aware of the trade-offs that might be associated with your choices before committing to one (or a few) of them. Just as a general guideline, it is worth remembering that:

  • if the dataset is extremely unbalanced, overall accuracy is never a good metric to optimize for. However, looking at the accuracy per class can still be useful;
  • if you are very worried about false negatives, you should probably focus on recall;
  • if you care more about getting true positives, you should probably look at precision;
  • if you need a (harmonic) equilibrium between both, F1 is the way to go.

Solution #2: evaluating the trade-offs

In practice, a lot of the models used for classification output class probabilities instead of directly outputting the label. The ML practitioner, then, defines a threshold above which it considers the model is confident enough on a certain class and that we should, then, use that class as the model’s prediction.

For example, the rare disease classifier might output the probability of an individual having the disease. Depending on that probability, we might interpret the model’s prediction as either being ‘disease’ or ‘healthy’. A common threshold used in practice is 0.5: if the probability is below 0.5, we consider the model to be predicting label 0, and if above 0.5, label 1.

Making the model output a probability value instead of the class name directly might seem like a pointless complication, but there is an important reason for following this approach. In practice, not all model mistakes cost the same.

In our disease classifier example, it is preferable to have more false positives than false negatives. Can you see why this is the case?

If the model says a patient has the disease when in fact it doesn’t (false positive), the patient might feel scared at first and seek help from a doctor. Further exams will reveal that the model made a mistake and the patient will likely feel relieved. However, if the model says a patient is disease-free, when in fact it isn’t (false negative), the patient will go on with their lives and eventually suffer severe symptoms that could easily be avoided if the disease was detected earlier.

The different costs associated with the model’s mistakes can also be observed in other problems. Think about a model that classifies a transaction as fraudulent or normal. Or about an e-mail spam classifier. In both situations, false positives are cheaper than false negatives.

Now, back to the threshold: by playing with the threshold value, practitioners can find a balance between the different types of model mistakes. A very low threshold results in more false positives while a very high threshold results in more false negatives.

With this idea in mind, it is possible to evaluate models in a much more comprehensive way by assessing how their performance change for different threshold values.

One of the most popular ways of doing so is by plotting the receiver operator characteristic curves (ROC), which depict the trade-off between the false positive rate and the true positive rate. By plotting the ROC curves for different models we can start observing how they behave face to this trade-off.

A perfect model would have a true positive rate of 1 and is depicted as just a line on the top. The closer we are to the perfect model, the better. In practice, to avoid the hassle of visually comparing ROC curves and deciding what to do when two curves intersect, it is common to compute the area under the ROC curve (known as AUC) and use that number as a proxy for model performance, choosing the model with the largest AUC.

Another trade-off that can be evaluated by varying the threshold is the one between precision and recall, giving rise to precision-recall (PR) curves.

PR curves seem to be more appropriate to use when there is class imbalance. This post explores the comparison between the ROC and PR curves on unbalanced datasets with code samples, so feel free to check it out.

Using ROC and PR curves helps us evaluate model performance on unbalanced datasets because we are not only segregating the types of model predictions (like we did in the previous section) but also evaluating the associated trade-offs.

The commonly used metrics, such as accuracy and error rates, are not the most appropriate to use when we are dealing with unbalanced datasets. Hopefully, after reading this post, you understand what’s the issue with these common metrics and what are the alternatives to overcome the problems associated with them. Being prepared to deal with class imbalance is a must for ML practitioners and we hope to have helped a little bit on this journey with our posts on the topic.

* A previous version of this article listed the company name as Unbox, which has since been rebranded to Openlayer.

Recommended posts

Model quality
Machine learning

Debugging models with the bias-variance trade-off

Systematically boosting model performance

Gustavo Cid

May 23rd, 2022 • 7 minute read

Debugging models with the bias-variance trade-off
Machine learning
Model quality

Dealing with class imbalance - Part 1

Learning with unbalanced datasets

Gustavo Cid

March 22nd, 2022 • 6 minute read

Dealing with class imbalance - Part 1
Machine learning
Model quality

Model evaluation in machine learning

Understanding the true purpose of model evaluation in the quest for high-quality models

Gustavo Cid

February 22nd, 2022 • 8 minute read

Model evaluation in machine learning