Thursday, December 26, 2024
Google search engine
HomeData Modelling & AIHow to Interpret Any Machine Learning Prediction

How to Interpret Any Machine Learning Prediction

Local Interpretable Model-agnostic Explanations (LIME) is a Python project developed by Ribeiro et al. [1] to interpret the predictions of any supervised Machine Learning (ML) model.

Most ML algorithms are black-boxes; we are unable to properly understand how they perform a specific prediction. This is a huge drawback of ML and as Artificial Intelligence (AI) becomes more and more widespread, the importance of understanding ‘the Why?’ is ever-increasing.

In this post, we will discuss how and why the LIME project works. We will also go through an example using a real-life dataset to further understand the results of LIME.

Understanding the Basics of Machine Learning

Before we can understand and truly appreciate the awesomeness of LIME, we must first understand the basic intuition of ML.

Any supervised problem can be summarised in two main characteristics: 𝒙 (our features) and 𝑦 (our target objective). We want to build a model ƒ(𝒙) to generate a prediction 𝑦’ whenever we provide some sample 𝒙’.

During training, the ML model is continually adjusting the weights of the mapping function ƒ(𝒙) — which makes the model a black box since it’s not trivial to understand how these weights are changing.

Understanding the predictions of any ML model boils down to understanding the mapping function behind the said model.

Understanding Model Explainability Types

There are two main types of model explainability:

1. Global Explanations

Given a model ƒ, global explanations are generated using the entire training dataset. Global explanations show the overall feature importance to the model. For every feature, global explanations typically answer the question: “overall, how important was this feature for ƒ?”

2. Local Explanations

Local explanations are based directly on a single observation. Using local explanations, we attempt to understand why ƒ generated that particular prediction for that specific sample. For any given sample, local explanations typically answer the question: “Which features most influenced this specific prediction?”

Throughout the rest of this post, we will focus on and discuss local explanations.

Intuition of LIME

LIME attempts to approximate the model’s mapping function ƒ(𝒙) by sampling instances (referred to as input perturbation). In layman terms, LIME generates a bunch of synthetic samples 𝒙’ which are closely based on the original instance 𝒙. LIME then passes 𝒙’ to the original model ƒ and records the respective prediction. This process enables LIME to determine how the different input fluctuations are influencing ƒ. At the end of the process, for a given sample 𝒙, LIME would be able to approximate the prediction of ƒ by determining the individual influence of every feature. Therefore, LIME is able to explain a specific prediction by understanding which features had the most contribution to the prediction.

In summary

  • LIME samples instances 𝒙’
  • LIME uses 𝒙’ to generate a set of predictions 𝑦’ using ƒ(𝒙)
  • LIME compares the prediction’s proximity to the original prediction and weights them.
  • LIME uses the weights to determine which features are the most influential to that individual prediction.

Diving Deeper

As previously mentioned, the whole idea of LIME is to attempt to interpret ƒ(𝒙). LIME achieves this through surrogate models. A surrogate model g is any model which is used to interpret the results of another predictive algorithm. Typically, g would be a much simpler and much more interpretable model (like a decision tree or a linear model). We can formally define the set of surrogate models as G, such that g ∈ G.

But, how does LIME select which g to use for interpreting the original model?

There are two main deciding factors that are considered by LIME:

*WARNING: Smart-sounding words incoming*

  1. Local Faithfulness, denoted by L(f, g, π) — also called fidelity function
  2. Complexity, denoted by Ω(g)

What is Local Faithfulness?

We have two parts:

  • Local

We have already discussed it. Local simple means that we are focusing on one specific prediction at a time rather than considering ƒ holistically.

  • Faithfulness

As kind of implied, this is a measure of how accurately our selected g is able to follow the original model ƒ. The closer the predictions of g to those of ƒ, the more faithful g is said to be of ƒ. We refer to the ‘closeness’ of the two predictions as the proximity, which is mathematically defined as π.

Simple, right?

What is Complexity?

Explaining 2+2 to a 5-year-old is easier than explaining ∫ tan(𝒙). Why? Because the ‘mapping function’ behind 2+2 is much simpler than that of integration.

The main motivation behind g is to interpret ƒ. Therefore, g must be interpretable. The simpler the g, the more interpretable it becomes.

