Explainable AI: Language Models

Introduction

Just like a coin, explainability in AI has two faces - one it shows to the developers (who actually build the models) and the other to the users (the end customers). The former face (IE i.e. intrinsic explainability) is a technical indicator to the builder that explains the working of the model. Whereas the latter (EE i.e. extrinsic explainability) is proof to the customers about the model’s predictions. While IE is required for any reasonable model improvement, we need EE for factual confirmation. A simple layman who ends up using the model’s prediction needs to know why is the model suggesting something. This article provides a brief introduction and famous techniques used to infer both types of explainabilities, with a focus on large language models.

Large Language models

As the name suggests, language models (LMs) try to model a language by creating probability distribution over the sequence of tokens (that can be words). This can be applied to the text generation task, as once the model is good enough with the probability, it can start predicting the next words provided some contextual words (prompts). This idea forms the basis of the research in deep learning-based (DL) Natural Language Processing (NLP), where newer and SotA models are becoming better with the probability distribution. And recently several powerful models like GPT-3 are able to generate text which is very hard to distinguish from human-written text!

GPT-3 is used to predict the next words for the prompt 'To be or not to be'. The output is quite impressive!
GPT-3 is used to predict the next words for the prompt 'To be or not to be'. The output is quite impressive!

LLMs are leading the scoreboard for several tasks in NLP. But one existing problem with them (or any DL model in fact) is that they are inherently black boxes i.e. it becomes extremely difficult to explain the predictions. Let us try to shine some grey lights on these black boxes!

Intrinsic explainability

Intrinsic explainability tries to provide technical clarity on why the model made some predictions. Large LMs (LLMs) are mostly DL-based models and while they become better with predictions given more data or network parameters, it comes with a cost of explainability. They are not straight forward explainable and we need to derive ways to infer why a model made the prediction. Below let’s briefly go through some of those ways.

Attention-based

The recent LLMs utilize Transformer-based architecture for enhanced performance. What makes Transformers unique is the presence of multiple attention layers and heads to capture bidirectional and context-aware token representation. Each attention head learns certain patterns that could help in language modeling. These patterns could be positional (attend to the previous word) or lexical (attend to acronyms, synonyms, etc) in nature. And when we combine multiple such heads in multiple layers, the overall model is able to capture sophisticated patterns and provide accurate predictions.

These kinds of internal and low-level patterns can be observed by using attention score visualization tools like bertviz [1]. There could be two major types of visualizations, (1) Attention head visualization - which could be helpful to see how the model attends to certain words with another word in context, and (2) Neuron visualization - which could be used to connect the model behaviors to neuron values.

For the Attention head visualization, you can select the model, the layer, and the head to explore the different attention patterns. One example is shown below. 

Example of Attention head view to detect association of words in the model. Also useful to detect gender biases. [1]
Example of Attention head view to detect association of words in the model. Also useful to detect gender biases. [1]

Here the layer 5 and 11th head of the model (2nd last blue color box from the color palette) seems to incorporate the gender-specific patterns (lexical) in the language. That is why, the attention score for gender-related words (She, girl), (He, boy), (He, Bob), etc is high. As such, this head could be used to identify gender bias as well, which is evident from the associations -  (nurse, She) and (doctor, He). Similar biases can be identified from other heads.

The neuron view is useful to drill down even deeper into the attentions scores to identify the specific neuron leading to certain behaviors in the model. One such example is shown below,

Neuron view of GPT-2 for layer 1, head 10 [1]
Neuron view of GPT-2 for layer 1, head 10 [1]

Here, for layer 1 and head 10, some of the neurons’ value decreases with an increase in distance. This showcases that the particular neurons have incorporated distance patterns!

Saliency methods

Attention presents a very myopic view as they are limited to a certain part of language models. As such they just provide the behavior for that part of the model. As an analogy, it’s like explaining just one line of the complete code. This ideation has led to a debate on whether attention really presents the true explainability of the complete model?! To mitigate these concerns, there has been renewed interest in old, tried, and tested saliency methods [2]. These are,

There are several open-source packages like Ecco, that provide the option to visualize the saliency measure. One example is shown below,

Gradient-based method to visualize the importance of the input words for the next word prediction (India)
Gradient-based method to visualize the importance of the input words for the next word prediction (India)

Here, for the input “New Delhi is the capital of”, LLM returned “India” which is factually correct. But intuitively, we would like to know which certain part of the input motivated the model to make such a prediction. The important measure is highlighted with color in the image, and as visible it gave a high score to “New Delhi” and “capital”!

Extrinsic explainability

The intrinsic measures fall short of covering the “needs” of explainability due to two reasons. Firstly, while it may help the model developer with access to internal metrics, it becomes quite difficult for the stakeholders on the other end of the spectrum to make sense of these complicated numbers. We are talking about the end-user of the model who could be a doctor in healthcare, engineers in mechanical, sales executive in the sales domain, and so on. They care more about factual correctness than the reason behind a model’s prediction. Secondly, even the model developers will be concerned if the internal metrics make sense but the output doesn’t (output is incorrect).

The problem with the existing language models is that they contain all their knowledge within the parameters. GPT-3 is made up of 175 Billion parameters and it has been shown to work on a large number of downstream tasks. It generates the text based on the information it learned and stored in its parameters during the training and fine-tuning phase. Take one task of question answering as an example, where the question is “Who is the Prime Minister of India?”. While the model may answer it correctly now, how about after the next election? The answer is obvious, it will fail. This is because the information gained is constant and would need re-training to calibrate constantly with the latest events in the world (which is a costly affair [3]).

Because of all these reasons, there has been a research interest in enhancing the language models by complimenting them with a connection to a dynamic knowledge base - that is kept up to date just like any other database. Some examples by the leading AI teams in the world are as follows,

Comparing Gopher vs GopherCite
Comparing Gopher vs GopherCite
Final output generated by WebGPT for the question 'How do neural networks work?'. Note the presence of References at the bottom. Taken from the official blog.
Final output generated by WebGPT for the question 'How do neural networks work?'. Note the presence of References at the bottom. Taken from the official blog.
Example of a conversation with Mt. Everest using LaMDA. Taken from blog.
Example of a conversation with Mt. Everest using LaMDA. Taken from blog.

References

[1] A Multiscale Visualization of Attention in the Transformer Model

[2] Why use attention as explanation when we have saliency methods?

[3] Cost of training GPT-3 @ Reddit

[4] Teaching language models to support answers with verified quotes

[5] WebGPT: Browser-assisted question-answering with human feedback

[6] LaMDA: Language Models for Dialog Applications

Cheers.