Tutorial: Doc2Vec and t-SNE

This post shows a tutorial of using doc2vec and the t-SNE visualization in Python for disease clustering. Of course, these tutorial codes can be used for any other types of inputs (e.g., movie reviews, product reviews, etc.). Disease ontology is used as input data for this tutorial. This ontology includes the list of diseases and short definition for each disease. Our doc2vec will use the definition text to represent vectors of diseases. The input data is in the link.

Doc2vec (Quoc Le and Tomas Mikolov), an extension of word2vec, is used to generate representation vectors of chunks of text (i.e., sentences, paragraphs, documents, etc.) as well as words. Doc2vec in Gensim, which is a topic modeling python library, is used to train a model. The t-SNE in scikit-learn is used for visualization.

Preparing Data

First of all, let’s take a look at the data.

It has the list of diseases that have multiple properties. I will use id, name, def and is_a for our doc2vec model and visualization. Now, let’s load some libraries what we need.

import pandas as pd

from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from nltk.tokenize import word_tokenize

The parser function parses the input data into a pandas dataframe. You can find the implementation of this function in the bottom of this post. I just want to focus on the doc2vec training first.

input_file = "doid.obo"
df_dx = parser(input_file)

The parsed data looks like:

is_a is a list of higher level disease. This dataframe includes 6,256 diseases that have definition. Now, let’s prepare the input data of our doc2vec model.

list_id = list(df_dx["id"])
list_def = list(df_dx["def"])

tagged_data = [TaggedDocument(words=word_tokenize(term_def.lower()), tags=[list_id[i]]) for i, term_def in enumerate(list_def)]

The definitions are used as input text of our model. A tag means a label of a definition. A diseases id is used as the tag for its definition (note: the tag should be string.)

Training a Doc2Vec Model

max_epochs = 500
vec_size = 100
alpha = 0.025

model = Doc2Vec(vector_size=vec_size,
                alpha=alpha, 
                min_alpha=0.00025,
                min_count=1,
                dm=1)
  
model.build_vocab(tagged_data)

This is an initialization of the model with some hyper-parameters. You can find details of parameters in the link. In this case, the vector size is 100 and the number of training epochs is 500. Please, feel free to try different values to produce better performances. The following code trains the doc2vec model and saves it:

for epoch in range(max_epochs):
    if epoch % 100 == 0:
        print('iteration {0}'.format(epoch))

    model.train(tagged_data,
                total_examples=model.corpus_count,
                epochs=model.epochs)
    
    model.alpha -= 0.0002
    model.min_alpha = model.alpha

model.save("d2v_do_v100_e500.model")

We trained the model and saved it! Now, let’s try find most similar words with “cancer.” Simply, we can find the closed words and their similarity score to an interesting keyword in the vector space by calling wv.most_similar(). As you can see, the result shows diseases or words related with cancer. Please, feel free to find other useful methods in the link and play with them.

Visualization with t-SNE

t-Distributed Stochastic Neighbor Embedding (t-SNE) is a technique for dimensionality reduction that is particularly well suited for the visualization of high-dimensional vectors. We will see how t-SNE shows our trained vectors of diseases in 2-dimensional vector space.

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

doc_tags = list(model.docvecs.doctags.keys())
X = model[doc_tags]

This extracts the tags (i.e., disease names) and their vectors from the trained model.

tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(X)
df = pd.DataFrame(X_tsne, index=doc_tags, columns=['x', 'y'])

This t-SNE reduces our trained vectors (i.e., 100-dimensional) into 2-dimensional vectors, and we store the reduced vectors into a pandas dataframe.

The following scatter plots show interesting results. The diseases, which have a keyword (e.g., cancer, allergy, etc.) in their name or definition, are shown in red, and the rest of them are shown in blue. The implementation of plotScatter() is shown in the bottom of this post.

# Allergy
plotScatter(keyword="allergy")

Allergies and allergy related diseases tend to appear in the center.

# Cancer
plotScatter(keyword="cancer")

Most of cancers are shown in the right and bottom.

# Syndrome
plotScatter(keyword="syndrome")

Many syndromes are in the top.

The following plots show diseases that are related with fever, cough, virus, bacterial or diarrhea. We could easily think of that these diseases are related to each other in somehow. The plots show that they have a similar distribution.

# Fever
plotScatter(keyword="fever")
# Cough
plotScatter(keyword="cough")
# Virus
plotScatter(keyword="virus")
# Bacterial
plotScatter(keyword="bacterial")
# Diarrhea
plotScatter(keyword="diarrhea")

Summary

We have looked at the process of the doc2vec training and the t-SNE visualization with diseases data. The doc2vec technique shows its potential in NLP for bioinformatics. The disease ontology we have used in this tutorial has the description (i.e., definition) for each disease. However, most of them are short (e.g., only one sentence), and around half of diseases have missing definition. We might have better clustering results if we ran doc2vec with longer descriptions. The followings show the implementation of parser() and plotScatter().

def parser(input_file):
    df_dx = pd.DataFrame(columns=['id', 'name', 'def', 'is_a'])

    with open(input_file, "r") as f:
        term_id = term_name = term_def = None
        term_is_a = []

        for line in f:
            line = line.rstrip('\n')

            if "[Term]" in line or "[Typedef]" in line:            
                if term_def:
                    df_dx.loc[len(df_dx)] = [term_id, term_name, term_def, term_is_a]

                term_id = term_name = term_def = None
                term_is_a = []
            elif "id: " == line[0:4]:
                term_id = line.split("id: ")[1]
            elif "name: " in line:
                term_name = line.split("name: ")[1]
            elif "def: " in line:
                temp = line.split("def: ")[1]
                term_def = temp.split("\"")[1]
            elif "is_a: " in line:
                temp = line.split("is_a: ")[1]
                temp = temp.split(" ! ")[0]
                term_is_a.append(temp)

    return df_dx
def plotScatter(keyword):
    fig = plt.figure(figsize=(10,15))
    ax = fig.add_subplot(1, 1, 1)

    pos_found_x = []
    pos_found_y = []
    found_names = []

    pos_rest_x = []
    pos_rest_y = []

    for term_id, pos in df.iterrows():
        term_name = df_dx[df_dx['id'] == term_id]['name'].values[0]
        term_def = df_dx[df_dx['id'] == term_id]['def'].values[0].lower()

        if keyword in term_name:
            pos_found_x.append(pos['x'])
            pos_found_y.append(pos['y'])
        elif keyword in term_def:
            pos_found_x.append(pos['x'])
            pos_found_y.append(pos['y'])
        else:
            found = False
            is_a_list = df_dx[df_dx['id'] == term_id]['is_a'].tolist()[0]

            for is_a_id in is_a_list:
                if len(df_dx[df_dx['id'] == is_a_id]) > 0:
                    if keyword in df_dx[df_dx['id'] == is_a_id]['name'].values[0]:
                        pos_found_x.append(pos['x'])
                        pos_found_y.append(pos['y'])
                        found = True
                        break
                    elif keyword in df_dx[df_dx['id'] == is_a_id]['def'].values[0].lower():
                        pos_found_x.append(pos['x'])
                        pos_found_y.append(pos['y'])
                        found = True
                        break

            if found == False:
                pos_rest_x.append(pos['x'])
                pos_rest_y.append(pos['y']) 

    ax.scatter(pos_rest_x, pos_rest_y, c='blue')       
    ax.scatter(pos_found_x, pos_found_y, c='red')