Diagnosing Heart Disease
Using ML Explainability Tools and Techniques
- The Data
- The Model
- The Explanation
This dataset gives a number of variables along with a target condition of having or not having heart disease. Below, the data is first used in a simple random forest model, and then the model is investigated using ML explainability tools and techniques.
Of all the applications of machine-learning, diagnosing any serious disease using a black box is always going to be a hard sell. If the output from a model is the particular course of treatment (potentially with side-effects), or surgery, or the absence of treatment, people are going to want to know why.
Learn more in Dan Becker’s course in Kaggle Learn here
First, load the appropriate libraries
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns #for plotting from sklearn.ensemble import RandomForestClassifier #for the model from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz #plot tree from sklearn.metrics import roc_curve, auc #for model evaluation from sklearn.metrics import classification_report #for model evaluation from sklearn.metrics import confusion_matrix #for model evaluation from sklearn.model_selection import train_test_split #for data splitting import eli5 #for purmutation importance from eli5.sklearn import PermutationImportance import shap #for SHAP values from pdpbox import pdp, info_plots #for partial plots np.random.seed(123) #ensure reproducibility
Next, load the data,
dt = pd.read_csv("D:/Dataset/heart.csv")
Let’s take a look,
It’s a clean, easy to understand set of data. However, the meaning of some of the column headers are not obvious. Here’s what they mean,
Customizing Column Headers of Heart Disease dataset
- age: The person’s age in years
- sex: The person’s sex (1 = male, 0 = female)
- cp: The chest pain experienced (Value 1: typical angina, Value 2: atypical angina, Value 3: non-anginal pain, Value 4: asymptomatic)
- trestbps: The person’s resting blood pressure (mm Hg on admission to the hospital)
- chol: The person’s cholesterol measurement in mg/dl
- fbs: The person’s fasting blood sugar (> 120 mg/dl, 1 = true; 0 = false)
- restecg: Resting electrocardiographic measurement (0 = normal, 1 = having ST-T wave abnormality, 2 = showing probable or definite left ventricular hypertrophy by Estes’ criteria)
- thalach: The person’s maximum heart rate achieved
- exang: Exercise induced angina (1 = yes; 0 = no)
- oldpeak: ST depression induced by exercise relative to rest (‘ST’ relates to positions on the ECG plot. See more here)
- slope: the slope of the peak exercise ST segment (Value 1: upsloping, Value 2: flat, Value 3: downsloping)
- ca: The number of major vessels (0-3)
- thal: A blood disorder called thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect)
- target: Heart disease (0 = no, 1 = yes)
To avoid HARKing (or Hypothesizing After the Results are Known) I’m going to take a look at online guides on how heart disease is diagnosed, and look up some of the terms above.
Diagnosis: The diagnosis of heart disease is done on a combination of clinical signs and test results. The types of tests run will be chosen on the basis of what the physician thinks is going on 1, ranging from electrocardiograms and cardiac computerized tomography (CT) scans, to blood tests and exercise stress tests 2.
Looking at information of heart disease risk factors led me to the following: high cholesterol, high blood pressure, diabetes, weight, family history and smoking 3. According to another source 4, the major factors that can’t be changed are: increasing age, male gender and heredity. Note that thalassemia, one of the variables in this dataset, is heredity. Major factors that can be modified are: Smoking, high cholesterol, high blood pressure, physical inactivity, and being overweight and having diabetes. Other factors include stress, alcohol and poor diet/nutrition.
I can see no reference to the ‘number of major vessels’, but given that the definition of heart disease is “…what happens when your heart’s blood supply is blocked or interrupted by a build-up of fatty substances in the coronary arteries”, it seems logical the more major vessels is a good thing, and therefore will reduce the probability of heart disease.
Given the above, I would hypothesis that, if the model has some predictive ability, we’ll see these factors standing out as the most important.
Let’s change the column names to be a bit clearer.
dt.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure', 'cholesterol', 'fasting_blood_sugar', 'rest_ecg', 'max_heart_rate_achieved', 'exercise_induced_angina', 'st_depression', 'st_slope', 'num_major_vessels', 'thalassemia', 'target']
I’m also going to change the values of the categorical variables, to improve the interpretation later on,
dt['sex'][dt['sex'] == 0] = 'female' dt['sex'][dt['sex'] == 1] = 'male' dt['chest_pain_type'][dt['chest_pain_type'] == 1] = 'typical angina' dt['chest_pain_type'][dt['chest_pain_type'] == 2] = 'atypical angina' dt['chest_pain_type'][dt['chest_pain_type'] == 3] = 'non-anginal pain' dt['chest_pain_type'][dt['chest_pain_type'] == 4] = 'asymptomatic' dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 0] = 'lower than 120mg/ml' dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 1] = 'greater than 120mg/ml' dt['rest_ecg'][dt['rest_ecg'] == 0] = 'normal' dt['rest_ecg'][dt['rest_ecg'] == 1] = 'ST-T wave abnormality' dt['rest_ecg'][dt['rest_ecg'] == 2] = 'left ventricular hypertrophy' dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 0] = 'no' dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 1] = 'yes' dt['st_slope'][dt['st_slope'] == 1] = 'upsloping' dt['st_slope'][dt['st_slope'] == 2] = 'flat' dt['st_slope'][dt['st_slope'] == 3] = 'downsloping' dt['thalassemia'][dt['thalassemia'] == 1] = 'normal' dt['thalassemia'][dt['thalassemia'] == 2] = 'fixed defect' dt['thalassemia'][dt['thalassemia'] == 3] = 'reversable defect'
Check the data types,
Some of those aren’t quite right. The code below changes them into categorical variables,
dt['sex'] = dt['sex'].astype('object') dt['chest_pain_type'] = dt['chest_pain_type'].astype('object') dt['fasting_blood_sugar'] = dt['fasting_blood_sugar'].astype('object') dt['rest_ecg'] = dt['rest_ecg'].astype('object') dt['exercise_induced_angina'] = dt['exercise_induced_angina'].astype('object') dt['st_slope'] = dt['st_slope'].astype('object') dt['thalassemia'] = dt['thalassemia'].astype('object') dt.dtypes
For the categorical varibles, we need to create dummy variables. I’m also going to drop the first category of each. For example, rather than having ‘male’ and ‘female’, we’ll have ‘male’ with values of 0 or 1 (1 being male, and 0 therefore being female).
dt = pd.get_dummies(dt, drop_first=True)
Now let’s see,
Looking good. Now, on to the model.
The next part fits a random forest model to the data,
X_train, X_test, y_train, y_test = train_test_split(dt.drop('target', 1), dt['target'], test_size = .2, random_state=10) #split the data
model = RandomForestClassifier(max_depth=5) model.fit(X_train, y_train)
We can plot the consequent decision tree, to see what it’s doing
estimator = model.estimators_ feature_names = [i for i in X_train.columns] y_train_str = y_train.astype('str') y_train_str[y_train_str == '0'] = 'no disease' y_train_str[y_train_str == '1'] = 'disease' y_train_str = y_train_str.values
#code from https://towardsdatascience.com/how-to-visualize-a-decision-tree-from-a-random-forest-in-python-using-scikit-learn-38ad2d75f21c export_graphviz(estimator, out_file='tree.dot', feature_names = feature_names, class_names = y_train_str, rounded = True, proportion = True, label='root', precision = 2, filled = True) from subprocess import call call(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png', '-Gdpi=600']) from IPython.display import Image Image(filename = 'tree.png')
This gives us on explainability tool. However, I can’t glance at this and get a quick sense of the most important features. We’ll revisit those later. Next, let’s evaluate the model,
y_predict = model.predict(X_test) y_pred_quant = model.predict_proba(X_test)[:, 1] y_pred_bin = model.predict(X_test)
Assess the fit with a confusion matrix,
confusion_matrix = confusion_matrix(y_test, y_pred_bin) confusion_matrix
array([[29, 6], [ 4, 22]], dtype=int64)
Diagnostic tests are often sold, marketed, cited and used with sensitivity and specificity as the headline metrics. Sensitivity and specificity are defined as,
Let’s see what this model is giving,
total=sum(sum(confusion_matrix)) sensitivity = confusion_matrix[0,0]/(confusion_matrix[0,0]+confusion_matrix[1,0]) print('Sensitivity : ', sensitivity ) specificity = confusion_matrix[1,1]/(confusion_matrix[1,1]+confusion_matrix[0,1]) print('Specificity : ', specificity)
Sensitivity : 0.8787878787878788 Specificity : 0.7857142857142857
That seems reasonable. Let’s also check with a Receiver Operator Curve (ROC),
fpr, tpr, thresholds = roc_curve(y_test, y_pred_quant) fig, ax = plt.subplots() ax.plot(fpr, tpr) ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3") plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.rcParams['font.size'] = 12 plt.title('ROC curve for diabetes classifier') plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.grid(True)
Another common metric is the Area Under the Curve, or AUC. This is a convenient way to capture the performance of a model in a single number, although it’s not without certain issues. As a rule of thumb, an AUC can be classed as follows,
- 0.90 – 1.00 = excellent
- 0.80 – 0.90 = good
- 0.70 – 0.80 = fair
- 0.60 – 0.70 = poor
- 0.50 – 0.60 = fail
Let’s see what the above ROC gives us,
OK, so it’s working well.
Read The explanation of this Post in Part2