Heart Disease? Explaining the ML Model | Part 1

Heart Disease

Diagnosing Heart Disease

Using ML Explainability Tools and Techniques

Contents

  1. Introduction
  2. The Data
  3. The Model
  4. The Explanation
  5. Conclusion

Introduction

This dataset gives a number of variables along with a target condition of having or not having heart disease. Below, the data is first used in a simple random forest model, and then the model is investigated using ML explainability tools and techniques.

Of all the applications of machine-learning, diagnosing any serious disease using a black box is always going to be a hard sell. If the output from a model is the particular course of treatment (potentially with side-effects), or surgery, or the absence of treatment, people are going to want to know why.

Learn more in Dan Becker’s course in Kaggle Learn here

First, load the appropriate libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns #for plotting
from sklearn.ensemble import RandomForestClassifier #for the model
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz #plot tree
from sklearn.metrics import roc_curve, auc #for model evaluation
from sklearn.metrics import classification_report #for model evaluation
from sklearn.metrics import confusion_matrix #for model evaluation
from sklearn.model_selection import train_test_split #for data splitting
import eli5 #for purmutation importance
from eli5.sklearn import PermutationImportance
import shap #for SHAP values
from pdpbox import pdp, info_plots #for partial plots
np.random.seed(123) #ensure reproducibility

The Data

Next, load the data,

dt = pd.read_csv("D:/Dataset/heart.csv")

Let’s take a look,

dt.head(10)

It’s a clean, easy to understand set of data. However, the meaning of some of the column headers are not obvious. Here’s what they mean,

Customizing Column Headers of Heart Disease dataset

  • age: The person’s age in years
  • sex: The person’s sex (1 = male, 0 = female)
  • cp: The chest pain experienced (Value 1: typical angina, Value 2: atypical angina, Value 3: non-anginal pain, Value 4: asymptomatic)
  • trestbps: The person’s resting blood pressure (mm Hg on admission to the hospital)
  • chol: The person’s cholesterol measurement in mg/dl
  • fbs: The person’s fasting blood sugar (> 120 mg/dl, 1 = true; 0 = false)
  • restecg: Resting electrocardiographic measurement (0 = normal, 1 = having ST-T wave abnormality, 2 = showing probable or definite left ventricular hypertrophy by Estes’ criteria)
  • thalach: The person’s maximum heart rate achieved
  • exang: Exercise induced angina (1 = yes; 0 = no)
  • oldpeak: ST depression induced by exercise relative to rest (‘ST’ relates to positions on the ECG plot. See more here)
  • slope: the slope of the peak exercise ST segment (Value 1: upsloping, Value 2: flat, Value 3: downsloping)
  • ca: The number of major vessels (0-3)
  • thal: A blood disorder called thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect)
  • target: Heart disease (0 = no, 1 = yes)

To avoid HARKing (or Hypothesizing After the Results are Known) I’m going to take a look at online guides on how heart disease is diagnosed, and look up some of the terms above.

Diagnosis: The diagnosis of heart disease is done on a combination of clinical signs and test results. The types of tests run will be chosen on the basis of what the physician thinks is going on 1, ranging from electrocardiograms and cardiac computerized tomography (CT) scans, to blood tests and exercise stress tests 2.

Conceptual Approach

Looking at information of heart disease risk factors led me to the following: high cholesterol, high blood pressure, diabetes, weight, family history and smoking 3. According to another source 4, the major factors that can’t be changed are: increasing age, male gender and heredity. Note that thalassemia, one of the variables in this dataset, is heredity. Major factors that can be modified are: Smoking, high cholesterol, high blood pressure, physical inactivity, and being overweight and having diabetes. Other factors include stress, alcohol and poor diet/nutrition.

I can see no reference to the ‘number of major vessels’, but given that the definition of heart disease is “…what happens when your heart’s blood supply is blocked or interrupted by a build-up of fatty substances in the coronary arteries”, it seems logical the more major vessels is a good thing, and therefore will reduce the probability of heart disease.

Given the above, I would hypothesis that, if the model has some predictive ability, we’ll see these factors standing out as the most important.

Let’s change the column names to be a bit clearer.

dt.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure', 'cholesterol', 'fasting_blood_sugar', 'rest_ecg', 'max_heart_rate_achieved',
       'exercise_induced_angina', 'st_depression', 'st_slope', 'num_major_vessels', 'thalassemia', 'target']
dt.head()

I’m also going to change the values of the categorical variables, to improve the interpretation later on,

dt['sex'][dt['sex'] == 0] = 'female'
dt['sex'][dt['sex'] == 1] = 'male'

dt['chest_pain_type'][dt['chest_pain_type'] == 1] = 'typical angina'
dt['chest_pain_type'][dt['chest_pain_type'] == 2] = 'atypical angina'
dt['chest_pain_type'][dt['chest_pain_type'] == 3] = 'non-anginal pain'
dt['chest_pain_type'][dt['chest_pain_type'] == 4] = 'asymptomatic'

dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 0] = 'lower than 120mg/ml'
dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 1] = 'greater than 120mg/ml'

dt['rest_ecg'][dt['rest_ecg'] == 0] = 'normal'
dt['rest_ecg'][dt['rest_ecg'] == 1] = 'ST-T wave abnormality'
dt['rest_ecg'][dt['rest_ecg'] == 2] = 'left ventricular hypertrophy'

dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 0] = 'no'
dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 1] = 'yes'

dt['st_slope'][dt['st_slope'] == 1] = 'upsloping'
dt['st_slope'][dt['st_slope'] == 2] = 'flat'
dt['st_slope'][dt['st_slope'] == 3] = 'downsloping'

dt['thalassemia'][dt['thalassemia'] == 1] = 'normal'
dt['thalassemia'][dt['thalassemia'] == 2] = 'fixed defect'
dt['thalassemia'][dt['thalassemia'] == 3] = 'reversable defect'

Check the data types,

dt.dtypes

Some of those aren’t quite right. The code below changes them into categorical variables,

dt['sex'] = dt['sex'].astype('object')
dt['chest_pain_type'] = dt['chest_pain_type'].astype('object')
dt['fasting_blood_sugar'] = dt['fasting_blood_sugar'].astype('object')
dt['rest_ecg'] = dt['rest_ecg'].astype('object')
dt['exercise_induced_angina'] = dt['exercise_induced_angina'].astype('object')
dt['st_slope'] = dt['st_slope'].astype('object')
dt['thalassemia'] = dt['thalassemia'].astype('object')
dt.dtypes

For the categorical varibles, we need to create dummy variables. I’m also going to drop the first category of each. For example, rather than having ‘male’ and ‘female’, we’ll have ‘male’ with values of 0 or 1 (1 being male, and 0 therefore being female).

dt = pd.get_dummies(dt, drop_first=True)

Now let’s see,

dt.head()

Looking good. Now, on to the model.

The Model

The next part fits a random forest model to the data,

X_train, X_test, y_train, y_test = train_test_split(dt.drop('target', 1), dt['target'], test_size = .2, random_state=10) #split the data
model = RandomForestClassifier(max_depth=5)
model.fit(X_train, y_train)
RandomForestClassifier(max_depth=5)

We can plot the consequent decision tree, to see what it’s doing

estimator = model.estimators_[1]
feature_names = [i for i in X_train.columns]

y_train_str = y_train.astype('str')
y_train_str[y_train_str == '0'] = 'no disease'
y_train_str[y_train_str == '1'] = 'disease'
y_train_str = y_train_str.values
#code from https://towardsdatascience.com/how-to-visualize-a-decision-tree-from-a-random-forest-in-python-using-scikit-learn-38ad2d75f21c

export_graphviz(estimator, out_file='tree.dot', 
                feature_names = feature_names,
                class_names = y_train_str,
                rounded = True, proportion = True, 
                label='root',
                precision = 2, filled = True)

from subprocess import call
call(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png', '-Gdpi=600'])

from IPython.display import Image
Image(filename = 'tree.png')

This gives us on explainability tool. However, I can’t glance at this and get a quick sense of the most important features. We’ll revisit those later. Next, let’s evaluate the model,

y_predict = model.predict(X_test)
y_pred_quant = model.predict_proba(X_test)[:, 1]
y_pred_bin = model.predict(X_test)

Assess the fit with a confusion matrix,

confusion_matrix = confusion_matrix(y_test, y_pred_bin)
confusion_matrix
array([[29,  6],
       [ 4, 22]], dtype=int64)

Diagnostic tests are often sold, marketed, cited and used with sensitivity and specificity as the headline metrics. Sensitivity and specificity are defined as,

Let’s see what this model is giving,

total=sum(sum(confusion_matrix))

sensitivity = confusion_matrix[0,0]/(confusion_matrix[0,0]+confusion_matrix[1,0])
print('Sensitivity : ', sensitivity )

specificity = confusion_matrix[1,1]/(confusion_matrix[1,1]+confusion_matrix[0,1])
print('Specificity : ', specificity)
Sensitivity :  0.8787878787878788
Specificity :  0.7857142857142857

That seems reasonable. Let’s also check with a Receiver Operator Curve (ROC),

fpr, tpr, thresholds = roc_curve(y_test, y_pred_quant)

fig, ax = plt.subplots()
ax.plot(fpr, tpr)
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for diabetes classifier')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.grid(True)

Another common metric is the Area Under the Curve, or AUC. This is a convenient way to capture the performance of a model in a single number, although it’s not without certain issues. As a rule of thumb, an AUC can be classed as follows,

  • 0.90 – 1.00 = excellent
  • 0.80 – 0.90 = good
  • 0.70 – 0.80 = fair
  • 0.60 – 0.70 = poor
  • 0.50 – 0.60 = fail

Let’s see what the above ROC gives us,

auc(fpr, tpr)
0.9076923076923078

OK, so it’s working well.

Read The explanation of this Post in Part2

Leave a Reply

%d bloggers like this: