A data science model development pipeline involves various components including data injection, data preprocessing, feature engineering, feature scaling, and modeling. A data scientist needs to write the learning and inference code for all the components. The code structure sometimes becomes messier and difficult to interpret for other team members, for machine learning projects with heterogeneous data.
A pipeline is a very handy function that can sequentially ensemble all your model development components. Using a pipeline one can easily perform the learning and inference tasks in a comparatively cleaner code structure.
In this article, we will discuss how to use scikit-learn pipelines to structure your code using chaining estimators and column transformers while developing an end-to-end machine learning model.
What is Pipeline?
A pipeline can sequentially list all the data processing and feature engineering estimators in a clean code structure. Basically, it chains multiple estimators into one. Pipelines are very convenient to use for learning and inference tasks and avoid data leakage. One can also perform a grid search over parameters of all estimators in the pipeline at once.
I will be developing an end-to-end machine learning model for a binary sample dataset with heterogeneous features. The binary sample dataset has 8 independent features of text, numerical, and categorical data types.
(Image by Author), Snapshot of the sample dataset
Usage:
The sample dataset includes text features (Name), Categorical features (Sex, Embarked), and Numerical features (PClass, Age, SibSp, Parch, Fare).
The raw real-world dataset might contain a lot of missing data values. We can use the SimpleImputer function from the scikit-learn package to impute the missing values. For categorical features, we can chain a one-hot encoder followed by an SVD estimator for feature decomposition.
For text features, we can vectorize the text using Count Vectorizer or Tf-Idf vectorizer to convert the text data into numerical embeddings followed by a dimensionality reduction estimator.
Pipeline 1 (For categoircal features): 1) Most Frequent value Imputer 2) One Hot Encoder 3) Truncated SVD decompositionPipeline 2 (For Text based features): 1) Tf-Idf Vectorizer 2) Truncated SVD decomposition
from sklearn.impute import SimpleImputer from sklearn.preprocessing import OneHotEncoder from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.decomposition import TruncatedSVD from sklearn.pipeline import make_pipeline most_frequent_imputer = SimpleImputer(strategy='most_frequent') onehot_encoder = OneHotEncoder(handle_unknown='ignore') vectorizer = TfidfVectorizer() svd = TruncatedSVD(n_components=2, random_state=42) pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd) pipe2 = make_pipeline(vectorizer, svd)
Column Transformer for Heterogenous data:
The sample dataset contains various data types of features from text datatypes, to float and object datatypes. So each type of feature requires separate feature engineering strategies.
Column Transformer is a scikit-learn function that enables developers to perform different feature engineering and data transformation steps for different sets of features. The good part of column transformers is that one can perform data transformation within the pipeline safe from data leakage issues.
Usage:
I have performed different feature transformation strategies for different sets of features.
- Pipeline 1 for categorical features such as ‘Sex’ and ‘Embarked’
- Pipeline 2 for text-based features such as ‘Name’
- Mean imputer for numerical feature ‘Age’ as it has a lot of missing values.
- The remaining numerical features don’t require any feature transformation, so it can be passed as it is using the ‘passthrough’ keyword
most_frequent_imputer = SimpleImputer(strategy='most_frequent') onehot_encoder = OneHotEncoder(handle_unknown='ignore') vectorizer = CountVectorizer() mean_imputer = SimpleImputer(strategy='mean') svd = TruncatedSVD(n_components=2, random_state=42) pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd) pipe2 = make_pipeline(vectorizer, svd) column_trans = make_column_transformer( (pipe1, ['Sex', 'Embarked']), (pipe2, 'Name'), (mean_imputer, ['Age']), ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass']))
Modeling Pipeline:
After performing the data transformation steps one can move to the model component. I will be using a Logistic Regression estimator to train the transformed dataset. But before moving to the modeling stage we can also include a StandardScaler estimator to standardize the transformed dataset.
pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd) pipe2 = make_pipeline(vectorizer, svd) column_trans = make_column_transformer( (pipe1, ['Sex', 'Embarked']), (pipe2, 'Name'), (mean_imputer, ['Age']), ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass'])) scaler = StandardScaler() classifier = LogisticRegression(random_state=42) # Final Pipeline pipeline = make_pipeline(column_trans, scaler, classifier)
Visualizing the Pipeline:
A visual representation of the entire pipeline is quite easy to interpret the end-to-end flow of the case study. Scikit-learn comes up with a set_config
function that enables the developer to display a diagrammatic representation of the entire end-to-end pipeline.
By default, the set_config
display parameter is ‘text’, which displays a textual format of the entire pipeline. Changing to the ‘diagram’ keyword will make it work.
from sklearn import set_config
set_config(display='diagram')
(Image by Author), Diagrammatic interpretation of the entire pipeline
Learning and Inference:
One can train the model pipeline using the .fit()
function and perform inference using the .predict()
function.
{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8a6bc438", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "\n", "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.decomposition import PCA\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.compose import make_column_transformer\n", "from sklearn.decomposition import PCA\n", "from sklearn.utils import estimator_html_repr\n", "from sklearn.decomposition import TruncatedSVD\n", "from sklearn.metrics import *" ] }, { "cell_type": "code", "execution_count": 14, "id": "8126bc5c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(891, 9)\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived</th>\n", " <th>Pclass</th>\n", " <th>Name</th>\n", " <th>Sex</th>\n", " <th>Age</th>\n", " <th>SibSp</th>\n", " <th>Parch</th>\n", " <th>Fare</th>\n", " <th>Embarked</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Braund, Mr. Owen Harris</td>\n", " <td>male</td>\n", " <td>22.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>7.2500</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", " <td>female</td>\n", " <td>38.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>71.2833</td>\n", " <td>C</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>Heikkinen, Miss. Laina</td>\n", " <td>female</td>\n", " <td>26.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>7.9250</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", " <td>female</td>\n", " <td>35.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>53.1000</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Allen, Mr. William Henry</td>\n", " <td>male</td>\n", " <td>35.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>8.0500</td>\n", " <td>S</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived Pclass Name \\\n", "0 0 3 Braund, Mr. Owen Harris \n", "1 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... \n", "2 1 3 Heikkinen, Miss. Laina \n", "3 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) \n", "4 0 3 Allen, Mr. William Henry \n", "\n", " Sex Age SibSp Parch Fare Embarked \n", "0 male 22.0 1 0 7.2500 S \n", "1 female 38.0 1 0 71.2833 C \n", "2 female 26.0 0 0 7.9250 S \n", "3 female 35.0 1 0 53.1000 S \n", "4 male 35.0 0 0 8.0500 S " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(\"DATA/titanic.csv\")\n", "\n", "columns_to_drop = ['PassengerId','Ticket','Cabin']\n", "df = df.drop(columns_to_drop, axis=1)\n", "\n", "print(df.shape)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "id": "e66a0836", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 10 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Survived 891 non-null int64 \n", " 1 Pclass 891 non-null int64 \n", " 2 Name 891 non-null object \n", " 3 Sex 891 non-null object \n", " 4 Age 714 non-null float64\n", " 5 SibSp 891 non-null int64 \n", " 6 Parch 891 non-null int64 \n", " 7 Fare 891 non-null float64\n", " 8 Cabin 204 non-null object \n", " 9 Embarked 889 non-null object \n", "dtypes: float64(2), int64(4), object(4)\n", "memory usage: 69.7+ KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "id": "55d8c869", "metadata": {}, "source": [ "<B> Feature Transformation: </B>\n", "* Sex, Embarked: Categorical Imputer, OHE, SVD\n", "* Name: Count Vectorizer, PCA\n", "* Age: Simple Imputer\n", "* Fare, SibSP, Parch, Pclass: passthrough\n", "\n", "<B> Feature Preprocessing: </B>\n", "* Standard Scaler\n", "\n", "<B> Modeling: </B>\n", "* Logistic Regression" ] }, { "cell_type": "code", "execution_count": 4, "id": "10c9a630", "metadata": {}, "outputs": [], "source": [ "most_frequent_imputer = SimpleImputer(strategy='most_frequent')\n", "onehot_encoder = OneHotEncoder(handle_unknown='ignore')\n", "vectorizer = CountVectorizer()\n", "mean_imputer = SimpleImputer(strategy='mean')\n", "svd = TruncatedSVD(n_components=2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f2f1e0b2", "metadata": {}, "outputs": [], "source": [ "pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd)\n", "pipe2 = make_pipeline(vectorizer, svd)\n", "\n", "column_trans = make_column_transformer(\n", " (pipe1, ['Sex', 'Embarked']),\n", " (pipe2, 'Name'),\n", " (mean_imputer, ['Age']),\n", " ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass']))\n", "\n", "scaler = StandardScaler()\n", "\n", "classifier = LogisticRegression(random_state=42)\n", "\n", "pipeline = make_pipeline(column_trans, scaler, classifier)" ] }, { "cell_type": "code", "execution_count": 7, "id": "d2d4c09c", "metadata": {}, "outputs": [], "source": [ "from sklearn import set_config\n", "set_config(display='diagram')" ] }, { "cell_type": "code", "execution_count": 8, "id": "4ad6b43c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c86030d8-48ef-4e5a-a584-b68ca5db79c9\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c86030d8-48ef-4e5a-a584-b68ca5db79c9\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('columntransformer',\n", " ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer',\n", " SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch',\n", " 'Pclass'])])),\n", " ('standardscaler', StandardScaler()),\n", " ('logisticregression', LogisticRegression(random_state=42))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"a928be64-e830-4166-8732-bc8b74a6b1e0\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"a928be64-e830-4166-8732-bc8b74a6b1e0\">columntransformer: ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer', SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch', 'Pclass'])])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"bdbba0fb-2b6b-4b5b-9b28-75fb788ff03d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"bdbba0fb-2b6b-4b5b-9b28-75fb788ff03d\">pipeline-1</label><div class=\"sk-toggleable__content\"><pre>['Sex', 'Embarked']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"adefeb60-c9bc-4208-85ce-dd0efb5e1ab8\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"adefeb60-c9bc-4208-85ce-dd0efb5e1ab8\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"74f7c18f-960e-4ce9-8ebd-64bc680d2b80\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"74f7c18f-960e-4ce9-8ebd-64bc680d2b80\">OneHotEncoder</label><div class=\"sk-toggleable__content\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"62c08d75-5927-4b94-b791-41704c1f587d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"62c08d75-5927-4b94-b791-41704c1f587d\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"f143c0ee-284d-4a88-9986-fa601350d0e5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"f143c0ee-284d-4a88-9986-fa601350d0e5\">pipeline-2</label><div class=\"sk-toggleable__content\"><pre>Name</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"7ff0c282-e3cd-4e5e-a74e-47df29826f73\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"7ff0c282-e3cd-4e5e-a74e-47df29826f73\">CountVectorizer</label><div class=\"sk-toggleable__content\"><pre>CountVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"9e43e3bc-31de-4ab1-92a2-0f365399dfef\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"9e43e3bc-31de-4ab1-92a2-0f365399dfef\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"d6d5054c-dea0-4125-8d9e-aa1245afc55a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"d6d5054c-dea0-4125-8d9e-aa1245afc55a\">simpleimputer</label><div class=\"sk-toggleable__content\"><pre>['Age']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"abd8343d-0dc4-4cd3-8418-0ab5efba0d52\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"abd8343d-0dc4-4cd3-8418-0ab5efba0d52\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"e8b597cc-a6f8-43ae-b116-2b880f06a9c7\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"e8b597cc-a6f8-43ae-b116-2b880f06a9c7\">passthrough</label><div class=\"sk-toggleable__content\"><pre>['Fare', 'SibSp', 'Parch', 'Pclass']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"8e4605c6-1dd3-4279-a8ed-5d6c0969024d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"8e4605c6-1dd3-4279-a8ed-5d6c0969024d\">passthrough</label><div class=\"sk-toggleable__content\"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c191c6c9-eb6c-46b3-a5c0-0907a9c4a396\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c191c6c9-eb6c-46b3-a5c0-0907a9c4a396\">StandardScaler</label><div class=\"sk-toggleable__content\"><pre>StandardScaler()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"685d4be8-59bd-455e-b68c-3bc26992979a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"685d4be8-59bd-455e-b68c-3bc26992979a\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(random_state=42)</pre></div></div></div></div></div></div></div>" ], "text/plain": [ "Pipeline(steps=[('columntransformer',\n", " ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer',\n", " SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch',\n", " 'Pclass'])])),\n", " ('standardscaler', StandardScaler()),\n", " ('logisticregression', LogisticRegression(random_state=42))])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline" ] }, { "cell_type": "code", "execution_count": 13, "id": "cac91287", "metadata": {}, "outputs": [], "source": [ "# To export the diagram to a html file\n", "with open('pipeline.html', 'w') as f: \n", " f.write(estimator_html_repr(pipeline))" ] }, { "cell_type": "code", "execution_count": null, "id": "ddd69b20", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "id": "bed7f4f1", "metadata": {}, "outputs": [], "source": [ "# Train Test Split\n", "y = df['Survived']\n", "X = df.drop('Survived', axis=1)\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 10, "id": "752ee4d9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"ff61a100-b9aa-4b9e-b1f4-1f53012b3d01\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"ff61a100-b9aa-4b9e-b1f4-1f53012b3d01\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('columntransformer',\n", " ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer',\n", " SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch',\n", " 'Pclass'])])),\n", " ('standardscaler', StandardScaler()),\n", " ('logisticregression', LogisticRegression(random_state=42))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"2febe663-a70b-485f-968d-8e5e07dea52b\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"2febe663-a70b-485f-968d-8e5e07dea52b\">columntransformer: ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer', SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch', 'Pclass'])])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"a0a2a9f0-4617-4282-9bc9-d7e182b3dd8c\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"a0a2a9f0-4617-4282-9bc9-d7e182b3dd8c\">pipeline-1</label><div class=\"sk-toggleable__content\"><pre>['Sex', 'Embarked']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"5d65213e-b803-43e2-95f1-07a035bc14c5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"5d65213e-b803-43e2-95f1-07a035bc14c5\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"fb4a1a51-cc5a-4989-9653-5e130438405a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"fb4a1a51-cc5a-4989-9653-5e130438405a\">OneHotEncoder</label><div class=\"sk-toggleable__content\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"1979132e-08ae-41fa-9932-888da3da49c1\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"1979132e-08ae-41fa-9932-888da3da49c1\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"5f90320f-f3bb-4baf-ba5e-9292683493c3\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"5f90320f-f3bb-4baf-ba5e-9292683493c3\">pipeline-2</label><div class=\"sk-toggleable__content\"><pre>Name</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"2cce429d-e319-40b2-a178-d6eaeeead992\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"2cce429d-e319-40b2-a178-d6eaeeead992\">CountVectorizer</label><div class=\"sk-toggleable__content\"><pre>CountVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c3df0d2e-9837-4781-afbc-40fd937ebe26\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c3df0d2e-9837-4781-afbc-40fd937ebe26\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"03c69ef3-7831-4308-b14c-fceb9c0bcfe5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"03c69ef3-7831-4308-b14c-fceb9c0bcfe5\">simpleimputer</label><div class=\"sk-toggleable__content\"><pre>['Age']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"b458cbb7-7ad8-4f0e-b403-cca747561cb2\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"b458cbb7-7ad8-4f0e-b403-cca747561cb2\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"72a1f8e7-7136-4559-abd5-15f2f19c0647\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"72a1f8e7-7136-4559-abd5-15f2f19c0647\">passthrough</label><div class=\"sk-toggleable__content\"><pre>['Fare', 'SibSp', 'Parch', 'Pclass']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"777b9d37-3d2f-4785-b76b-f09faaa3f46a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"777b9d37-3d2f-4785-b76b-f09faaa3f46a\">passthrough</label><div class=\"sk-toggleable__content\"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"93e9368f-72b1-49f9-9830-30e10c2f18dc\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"93e9368f-72b1-49f9-9830-30e10c2f18dc\">StandardScaler</label><div class=\"sk-toggleable__content\"><pre>StandardScaler()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"198e6745-c4f7-4118-8fee-f6977a8f6653\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"198e6745-c4f7-4118-8fee-f6977a8f6653\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(random_state=42)</pre></div></div></div></div></div></div></div>" ], "text/plain": [ "Pipeline(steps=[('columntransformer',\n", " ColumnTransformer(transformers=[('pipeline-1',\n", " Pipeline(steps=[('simpleimputer',\n", " SimpleImputer(strategy='most_frequent')),\n", " ('onehotencoder',\n", " OneHotEncoder(handle_unknown='ignore')),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " ['Sex', 'Embarked']),\n", " ('pipeline-2',\n", " Pipeline(steps=[('countvectorizer',\n", " CountVectorizer()),\n", " ('truncatedsvd',\n", " TruncatedSVD(random_state=42))]),\n", " 'Name'),\n", " ('simpleimputer',\n", " SimpleImputer(), ['Age']),\n", " ('passthrough', 'passthrough',\n", " ['Fare', 'SibSp', 'Parch',\n", " 'Pclass'])])),\n", " ('standardscaler', StandardScaler()),\n", " ('logisticregression', LogisticRegression(random_state=42))])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 11, "id": "977d3ce9", "metadata": {}, "outputs": [], "source": [ "def get_performance_metrics(y, y_hat, y_hat_proba):\n", " precision = precision_score(y, y_hat)\n", " recall = recall_score(y, y_hat)\n", " roc_auc = roc_auc_score(y, y_hat_proba)\n", " cm = confusion_matrix(y, y_hat)\n", " \n", " return {'Precision':precision, 'Recall':recall, 'ROC AUC Score':roc_auc, 'CM':cm}" ] }, { "cell_type": "code", "execution_count": 12, "id": "899ea452", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Precision': 0.7397260273972602,\n", " 'Recall': 0.7297297297297297,\n", " 'ROC AUC Score': 0.8720720720720722,\n", " 'CM': array([[86, 19],\n", " [20, 54]], dtype=int64)}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test Performance\n", "get_performance_metrics(y_test, pipeline.predict(X_test), pipeline.predict_proba(X_test)[:,1])" ] }, { "cell_type": "code", "execution_count": 13, "id": "1a844f1e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Precision': 0.8,\n", " 'Recall': 0.7313432835820896,\n", " 'ROC AUC Score': 0.8623268791179239,\n", " 'CM': array([[395, 49],\n", " [ 72, 196]], dtype=int64)}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train Performance\n", "get_performance_metrics(y_train, pipeline.predict(X_train), pipeline.predict_proba(X_train)[:,1])" ] }, { "cell_type": "code", "execution_count": null, "id": "7a154855", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }
Conclusion:
A pipeline is a very handy function that can sequentially ensemble all your estimators into one, and prepare a clean and well-structured code. The learning and inference task becomes very easy to run using the Pipeline. You can display the entire end-to-end pipeline in an interpretable representation.
References:
[1] Scikit-learn documentation: https://scikit-learn.org/stable/modules/compose.html
Article originally posted here. Reposted with permission.