You might have already learned how to build a Decision-Tree Classifier, but might be wondering how the scikit-learn actually does that. So, in this article, we will cover this in a step-by-step manner. You can run the code in sequence, for better understanding.
Decision-Tree uses tree-splitting criteria for splitting the nodes into sub-nodes until each splitting becomes pure with respect to the classes or targets. In each splitting, to know the purity of splitting we calculate a Gini Impurity/Gini Index, which ranges between 0 & 1. For pure splitting, the Gini Index is 0, and the tree stops there, or else the tree continues to split for the non-zero values. The splitting criteria are chosen by an algorithm, such that the Gini index always remains minimum for each split. This algorithm is also called CART (Classification and Regression Trees). This can also be done by calculating Entropy instead of Gini Impurity. To extract the decision rules from the decision tree we use the sci-kit-learn library. Let’s see by an example, how this is done.
Importing Libraries
Python libraries make it very easy for us to handle the data and perform typical and complex tasks with a single line of code.
- Pandas – This library helps to load the data frame in a 2D array format and has multiple functions to perform analysis tasks in one go.
- Numpy – Numpy arrays are very fast and can perform large computations in a very short time.
- Matplotlib– This library is used to draw visualizations.
- Sklearn – This module contains multiple libraries having pre-implemented functions to perform tasks from data preprocessing to model development and evaluation.
Python3
import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import LabelEncoder from sklearn.tree import DecisionTreeClassifier from sklearn import tree |
Importing Dataset
In this example, we will look into a binary-classification problem, where the label for the salary of an employee greater than 100k is 1 and less than 100k is 0. The code to load the data is as follows:
Python3
url = 'https: / / raw.githubusercontent.com / \ Stitaprajna / Practise - codes / main / salaries.csv' df = pd.read_csv(url) df |
Output:
Next, we will build a Decision-tree using sci-kit-learn, and then finally we will plot the tree using sklearn.tree.plot_tree(). The code is as follows:
Python3
# Setting the Input input = df.drop( 'salary_more_then_100k' , axis = 1 ) # Label Encoding the Ordinal columns in # the dataset le = LabelEncoder() input [ 'company_n' ] = le.fit_transform( input [ 'company' ]) input [ 'jobs_n' ] = le.fit_transform( input [ 'job' ]) input [ 'degree_n' ] = le.fit_transform( input [ 'degree' ]) # Building the Descision-Tree input_n = input [[ 'company_n' , 'jobs_n' , 'degree_n' ]] model = DecisionTreeClassifier() model.fit(input_n, df.salary_more_then_100k) # Creating the tree plot tree.plot_tree(model, filled = True ) plt.rcParams[ 'figure.figsize' ] = [ 10 , 10 ] |
Output:
Now, let’s try to understand this diagram and try to extract the decision rules used by sklearn for splitting.
- If we begin from the Root Node, which is the topmost light blue box, you can see that the splitting criteria are X[0] <= 0.5, with a Gini index of 0.469. So, there will be splitting into two categories based on the fact that, if it satisfies the splitting criteria of X[0] or not. The ones that satisfy the criteria go to the right & the others go left of the tree.
- Now, we have two nodes and again based on the Gini Index the splitting happens. If the Gini Index is 0.0, the grouped data is considered pure and no further splitting occurs.
- You might be able to see the ‘deep orange’ and ‘deep blue’ boxes, with zero Gini index, these are actually the pure class (0 or 1). The color is ‘blue’ or ‘orange’ based on the fact that it’s either 0 or 1 class (salary_more_then_100k). The density of the color defines its purity, and if the color is ‘white’, then the grouped data has equal numbers of 0 and 1 classes.