Table of contents:
- What is Transfer Learning?
- Why Transfer Learning Now?
- A Definition of Transfer Learning
- Transfer Learning Scenarios
- Applications of Transfer Learning
- Transfer Learning Methods
- Related Research Areas
- Conclusion
In recent years, we have become increasingly good at training deep neural networks to learn a very accurate mapping from inputs to outputs, whether they are images, sentences, label predictions, etc. from large amounts of labeled data.
What our models still frightfully lack is the ability to generalize to conditions that are different from the ones encountered during training. When is this necessary? Every time you apply your model not to a carefully constructed dataset but to the real world. The real world is messy and contains an infinite number of novel scenarios, many of which your model has not encountered during training and for which it is in turn ill-prepared to make predictions. The ability to transfer knowledge to new conditions is generally known as transfer learning and is what we will discuss in the rest of this post.
Over the course of this blog post, I will first contrast transfer learning with machine learning’s most pervasive and successful paradigm, supervised learning. I will then outline reasons why transfer learning warrants our attention. Subsequently, I will give a more technical definition and detail different transfer learning scenarios. I will then provide examples of applications of transfer learning before delving into practical methods that can be used to transfer knowledge. Finally, I will give an overview of related directions and provide an outlook into the future.
What is Transfer Learning?
In the classic supervised learning scenario of machine learning, if we intend to train a model for some task and domain (A), we assume that we are provided with labeled data for the same task and domain. We can see this clearly in Figure 1, where the task and domain of the training and test data of our model (A) is the same. We will later define in more detail what exactly a task and a domain are). For the moment, let us assume that a task is the objective our model aims to perform, e.g. recognize objects in images, and a domain is where our data is coming from, e.g. images taken in San Francisco coffee shops.
We can now train a model (A) on this dataset and expect it to perform well on unseen data of the same task and domain. On another occasion, when given data for some other task or domain (B), we require again labeled data of the same task or domain that we can use to train a new model (B) so that we can expect it to perform well on this data.
The traditional supervised learning paradigm breaks down when we do not have sufficient labeled data for the task or domain we care about to train a reliable model.
If we want to train a model to detect pedestrians on night-time images, we could apply a model that has been trained on a similar domain, e.g. on day-time images. In practice, however, we often experience a deterioration or collapse in performance as the model has inherited the bias of its training data and does not know how to generalize to the new domain.
If we want to train a model to perform a new task, such as detecting bicyclists, we cannot even reuse an existing model, as the labels between the tasks differ.
Transfer learning allows us to deal with these scenarios by leveraging the already existing labeled data of some related task or domain. We try to store this knowledge gained in solving the source task in the source domain and apply it to our problem of interest as can be seen in Figure 2.
In practice, we seek to transfer as much knowledge as we can from the source setting to our target task or domain. This knowledge can take on various forms depending on the data: it can pertain to how objects are composed to allow us to more easily identify novel objects; it can be with regard to the general words people use to express their opinions, etc.
Why Transfer Learning Now?
Andrew Ng, chief scientist at Baidu and professor at Stanford, said during his widely popular NIPS 2016 tutorial that transfer learning will be — after supervised learning — the next driver of ML commercial success.
In particular, he sketched out a chart on a whiteboard that I’ve sought to replicate as faithfully as possible in Figure 4 below (sorry about the unlabelled axes). According to Andrew Ng, transfer learning will become a key driver of Machine Learning success in industry.
It is indisputable that ML use and success in industry has so far been mostly driven by supervised learning. Fuelled by advances in Deep Learning, more capable computing utilities, and large labeled datasets, supervised learning has been largely responsible for the wave of renewed interest in AI, funding rounds and acquisitions, and in particular the applications of machine learning that we have seen in recent years and that have become part of our daily lives. If we disregard naysayers and heralds of another AI winter and instead trust the prescience of Andrew Ng, this success will likely continue.
It is less clear, however, why transfer learning which has been around for decades and is currently little utilized in industry, will see the explosive growth predicted by Ng. Even more so as transfer learning currently receives relatively little visibility compared to other areas of machine learning such as unsupervised learning and reinforcement learning, which have come to enjoy increasing popularity: Unsupervised learning — the key ingredient on the quest to General AI according to Yann LeCun as can be seen in Figure 5 — has seen a resurgence of interest, driven in particular by Generative Adversarial Networks. Reinforcement learning, in turn, spear-headed by Google DeepMind has led to advances in game-playing AI exemplified by the success of AlphaGo and has already seen success in the real world, e.g. by reducing Google’s data center cooling bill by 40%. Both of these areas, while promising, will likely only have a comparatively small commercial impact in the foreseeable future and mostly remain within the confines of cutting-edge research papers as they still face many challenges.
What makes transfer learning different? In the following, we will look at the factors that — in our opinion — motivate Ng’s prognosis and outline the reasons why just now is the time to pay attention to transfer learning.
The current use of machine learning in industry is characterised by a dichotomy:
On the one hand, over the course of the last years, we have obtained the ability to train more and more accurate models. We are now at the stage that for many tasks, state-of-the-art models have reached a level where their performance is so good that it is no longer a hindrance for users. How good? The newest residual networks [1] on ImageNet achieve superhuman performance at recognising objects; Google’s Smart Reply [2] automatically handles 10% of all mobile responses; speech recognition error has consistently dropped and is more accurate than typing [3]; we can automatically identify skin cancer as well as dermatologists; Google’s NMT system [4] is used in production for more than 10 language pairs; Baidu can generate realistic sounding speech in real-time; the list goes on and on. This level of maturity has allowed the large-scale deployment of these models to millions of users and has enabled widespread adoption.
On the other hand, these successful models are immensely data-hungry and rely on huge amounts of labeled data to achieve their performance. For some tasks and domains, this data is available as it has been painstakingly gathered over many years. In a few cases, it is public, e.g. ImageNet [5], but large amounts of labeled data are usually proprietary or expensive to obtain, as in the case of many speech or MT datasets, as they provide an edge over the competition.
At the same time, when applying a machine learning model in the wild, it is faced with a myriad of conditions which the model has never seen before and does not know how to deal with; each client and every user has their own preferences, possesses or generates data that is different than the data used for training; a model is asked to perform many tasks that are related to but not the same as the task it was trained for. In all of these situations, our current state-of-the-art models, despite exhibiting human-level or even super-human performance on the task and domain they were trained on, suffer a significant loss in performance or even break down completely.
Transfer learning can help us deal with these novel scenarios and is necessary for production-scale use of machine learning that goes beyond tasks and domains were labeled data is plentiful. So far, we have applied our models to the tasks and domains that — while impactful — are the low-hanging fruits in terms of data availability. To also serve the long tail of the distribution, we must learn to transfer the knowledge we have acquired to new tasks and domains.
To be able to do this, we need to understand the concepts that transfer learning involves. For this reason, we will give a more technical definition in the following section.
A Definition of Transfer Learning
For this definition, we will closely follow the excellent survey by Pan and Yang (2010) [6] with binary document classification as a running example.
Transfer learning involves the concepts of a domain and a task. A domain (mathcal{D}) consists of a feature space (mathcal{X}) and a marginal probability distribution (P(X)) over the feature space, where (X = {x_1, cdots, x_n} in mathcal{X}). For document classification with a bag-of-words representation, (mathcal{X}) is the space of all document representations, (x_i) is the binary feature of the (i)-th word and (X) is a particular document.
Given a domain, (mathcal{D} = {mathcal{X},P(X)}), a task (mathcal{T}) consists of a label space (mathcal{Y}) and a conditional probability distribution (P(Y|X)) that is typically learned from the training data consisting of pairs (x_i in X) and (y_i in mathcal{Y}). In our document classification example, (mathcal{Y}) is the set of all labels, i.e. True, False and (y_i) is either True or False.
Given a source domain (mathcal{D}_S), a corresponding source task (mathcal{T}_S), as well as a target domain (mathcal{D}_T) and a target task (mathcal{T}_T), the objective of transfer learning now is to enable us to learn the target conditional probability distribution (P(Y_T|X_T)) in (mathcal{D}_T) with the information gained from (mathcal{D}_S) and (mathcal{T}_S) where (mathcal{D}_S neq mathcal{D}_T) or (mathcal{T}_S neq mathcal{T}_T). In most cases, a limited number of labeled target examples, which is exponentially smaller than the number of labeled source examples are assumed to be available.
As both the domain (mathcal{D}) and the task (mathcal{T}) are defined as tuples, these inequalities give rise to four transfer learning scenarios, which we will discus below.
Transfer Learning Scenarios
Given source and target domains (mathcal{D}_S) and (mathcal{D}_T) where (mathcal{D} = {mathcal{X},P(X)}) and source and target tasks (mathcal{T}_S) and (mathcal{T}_T) where (mathcal{T} = {mathcal{Y}, P(Y|X)}) source and target conditions can vary in four ways, which we will illustrate in the following again using our document classification example:
- (mathcal{X}_S neq mathcal{X}_T). The feature spaces of the source and target domain are different, e.g. the documents are written in two different languages. In the context of natural language processing, this is generally referred to as cross-lingual adaptation.
- (P(X_S) neq P(X_T)). The marginal probability distributions of source and target domain are different, e.g. the documents discuss different topics. This scenario is generally known as domain adaptation.
- (mathcal{Y}_S neq mathcal{Y}_T). The label spaces between the two tasks are different, e.g. documents need to be assigned different labels in the target task. In practice, this scenario usually occurs with scenario 4, as it is extremely rare for two different tasks to have different label spaces, but exactly the same conditional probability distributions.
- (P(Y_S|X_S) neq P(Y_T|X_T)). The conditional probability distributions of the source and target tasks are different, e.g. source and target documents are unbalanced with regard to their classes. This scenario is quite common in practice and approaches such as over-sampling, under-sampling, or SMOTE [7] are widely used.
After we are now aware of the concepts relevant for transfer learning and the scenarios in which it is applied, we will look to different applications of transfer learning that illustrate some of its potential.
Applications of Transfer Learning
Learning from simulations
One particular application of transfer learning that I’m very excited about and that I assume we’ll see more of in the future is learning from simulations. For many machine learning applications that rely on hardware for interaction, gathering data and training a model in the real world is either expensive, time-consuming, or simply too dangerous. It is thus advisable to gather data in some other, less risky way.
Simulation is the preferred tool for this and is used towards enabling many advanced ML systems in the real world. Learning from a simulation and applying the acquired knowledge to the real world is an instance of transfer learning scenario 2, as the feature spaces between source and target domain are the same (both generally rely on pixels), but the marginal probability distributions between simulation and reality are different, i.e. objects in the simulation and the source look different, although this difference diminishes as simulations get more realistic. At the same time, the conditional probability distributions between simulation and real wold might be different as the simulation is not able to fully replicate all reactions in the real world, e.g. a physics engine can not completely mimic the complex interactions of real-world objects.
Learning from simulations has the benefit of making data gathering easy as objects can be easily bounded and analyzed, while simultaneously enabling fast training, as learning can be parallelized across multiple instances. Consequently, it is a prerequisite for large-scale machine learning projects that need to interact with the real world, such as self-driving cars (Figure 6). According to Zhaoyin Jia, Google’s self-driving car tech lead, “Simulation is essential if you really want to do a self-driving car”. Udacity has open-sourced the simulator it uses for teaching its self-driving car engineer nanodegree, which can be seen in Figure 7 and OpenAI’s Universe will potentially allows to train a self-driving car using GTA 5 or other video games.
Another area where learning from simulations is key is robotics: Training models on a real robot is too slow and robots are expensive to train. Learning from a simulation and transferring the knowledge to real-world robot alleviates this problem and has recently been garnering additional interest [8]. An example of a data manipulation task in the real world and in a simulation can be seen in Figure 8.
Finally, another direction where simulation will be an integral part is on the path towards general AI. Training an agent to achieve general artificial intelligence directly in the real world is too costly and hinders learning initially through unnecessary complexity. Rather, learning may be more successful if it is based on a simulated environment such as CommAI-env [9] that is visible in Figure 9.
Adapting to new domains
While learning from simulations is a particular instance of domain adaptation, it is worth outlining some other examples of domain adaptation.
Domain adaptation is a common requirement in vision as often the data where labeled information is easily accessible and the data that we actually care about are different, whether this pertains to identifying bikes as in Figure 10 or some other objects in the wild. Even if the training and the the test data look the same, the training data may still contain a bias that is imperceptible to humans but which the model will exploit to overfit on the training data [10].
Another common domain adaptation scenario pertains to adapting to different text types: Standard NLP tools such as part-of-speech taggers or parsers are typically trained on news data such as the Wall Street Journal, which has historically been used to evaluate these models. Models trained on news data, however, have difficulty coping with more novel text forms such as social media messages and the challenges they present.
Even within one domain such as product reviews, people employ different words and phrases to express the same opinion. A model trained on one type of review should thus be able to disentangle the general and domain-specific opinion words that people use in order not to be confused by the shift in domain.
Finally, while the above challenges deal with general text or image types, problems are amplified if we look at domains that pertain to individual or groups of users: Consider the case of automatic speech recognition (ASR). Speech is poised to become the next big platform, with 50% of all our searches predicted to be performed by voice by 2020. Most ASR systems are evaluated traditionally on the Switchboard dataset, which comprises 500 speakers. Most people with a standard accent are thus fortunate, while immigrants, people with non-standard accents, people with a speech impediment, or children have trouble being understood. Now more than ever do we need systems that are able to adapt to individual users and minorities to ensure that everyone’s voice is heard.
Transferring knowledge across languages
Finally, learning from one language and applying our knowledge to another language is — in my opinion — another killer application of transfer learning, which I have written about before here in the context of cross-lingual embedding models. Reliable cross-lingual adaptation methods would allow us to leverage the vast amounts of labeled data we have in English and apply them to any language, particularly underserved and truly low-resource languages. Given the current state-of-the-art, this still seems utopian, but recent advances such as zero-shot translation [11] promise rapid progress in this area.
While we have so far considered particular applications of transfer learning, we will now look at practical methods and directions in the literature that are used to solve some of the presented challenges.
Transfer Learning Methods
Transfer learning has a long history of research and techniques exist to tackle each of the four transfer learning scenarios described above. The advent of Deep Learning has led to a range of new transfer learning approaches, some of which we will review in the following. For a survey of earlier methods, refer to [6].
Using pre-trained CNN features
In order to motivate the most common way of transfer learning is currently applied, we must understand what accounts for the outstanding success of large convolutional neural networks on ImageNet [12].
Understanding convolutional neural networks
While many details of how these models work still remain a mystery, we are by now aware that lower convolutional layers capture low-level image features, e.g. edges (see Figure 14), while higher convolutional layers capture more and more complex details, such as body parts, faces, and other compositional features.
The final fully-connected layers are generally assumed to capture information that is relevant for solving the respective task, e.g. AlexNet’s fully-connected layers would indicate which features are relevant to classify an image into one of 1000 object categories.
However, while knowing that a cat has whiskers, paws, fur, etc. is necessary for identifying an animal as a cat (for an example, see Figure 15), it does not help us with identifying new objects or to solve other common vision tasks such as scene recognition, fine grained recognition, attribute detection and image retrieval.
What can help us, however, are representations that capture general information of how an image is composed and what combinations of edges and shapes it contains. This information is contained in one of the final convolutional layers or early fully-connected layers in large convolutional neural networks trained on ImageNet as we have described above.
For a new task, we can thus simply use the off-the-shelf features of a state-of-the-art CNN pre-trained on ImageNet and train a new model on these extracted features. In practice, we either keep the pre-trained parameters fixed or tune them with a small learning rate in order to ensure that we do not unlearn the previously acquired knowledge. This simple approach has been shown to achieve astounding results on an array of vision tasks [13] as well as tasks that rely on visual input such as image captioning. A model trained on ImageNet seems to capture details about the way animals and objects are structured and composed that is generally relevant when dealing with images. As such, the ImageNet task seems to be a good proxy for general computer vision problems, as the same knowledge that is required to excel in it is also relevant for many other tasks.
Learning the underlying structure of images
A similar assumption is used to motivate generative models: When training generative models, we assume that the ability to generate realistic images requires an understanding of the underlying structure of images, which in turn can be applied to many other tasks. This assumption itself relies on the premise that all images lie on a low-dimensional manifold, i.e. that there is some underlying structure to images that can be extracted by a model. Recent advances in generating photorealistic images with Generative Adversarial Networks [14] indicate that such a structure might ind