Complexity is measured in different ways, depending on the type of model being evaluated. For instance, in decision trees, the complexity can be directly given by the depth of the tree (the deeper the tree, the more complex and the less interpretable). In linear models, complexity might be measure in the number of non-zero weights.

LIME attempts to minimise complexity and maximise faithfulness.

The fidelity function L(f, g, π) can be defined by any loss function. LIME uses a square loss distance function. The loss function also changes depending on the type of model to be interpreted (an image classifier will require a different loss function than a tabular one).

This is mainly the gist of the LIME project. There is still more left to it, so if you enjoyed this article, I would highly recommend reading [1].

Now, let’s get our hands dirty with some Python examples!

Working LIME Example in Python

First things first, we need to install LIME using pip. You can find the source code for LIME in [2].

pip install lime

We will use the iris dataset provided to us by Scikit-learn [3] as an example to demonstrate the package usages.

First things first, we need to import the different packages which we will need.

# imports
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

We can also import LIME as follows:

from lime.lime_tabular import LimeTabularExplainer

Our problem has a supervised tabular structure. Therefore, we need to import the LimeTabularExplainer. Also, when working with LIME, it’s probably a good idea to set NumPy’s random seed. LIME leverages NumPy for its backend; hence, setting the random seed to a number of our choice will ensure that we can get repeatable experiments. We can set the random seed using:

np.random.seed(1)

We then create a helper function that takes in our training and testing set, trains a base RandomForestClassifier, and works out its accuracy score. The goal here is not to build the most robust model but rather to get a base model for our interpretations.

def train_model(X_train: np.ndarray,
                X_test: np.ndarray,
                y_train: np.ndarray,
                y_test: np.ndarray) -> RandomForestClassifier:
    model = RandomForestClassifier()
    model.fit(X_train, y_train)
    
    predictions = model.predict(X_test)
    
    print(f'Accuracy: {np.round((accuracy_score(y_test, predictions) * 100), 3)}%')
    return model

We pull in our dataset from Scikit-learn, split it, and train our model as follows:

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, train_size=0.8)
rf = train_model(X_train, X_test, y_train, y_test)

Next, we need to generate our LIME explainer function. Here we need to specify the training data, feature names, class labels, and whether to discretize continuous variables or not.

explainer = LimeTabularExplainer(X_train, 
                                 feature_names=iris.feature_names, 
                                 class_names=iris.target_names, 
                                 discretize_continuous=True)

Now we can generate an explanation for any prediction we want. At this step, we get to control the number of most influential features to show. This can be any integer value between 1 and the number of features in the dataset.

i = np.random.randint(0, X_test.shape[0])
exp = explainer.explain_instance(X_test[i], rf.predict_proba, num_features=2, top_labels=1)
exp.show_in_notebook()

Running this last piece of code for three different predictions, we get:

Image by Author

Image by Author

Image by Author

To the left-hand side of the visualisation, we get the predictive probability distribution per class. On the right-hand side, we get the top 2 (we specified this when we initialised the explainer function above) most influential features for that prediction along with their respective value. In the centre of the plot, we get a condition per influential feature (based on the perturbed inputs) and its strength (i.e. contribution/influence to the model).

For example, in the first prediction, the model predicted the sample to be a Versicolor with 99% confidence. 24% of this score was because the petal length is greater than 1.58cm and another 14% influence was added because of the petal width being greater than 0.3cm.

Conclusion

And that’s really it! The beauty of this package is that it strictly follows the ‘code template’ that we just went over. Even when explaining image or text classifiers. The only different part would be to import the required explainer (in our case we used LimeTabularExplainer because we wanted to interpret tabular data).

One important aspect to keep in mind is that the explainer function can only be as good as the original model that it’s trying to approximate. So when out in the wild, always make sure that the model is robustly trained using cross-validation and properly validated. Then again, LIME can also be used to evaluate the robustness of any given ML model.

LIME is an amazing introduction to the world of eXplainable AI. Both LIME and the domain are continuously growing and maturing which makes now the perfect time to start incorporating XAI in your data modelling pipelines.

References

[1] Ribeiro, M.T., Singh, S. and Guestrin, C., 2016, August. “ Why should I trust you?” Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 1135–1144).

[2] https://github.com/marcotcr/lime

[3] https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html

Article originally posted here by David Farrugia. Reposted with permission.

RELATED ARTICLES

Most Popular

Recent Comments