Part 10 : Following along MIT intro to deep learning
Abhijit Ramesh / April 16, 2021
9 min read • ––– views
Taming Dataset Bias via Domain Adaptation
Introduction
Deep learning has been very successful recently and it has shown very good performance especially in computer vision task.
We can do things like detecting objects on roads with bounding boxes even go as far as adapting this to work with non-human characters or cartoon character as well. Deep Learning can also detect facial features and emotions from faces.
But there are also problems with deep learning one of the main one is dataset bias.
Dataset bias can occur mainly due to things like domain shift what this means is that let's say that we have some data using which we trained a self-driving car and this data is trained on images from California where there is not much snow and the model is performing well now if we want to adopt this model to work on snow terrain for example new England which is a new environment with snow and people wearing heavy cloths in this case the model will not be able to detect the people or generally would have very less accuracy. This bias happens due to the change in the environment and this is called domain shift.
When does dataset bias occur ?
Generally to think about it dataset bias occur when we do things like,
- Train model using data of one city and try to use it on another city
In this case things like the environment and the lighting conditions change.
- Take data from the web to train a model and try to apply this on a task a robot has to do.
- Training data in a simulation and putting the model to test on a real control.
- When the data trained is from one demographic of people and we try to use this model to work with another demographic of people.
- When the data trained in images taken from one culture and we use this on doing a task on images of the same genre but there is a culture shift.
This might be understandable but the fact is that these kinds of biases even happen in toy datasets such as MNIST
The Image above shows the data a model is being trained on and then it is shown where the same data is tested the bar graph below shows the accuracy of these test.
Even in cases where things look very similar like the USPS dataset compared to MNIST dataset, we as human beings feel it's very similar but due to dataset bias machine learning models might get a lesser accuracy due to dataset bias.
Real-world implications of dataset bias
In the real world, many commercial applications that have been deployed are known to show these biases, there have been cases where facial detection models or gender detection models have known to show a lot of bias because most of their data is trained from caucasian faces and they tend to struggle when applied to other faces. Similar there as been very sad cases where a bicyclist have been hit by a self-driving car and the person suffered a fatal accident. When the investigation was carried out the understanding was that the algorithm was not trained on how to handle situations where pedestrians are outside the crosswalk which is also a case of algorithmic bias.
Generally the answer to solving this problem would be collecting more data but this is a very expensive task.
Here labelling 1000 pedestrian polygons costs around $1000 and depending of how many conditions we want new data to be this price will only keep on multiplying.
So what causes there poor performance ?
Mainly this is because the test and training data distributions are different.
Here we can see the distribution of the test and training data, even though the numbers are the same and only the colours are changed they are still clustered in different places whereas they should actually be clustered together. Here there is a distribution shift between the blue and red points.
Another thing to notice is that the blue points are clustered together this is because they representing similar data points and being clustered as groups shows that they are being classified well but in the case of the test data points they are far apart which shows that they are not classified well.
Dealing with data bias
Some of the techniques that can be used to deal with data bias are,
- Collet some labelled data from target domain
- Better backbone CNNs
- Batch Nomalization
- Instance Normalisation + Batch Normalisation
- Data Augmentation, Mix Match
- Semi-supervised methods, such as Pseudo labelling
- Domain Adaption
Domain adaptation
As the name suggest this means adapting the knowledge of one domain to another domain,
In our source domain we have our data and its labels in our target domain we only have our data and not the labels, this is because of reasons like labelling data is an expensive task. The goal here is to make a model under the distribution . We assume that we get to see the unlabeled target data even though we can't see its labels.
Adversarial domain alignment
As we have said earlier we have a group of data with its labels first we have an encoder CNN that can help us plot and visualise there points, the classifier will help in classifying these points into there respective domains. We also have some unlabelled data here we take the same encoder and generate the features but the problem here is that there would be a distribution shift and our goal is to fix this. We fix this by first introducing a Domain Discriminator that predicts if the points are blue points or orange points.
Then the next step is to fix the domain discriminator by fixing the encoder such that the encoder results in a poor domain discriminator accuracy so that it is fooling the domain discriminator by generating features that are similar to the target data. The classification loss is still measured and used to update the encoder so that the model does not do something like grouping everything to one point.
So how does the result ?
Now when we visualise this data we can see that the adapted data the read and blue points have the same distribution and they also have been clustered well. Did this improve the classifier ?
Now we can see that the performance have become way better in different datasets as well. THE SVHN to MNIST is the most complex among all these task and it also have a very noticeable gain.
Pixel space alignment
The method used here as the name suggest is to change the pixels in the test images such that it looks like it came from the training set. Here the classifier would be able to map this to the similar probability distribution as that for the training data for the test because as far as the classifier cares everything looks the same.
The trick here is to use a GAN to take the images from a training domain to produce a reconstruction of the Images so that we can take the images from the target domain and reconstruct this to look like it is from the training domain. We are able to do this because we have unlabelled target data.
Some examples of how this has been done can be seen here,
Here, the training data used is from Grand Theft Auto game and the test data would be a city scape. The training data is already labelled as it is taken from the game elements.
Here pixel adaptation is used to make the real-time images look like it is from GTA so that the model is easily able to do segmentation.
So how well does this reflect on the accuracy ?
In the previous case where the SVHN images were under-performing for domain alignment here, it starts getting a significant gain in performance. This shows that we can use unsupervised methods to do image-to-image translations and this would align corresponding structures in the domains.
Few-shot pixel alignment
In both the cases above we assume that we have a good number of target images but what if we have only a few images on the target domain. In this case we need to generate new images and this is where few-shot pixel alignment comes in.
The idea here is to have a source image take the pose of the target image and impost the style of the target image and generate a translated image.
Perviously the architecture that was used for this was funit and improvement to this architecture is called coco -unit where the Style encoder is enhanced to incorporate patterns from both content and style image.
The results of this can be seen here,
This also have a very significant improvement over the FUNIT model which was previously used. From this we can infer that conditioning on content and style image result in better encoding of style.
Moving beyond alignment
We always assume that the categories remain the same above but what if we move beyond this some cases there are some category that does not exist and another where it does not?
Here the idea would be to do a three step domain adaptation method.
First, the categories are clustered using normal neighbourhood clustering methods than an entropy separation by using an entropy loss finally the feature distributions are obtained.
This method shows a very good improvement on the VisDA challenge where synthetic data is adapted to the real target.
Enforcing Consistency
Here instead of doing clustering first, the data passed through a feature extractor that predicts the rotation for the images. Then it is given to an n-way classification head. This model is then taken and it is used to predict the probability distribution of the class on both the original and augmented unlabelled image.
The consistency loss is to ensure that the distributions are consistent. That is all the data should be the same category, the model does not know what category it is because there is no label but it should be categorised together.
Here the model is trained on the same dataset as before and some initial results are shown.
Sources
MIT introtodeeplearning : http://introtodeeplearning.com/
Subscribe to the newsletter
Get emails from me about machine learning, tech, statups and more.
- subscribers