Build Question Answering System with BERT model

build-question-answering-with-bert-model-in-python

We all remember that during our school days, teachers used to assign homework tasks, and the next day they would ask questions on those topics. BERT is built to do the exact same type of question answering job for us. We can feed a passage of text to BERT, and it can help provide answers to questions based on the information in that passage.

In this tutorial, I will guide you to create a question answering system with BERT model in Python.

What is BERT?

Bidirectional Encoder Representations from Transformers (in short BERT) was introduced by researchers at Google in 2018. BERT is a transformer-based model to solve various natural language processing (NLP) tasks like sentiment analysis, text classification, named entity recognition, and Question Answering.

BERT is a new type of language model that works differently from older language models. Earlier models used to read the text in one direction, either from left to right or right to left (for example Skip-gram and CBOW).

But BERT is special because it reads the entire text in both directions. This means it understands the words not just based on what comes before them, but also what comes after.

By doing this, BERT can capture the meaning of each word based on the context. This makes BERT very good at understanding how words relate to each other in a sentence. This is the reason BERT is way better than any other traditional language model like Word2Vec or RNN.

Why BERT for Question Answering?

BERT is particularly designed for tasks like question answering or chatbot. This is because it can understand contextual information of input text.

Understanding context is most important to answer certain question from a given text document or paragraph.

Install Libraries

Let’s first install some required Python libraries to make our simple QnA (question answering) system.

Create Virtual Environment

Before installing any packages, it will be a good practice to make an isolated virtual environment in Python. Below command is to do that using Anaconda.

conda create -n nlp_test python=3.9
activate nlp_test

Install torch

Since transformer model is built on top of torch framework, so we need to install it. Torch has two versions: GPU and CPU. You can find your torch installation command in this link.

I am using windows system with NVIDIA GPU. My command was like below.

# For GPU
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install cudatoolkit

To install torch for CPU you just need to run below command.

conda install pytorch torchvision torchaudio cpuonly -c pytorch

Install Transformer

Python has package called transformer to utilize all transformer-based models. We need to install that package using below command.

pip install transformers

That’s all. We have installed all our required libraries to implement question answering system in Python with BERT model.

BERT question answering Example

Now let’s see how do you create a question answering system using BERT model in Python. I will break the entire process of making QnA with Bert model into some steps.

Step1: Load BERT Question Answering model

For this demo project, I am going to use a pre-trained BERT model. Which is fine-tuned on the SQuAD (Stanford Question Answering Dataset) task.

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

# Download Pre-trained Model from Huggingface
bert_model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

# Tokenizer
bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

Above code will download pre-trained Bert question answering model from huggingface. It will save all model files in the default huggingface cache location. For me, it was: C:\Users\Anindya.cache

Step2: Define Question and Passage

Now let’s define a paragraph passage and a question. Note BERT model can process maximum of 512 tokens at a time. So keep your paragraph under 512 words.

question = '''What is the date of birth ronaldo'''

paragraph = '''

Cristiano Ronaldo dos Santos Aveiro was born on February 5, 1985, in Funchal, the capital of Madeira, Portugal. 
He spent his early years in the neighboring parish of Santo António. 
Ronaldo is the youngest of four children born to Maria Dolores dos Santos Viveiros da Aveiro, a cook, 
and José Dinis Aveiro, a municipal gardener and part-time kit man. His family heritage includes his great-grandmother, 
Isabel da Piedade, who hailed from São Vicente, Cape Verde. 
Ronaldo has an older brother named Hugo and two older sisters, Elma and Liliana Cátia "Katia." 
Ronaldo's mother revealed that she considered aborting him due to their challenging circumstances, 
including poverty, his father's alcoholism, and already having a large family. However, 
her doctor declined to perform the procedure as abortions were illegal in Portugal at that time. 
Ronaldo grew up in a modest Catholic Christian household, sharing a room with his siblings, amidst financial difficulties.

'''

Step3: Encode the question and paragraph

As you know to apply any machine learning model to text data we need to convert it to numeric values. This is the basics of any NLP model.

Also Read:  How to download NLTK corpus manually

There are various ways to convert text to numeric values like: one-hot encoding, word2vec, fastText, TF-IDF vectorizer, etc.

BERT has its own encoder layer to convert text to numeric values. We are going to use this default encoder.

# Encode the question and paragraph using BERT tokenizer
encoding = bert_tokenizer.encode_plus(text=question, text_pair=paragraph)

