Balancing Training Data
Data imbalance refers to when the classes in a dataset are not equally distributed, which can then lead to potential risks in training a model. There are several methods to overcome imbalanced data, including resampling and weight balancing.
What You Need to Know
Imagine that you have a model that identifies whether there is a dog or a cat in the picture. During testing, you realized that your model correctly identifies all the dogs in the pictures, but does not identify the cats. In reviewing your training dataset, you realize that there were 10,000 pictures of dogs and only 100 pictures of cats. This is an example of data imbalance, where datasets do not have a comparable number of instances for each object class.
The truth is that imbalanced data is everywhere, and it is impossible to avoid imbalanced datasets. Consider an example of surveying electric car owners’ opinions on electric car maintenance fees. Because most of the people driving electrical cars have high annual income, 80 percent of the results are “the fee is pretty reasonable”. In other words, the dataset is biased. A model trained to predict survey responses would mostly predict that a person, regardless of income, driving tendencies or car preferences, would consider the fees to be inexpensive. The same concern occurs when examining crime data. An imbalanced crime dataset would perpetuate racial and gender biases that exist in the dataset when using artificial intelligence (AI) to predict criminal behavior. Having methods to improve training processes when facing imbalanced data is crucial, and there are two major ways to tackle this problem: focusing on the datasets or on the weights.
In situations where we don’t want to change the model, we can simply conduct data preprocessing. In other words, we should look at our dataset, understand the data distribution, and decide how to resample our data. Here, there are two possible methods:
- Over/under-sampling: increase samples in the minority classes or of decrease samples in the majority classes. In the example of Fig. 1(a), if there are 100 samples in class “A” and 30 samples in class “B”, we would either copy samples in Class B or remove samples from class “A”. Note that this method can also lead to other problems, such as overfitting or information loss.
- Clustering techniques: This is similar to resampling, but instead of adding samples to different classes, we first find the subclasses, or sub-clusters in each class, and then replicate the samples in the subclasses to ensure equal size, Fig. 1(b).
Weight balancing is another good way to tackle imbalanced data, and this is done by adding class weights to the loss function by multiplying logits. We first define class weights to give additional weight to the minority classes, and then multiply the class weights by the loss function. In Tensorflow, one can do something like the following:
loss = tf.nn.softmax_cross_entropy_with_logits(labels, pred)
weighted_loss = loss * class_weights
Why would we multiply the weights by the original loss function? This makes the loss a weighted average, where the weight of each sample is specified by the class_weight of its corresponding class.
At Modzy, our data scientists treat data preprocessing as a crucial task. Before training our models, we make sure that our datasets will not create potential risks and that our models are robust.
What This Means for You
In a world where AI is proliferating, it is important that we place a particular focus on training data to reduce the risk of biased outputs.