# Token embeddings
token_ids = encoding['input_ids']
# input tokens
tokens = bert_tokenizer.convert_ids_to_tokens(token_ids)

# Segment embeddings
sentence_embedding = encoding['token_type_ids']

Here, we are using bert_tokenizer.encode_plus() function to convert input text to numeric values (or encoding question and paragraph).

If we print only encoding, it should look like below:

print(encoding)
{'input_ids': [101, 2054, 2003, 1996, 3058, 1997, 4182, 8923, 2080, 102, 13675, 2923, 15668, 8923, 2080, 9998, 11053, 13642, 9711, 2001, 2141, 2006, 2337, 1019, 1010, 3106, 1010, 1999, 4569, 18598, 1010, 1996, 3007, 1997, 27309, 1010, 5978, 1012, 2002, 2985, 2010, 2220, 2086, 1999, 1996, 8581, 3583, 1997, 11685, 4980, 1012, 8923, 2080, 2003, 1996, 6587, 1997, 2176, 2336, 2141, 2000, 3814, 21544, 9998, 11053, 6819, 3726, 9711, 2015, 4830, 13642, 9711, 1010, 1037, 5660, 1010, 1998, 4560, 11586, 2483, 13642, 9711, 1010, 1037, 4546, 19785, 1998, 2112, 1011, 2051, 8934, 2158, 1012, 2010, 2155, 4348, 2950, 2010, 2307, 1011, 7133, 1010, 11648, 4830, 11345, 14697, 2063, 1010, 2040, 16586, 2013, 7509, 17280, 1010, 4880, 16184, 1012, 8923, 2080, 2038, 2019, 3080, 2567, 2315, 9395, 1998, 2048, 3080, 5208, 1010, 17709, 2050, 1998, 13451, 11410, 4937, 2401, 1000, 10645, 2401, 1012, 1000, 8923, 2080, 1005, 1055, 2388, 3936, 2008, 2016, 2641, 11113, 11589, 2075, 2032, 2349, 2000, 2037, 10368, 6214, 1010, 2164, 5635, 1010, 2010, 2269, 1005, 1055, 25519, 1010, 1998, 2525, 2383, 1037, 2312, 2155, 1012, 2174, 1010, 2014, 3460, 6430, 2000, 4685, 1996, 7709, 2004, 11324, 2015, 2020, 6206, 1999, 5978, 2012, 2008, 2051, 1012, 8923, 2080, 3473, 2039, 1999, 1037, 10754, 3234, 3017, 4398, 1010, 6631, 1037, 2282, 2007, 2010, 9504, 1010, 17171, 3361, 8190, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

As you can see the encoding is a dictionary containing encoded input for the BERT model. There are mainly three keys: input_ids, token_type_ids, and attention_mask. Let’s understand those separately.

Also Read:  Latent Dirichlet Allocation for Beginners: A high level overview
input_ids

input_ids is a list of integers representing the tokenized input sequence. Each integer corresponds to a specific token in the input. For example, 2054 represents the token “What”, 2003 represents “is” and so on.

Note:

101: This is the token id for the [CLS] (classification) token. It is used to mark the beginning of the input sequence.

102: It is the token id for the [SEP] (separator) token. A separator ([SEP]) is used to mark the separation between two different sentences or the end of a single sentence.

For question-answering tasks, 102 or [SEP] token separates the question from the paragraph or context. You can see values of input_ids start with token ids for question text and then token ids of paragraph text separated by 102 or [SEP] token.

If we print tokens variable, our understanding will be clear. So let’s print it.

print(tokens)
['[CLS]', 'what', 'is', 'the', 'date', 'of', 'birth', 'ronald', '##o', '[SEP]', 'cr', '##ist', '##iano', 'ronald', '##o', 'dos', 'santos', 'ave', '##iro', 'was', 'born', 'on', 'february', '5', ',', '1985', ',', 'in', 'fun', '##chal', ',', 'the', 'capital', 'of', 'madeira', ',', 'portugal', '.', 'he', 'spent', 'his', 'early', 'years', 'in', 'the', 'neighboring', 'parish', 'of', 'santo', 'antonio', '.', 'ronald', '##o', 'is', 'the', 'youngest', 'of', 'four', 'children', 'born', 'to', 'maria', 'dolores', 'dos', 'santos', 'vi', '##ve', '##iro', '##s', 'da', 'ave', '##iro', ',', 'a', 'cook', ',', 'and', 'jose', 'din', '##is', 'ave', '##iro', ',', 'a', 'municipal', 'gardener', 'and', 'part', '-', 'time', 'kit', 'man', '.', 'his', 'family', 'heritage', 'includes', 'his', 'great', '-', 'grandmother', ',', 'isabel', 'da', 'pie', '##dad', '##e', ',', 'who', 'hailed', 'from', 'sao', 'vicente', ',', 'cape', 'verde', '.', 'ronald', '##o', 'has', 'an', 'older', 'brother', 'named', 'hugo', 'and', 'two', 'older', 'sisters', ',', 'elm', '##a', 'and', 'lil', '##iana', 'cat', '##ia', '"', 'kat', '##ia', '.', '"', 'ronald', '##o', "'", 's', 'mother', 'revealed', 'that', 'she', 'considered', 'ab', '##ort', '##ing', 'him', 'due', 'to', 'their', 'challenging', 'circumstances', ',', 'including', 'poverty', ',', 'his', 'father', "'", 's', 'alcoholism', ',', 'and', 'already', 'having', 'a', 'large', 'family', '.', 'however', ',', 'her', 'doctor', 'declined', 'to', 'perform', 'the', 'procedure', 'as', 'abortion', '##s', 'were', 'illegal', 'in', 'portugal', 'at', 'that', 'time', '.', 'ronald', '##o', 'grew', 'up', 'in', 'a', 'modest', 'catholic', 'christian', 'household', ',', 'sharing', 'a', 'room', 'with', 'his', 'siblings', ',', 'amidst', 'financial', 'difficulties', '.', '[SEP]']
token_type_ids

This is a list of integers containing values of either 0 or 1. Here 0 means token from question text and 1 means token from paragraph text. It serves as a marker to indicate which segment each token belongs to within the input sequence.

Also Read:  Sentiment Analysis using VADER in Python

In our code, we are capturing token_type_ids in sentence_embedding variable.

attention_mask

Attention mask is used to specify which tokens in the input sequence should be attended to or used by the model and which tokens should be ignored. A value of 1 indicates that the corresponding token should be attended to, while a value of 0 indicates that the token should be ignored.

In our case, all the values in the attention_mask list are 1, which means that all tokens in the input sequence should be attended to by the model during processing.

Step4: Pass the Encoded input to the BERT model

So we converted our input text data to Bert encoding. Now we need to pass this encoding to the BERT model, below Python code is to do that.

# Pass the encoded input to the BERT model
bert_out = bert_model(torch.tensor([token_ids]), token_type_ids=torch.tensor([sentence_embedding]))
Step5: Retrieve the start and end token index of the answer

The output from the BERT model (bert_out) contains the predicted start and end logits of answer. This is nothing but start and end index of the input paragraph tokens which is predicted as answer by the BERT model. We can extract those indices using torch.argmax() function.

# Retrieve the start and end indices of the answer
start_index = torch.argmax(bert_out['start_logits'])

end_index = torch.argmax(bert_out['end_logits'])

print('start_index: ', start_index, 'end_index: ', end_index)
start_index:  tensor(22) end_index:  tensor(25)

Step6: Extract predicted Answer

Since we have start and end index for the predicted answer token, we can easily extract those and make our final answer.

# Extract the predicted answer
answer = ' '.join(tokens[start_index:end_index + 1])
print(answer)
february 5 , 1985

Step7: Cleaning Answer Text

In our example, we are getting proper answer in above step. But in some cases, you may get some unwanted tokens like ##, [sep], [cls], etc. in your predicted answer.

This step helps refine and present the answer in a more readable and meaningful format. Below is the Python code to do that.

# Clean the answer
final_answer = ''

for word in answer.split():

    # If it's a subword token
    if word[0:2] == '##':
        final_answer += word[2:]

    elif answer.startswith("[CLS]") or answer.startswith("[SEP]"):
        final_answer = "Unable to find the answer to your question."

    else:
        final_answer += ' ' + word

print(final_answer)
february 5 , 1985

This is the final answer for our question “What is the date of birth Ronaldo”.

End Note

In this tutorial, we explored how we can use BERT model to build our own question-answering system. We noticed that how accurately BERT is finding answers from the input paragraph.

We also noticed that BERT has some limitations. It can only process maximum of 512 tokens at a time. So, if you are working with some document with huge number of words, BERT can not process that document. I will make a separate tutorial to handle such huge document with BERT.

This is it for this tutorial. If you have any questions or suggestions regarding this tutorial, please shoot those in the comment section below.

Similar Read:

Leave a comment