Saturday, November 16, 2024
Google search engine
HomeData Modelling & AIInteractive Pipeline and Composite Estimators for Your End-to-End ML Model

Interactive Pipeline and Composite Estimators for Your End-to-End ML Model

A data science model development pipeline involves various components including data injection, data preprocessing, feature engineering, feature scaling, and modeling. A data scientist needs to write the learning and inference code for all the components. The code structure sometimes becomes messier and difficult to interpret for other team members, for machine learning projects with heterogeneous data.

A pipeline is a very handy function that can sequentially ensemble all your model development components. Using a pipeline one can easily perform the learning and inference tasks in a comparatively cleaner code structure.

In this article, we will discuss how to use scikit-learn pipelines to structure your code using chaining estimators and column transformers while developing an end-to-end machine learning model.

What is Pipeline?

A pipeline can sequentially list all the data processing and feature engineering estimators in a clean code structure. Basically, it chains multiple estimators into one. Pipelines are very convenient to use for learning and inference tasks and avoid data leakage. One can also perform a grid search over parameters of all estimators in the pipeline at once.

I will be developing an end-to-end machine learning model for a binary sample dataset with heterogeneous features. The binary sample dataset has 8 independent features of text, numerical, and categorical data types.

(Image by Author), Snapshot of the sample dataset

Usage:

The sample dataset includes text features (Name), Categorical features (Sex, Embarked), and Numerical features (PClass, Age, SibSp, Parch, Fare).

The raw real-world dataset might contain a lot of missing data values. We can use the SimpleImputer function from the scikit-learn package to impute the missing values. For categorical features, we can chain a one-hot encoder followed by an SVD estimator for feature decomposition.

For text features, we can vectorize the text using Count Vectorizer or Tf-Idf vectorizer to convert the text data into numerical embeddings followed by a dimensionality reduction estimator.

Pipeline 1 (For categoircal features):
1) Most Frequent value Imputer
2) One Hot Encoder
3) Truncated SVD decompositionPipeline 2 (For Text based features):
1) Tf-Idf Vectorizer
2) Truncated SVD decomposition
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.pipeline import make_pipeline

most_frequent_imputer = SimpleImputer(strategy='most_frequent')
onehot_encoder = OneHotEncoder(handle_unknown='ignore')
vectorizer = TfidfVectorizer()
svd = TruncatedSVD(n_components=2, random_state=42)

pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd)
pipe2 = make_pipeline(vectorizer, svd)

Column Transformer for Heterogenous data:

The sample dataset contains various data types of features from text datatypes, to float and object datatypes. So each type of feature requires separate feature engineering strategies.

Column Transformer is a scikit-learn function that enables developers to perform different feature engineering and data transformation steps for different sets of features. The good part of column transformers is that one can perform data transformation within the pipeline safe from data leakage issues.

Usage:

I have performed different feature transformation strategies for different sets of features.

  • Pipeline 1 for categorical features such as ‘Sex’ and ‘Embarked’
  • Pipeline 2 for text-based features such as ‘Name’
  • Mean imputer for numerical feature ‘Age’ as it has a lot of missing values.
  • The remaining numerical features don’t require any feature transformation, so it can be passed as it is using the ‘passthrough’ keyword
most_frequent_imputer = SimpleImputer(strategy='most_frequent')
onehot_encoder = OneHotEncoder(handle_unknown='ignore')
vectorizer = CountVectorizer()
mean_imputer = SimpleImputer(strategy='mean')
svd = TruncatedSVD(n_components=2, random_state=42)

pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd)
pipe2 = make_pipeline(vectorizer, svd)

column_trans = make_column_transformer(
    (pipe1, ['Sex', 'Embarked']),
    (pipe2, 'Name'),
    (mean_imputer, ['Age']),
    ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass']))

Modeling Pipeline:

After performing the data transformation steps one can move to the model component. I will be using a Logistic Regression estimator to train the transformed dataset. But before moving to the modeling stage we can also include a StandardScaler estimator to standardize the transformed dataset.

pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd)
pipe2 = make_pipeline(vectorizer, svd)

column_trans = make_column_transformer(
    (pipe1, ['Sex', 'Embarked']),
    (pipe2, 'Name'),
    (mean_imputer, ['Age']),
    ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass']))

scaler = StandardScaler()
classifier = LogisticRegression(random_state=42)

# Final Pipeline
pipeline = make_pipeline(column_trans, scaler, classifier)

Visualizing the Pipeline:

A visual representation of the entire pipeline is quite easy to interpret the end-to-end flow of the case study. Scikit-learn comes up with a set_config function that enables the developer to display a diagrammatic representation of the entire end-to-end pipeline.

By default, the set_config display parameter is ‘text’, which displays a textual format of the entire pipeline. Changing to the ‘diagram’ keyword will make it work.

from sklearn import set_config
set_config(display='diagram')

(Image by Author), Diagrammatic interpretation of the entire pipeline

Learning and Inference:

One can train the model pipeline using the .fit() function and perform inference using the .predict() function.

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8a6bc438",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.compose import make_column_transformer\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.utils import estimator_html_repr\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "from sklearn.metrics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8126bc5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(891, 9)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  Pclass                                               Name  \\\n",
       "0         0       3                            Braund, Mr. Owen Harris   \n",
       "1         1       1  Cumings, Mrs. John Bradley (Florence Briggs Th...   \n",
       "2         1       3                             Heikkinen, Miss. Laina   \n",
       "3         1       1       Futrelle, Mrs. Jacques Heath (Lily May Peel)   \n",
       "4         0       3                           Allen, Mr. William Henry   \n",
       "\n",
       "      Sex   Age  SibSp  Parch     Fare Embarked  \n",
       "0    male  22.0      1      0   7.2500        S  \n",
       "1  female  38.0      1      0  71.2833        C  \n",
       "2  female  26.0      0      0   7.9250        S  \n",
       "3  female  35.0      1      0  53.1000        S  \n",
       "4    male  35.0      0      0   8.0500        S  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"DATA/titanic.csv\")\n",
    "\n",
    "columns_to_drop = ['PassengerId','Ticket','Cabin']\n",
    "df = df.drop(columns_to_drop, axis=1)\n",
    "\n",
    "print(df.shape)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e66a0836",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 891 entries, 0 to 890\n",
      "Data columns (total 10 columns):\n",
      " #   Column    Non-Null Count  Dtype  \n",
      "---  ------    --------------  -----  \n",
      " 0   Survived  891 non-null    int64  \n",
      " 1   Pclass    891 non-null    int64  \n",
      " 2   Name      891 non-null    object \n",
      " 3   Sex       891 non-null    object \n",
      " 4   Age       714 non-null    float64\n",
      " 5   SibSp     891 non-null    int64  \n",
      " 6   Parch     891 non-null    int64  \n",
      " 7   Fare      891 non-null    float64\n",
      " 8   Cabin     204 non-null    object \n",
      " 9   Embarked  889 non-null    object \n",
      "dtypes: float64(2), int64(4), object(4)\n",
      "memory usage: 69.7+ KB\n"
     ]
    }
   ],
   "source": [
    "df.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d8c869",
   "metadata": {},
   "source": [
    "<B> Feature Transformation: </B>\n",
    "* Sex, Embarked: Categorical Imputer, OHE, SVD\n",
    "* Name: Count Vectorizer, PCA\n",
    "* Age: Simple Imputer\n",
    "* Fare, SibSP, Parch, Pclass: passthrough\n",
    "\n",
    "<B> Feature Preprocessing: </B>\n",
    "* Standard Scaler\n",
    "\n",
    "<B> Modeling: </B>\n",
    "* Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "10c9a630",
   "metadata": {},
   "outputs": [],
   "source": [
    "most_frequent_imputer = SimpleImputer(strategy='most_frequent')\n",
    "onehot_encoder = OneHotEncoder(handle_unknown='ignore')\n",
    "vectorizer = CountVectorizer()\n",
    "mean_imputer = SimpleImputer(strategy='mean')\n",
    "svd = TruncatedSVD(n_components=2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f2f1e0b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe1 = make_pipeline(most_frequent_imputer, onehot_encoder, svd)\n",
    "pipe2 = make_pipeline(vectorizer, svd)\n",
    "\n",
    "column_trans = make_column_transformer(\n",
    "    (pipe1, ['Sex', 'Embarked']),\n",
    "    (pipe2, 'Name'),\n",
    "    (mean_imputer, ['Age']),\n",
    "    ('passthrough', ['Fare', 'SibSp', 'Parch', 'Pclass']))\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "\n",
    "pipeline = make_pipeline(column_trans, scaler, classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d2d4c09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import set_config\n",
    "set_config(display='diagram')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4ad6b43c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c86030d8-48ef-4e5a-a584-b68ca5db79c9\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c86030d8-48ef-4e5a-a584-b68ca5db79c9\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('onehotencoder',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  ['Sex', 'Embarked']),\n",
       "                                                 ('pipeline-2',\n",
       "                                                  Pipeline(steps=[('countvectorizer',\n",
       "                                                                   CountVectorizer()),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  'Name'),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(), ['Age']),\n",
       "                                                 ('passthrough', 'passthrough',\n",
       "                                                  ['Fare', 'SibSp', 'Parch',\n",
       "                                                   'Pclass'])])),\n",
       "                ('standardscaler', StandardScaler()),\n",
       "                ('logisticregression', LogisticRegression(random_state=42))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"a928be64-e830-4166-8732-bc8b74a6b1e0\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"a928be64-e830-4166-8732-bc8b74a6b1e0\">columntransformer: ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                 Pipeline(steps=[('simpleimputer',\n",
       "                                                  SimpleImputer(strategy='most_frequent')),\n",
       "                                                 ('onehotencoder',\n",
       "                                                  OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                 ('truncatedsvd',\n",
       "                                                  TruncatedSVD(random_state=42))]),\n",
       "                                 ['Sex', 'Embarked']),\n",
       "                                ('pipeline-2',\n",
       "                                 Pipeline(steps=[('countvectorizer',\n",
       "                                                  CountVectorizer()),\n",
       "                                                 ('truncatedsvd',\n",
       "                                                  TruncatedSVD(random_state=42))]),\n",
       "                                 'Name'),\n",
       "                                ('simpleimputer', SimpleImputer(), ['Age']),\n",
       "                                ('passthrough', 'passthrough',\n",
       "                                 ['Fare', 'SibSp', 'Parch', 'Pclass'])])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"bdbba0fb-2b6b-4b5b-9b28-75fb788ff03d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"bdbba0fb-2b6b-4b5b-9b28-75fb788ff03d\">pipeline-1</label><div class=\"sk-toggleable__content\"><pre>['Sex', 'Embarked']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"adefeb60-c9bc-4208-85ce-dd0efb5e1ab8\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"adefeb60-c9bc-4208-85ce-dd0efb5e1ab8\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"74f7c18f-960e-4ce9-8ebd-64bc680d2b80\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"74f7c18f-960e-4ce9-8ebd-64bc680d2b80\">OneHotEncoder</label><div class=\"sk-toggleable__content\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"62c08d75-5927-4b94-b791-41704c1f587d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"62c08d75-5927-4b94-b791-41704c1f587d\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"f143c0ee-284d-4a88-9986-fa601350d0e5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"f143c0ee-284d-4a88-9986-fa601350d0e5\">pipeline-2</label><div class=\"sk-toggleable__content\"><pre>Name</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"7ff0c282-e3cd-4e5e-a74e-47df29826f73\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"7ff0c282-e3cd-4e5e-a74e-47df29826f73\">CountVectorizer</label><div class=\"sk-toggleable__content\"><pre>CountVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"9e43e3bc-31de-4ab1-92a2-0f365399dfef\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"9e43e3bc-31de-4ab1-92a2-0f365399dfef\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"d6d5054c-dea0-4125-8d9e-aa1245afc55a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"d6d5054c-dea0-4125-8d9e-aa1245afc55a\">simpleimputer</label><div class=\"sk-toggleable__content\"><pre>['Age']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"abd8343d-0dc4-4cd3-8418-0ab5efba0d52\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"abd8343d-0dc4-4cd3-8418-0ab5efba0d52\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"e8b597cc-a6f8-43ae-b116-2b880f06a9c7\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"e8b597cc-a6f8-43ae-b116-2b880f06a9c7\">passthrough</label><div class=\"sk-toggleable__content\"><pre>['Fare', 'SibSp', 'Parch', 'Pclass']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"8e4605c6-1dd3-4279-a8ed-5d6c0969024d\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"8e4605c6-1dd3-4279-a8ed-5d6c0969024d\">passthrough</label><div class=\"sk-toggleable__content\"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c191c6c9-eb6c-46b3-a5c0-0907a9c4a396\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c191c6c9-eb6c-46b3-a5c0-0907a9c4a396\">StandardScaler</label><div class=\"sk-toggleable__content\"><pre>StandardScaler()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"685d4be8-59bd-455e-b68c-3bc26992979a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"685d4be8-59bd-455e-b68c-3bc26992979a\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(random_state=42)</pre></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('onehotencoder',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  ['Sex', 'Embarked']),\n",
       "                                                 ('pipeline-2',\n",
       "                                                  Pipeline(steps=[('countvectorizer',\n",
       "                                                                   CountVectorizer()),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  'Name'),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(), ['Age']),\n",
       "                                                 ('passthrough', 'passthrough',\n",
       "                                                  ['Fare', 'SibSp', 'Parch',\n",
       "                                                   'Pclass'])])),\n",
       "                ('standardscaler', StandardScaler()),\n",
       "                ('logisticregression', LogisticRegression(random_state=42))])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cac91287",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To export the diagram to a html file\n",
    "with open('pipeline.html', 'w') as f:  \n",
    "    f.write(estimator_html_repr(pipeline))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd69b20",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bed7f4f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Test Split\n",
    "y = df['Survived']\n",
    "X = df.drop('Survived', axis=1)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "752ee4d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"ff61a100-b9aa-4b9e-b1f4-1f53012b3d01\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"ff61a100-b9aa-4b9e-b1f4-1f53012b3d01\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('onehotencoder',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  ['Sex', 'Embarked']),\n",
       "                                                 ('pipeline-2',\n",
       "                                                  Pipeline(steps=[('countvectorizer',\n",
       "                                                                   CountVectorizer()),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  'Name'),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(), ['Age']),\n",
       "                                                 ('passthrough', 'passthrough',\n",
       "                                                  ['Fare', 'SibSp', 'Parch',\n",
       "                                                   'Pclass'])])),\n",
       "                ('standardscaler', StandardScaler()),\n",
       "                ('logisticregression', LogisticRegression(random_state=42))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"2febe663-a70b-485f-968d-8e5e07dea52b\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"2febe663-a70b-485f-968d-8e5e07dea52b\">columntransformer: ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                 Pipeline(steps=[('simpleimputer',\n",
       "                                                  SimpleImputer(strategy='most_frequent')),\n",
       "                                                 ('onehotencoder',\n",
       "                                                  OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                 ('truncatedsvd',\n",
       "                                                  TruncatedSVD(random_state=42))]),\n",
       "                                 ['Sex', 'Embarked']),\n",
       "                                ('pipeline-2',\n",
       "                                 Pipeline(steps=[('countvectorizer',\n",
       "                                                  CountVectorizer()),\n",
       "                                                 ('truncatedsvd',\n",
       "                                                  TruncatedSVD(random_state=42))]),\n",
       "                                 'Name'),\n",
       "                                ('simpleimputer', SimpleImputer(), ['Age']),\n",
       "                                ('passthrough', 'passthrough',\n",
       "                                 ['Fare', 'SibSp', 'Parch', 'Pclass'])])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"a0a2a9f0-4617-4282-9bc9-d7e182b3dd8c\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"a0a2a9f0-4617-4282-9bc9-d7e182b3dd8c\">pipeline-1</label><div class=\"sk-toggleable__content\"><pre>['Sex', 'Embarked']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"5d65213e-b803-43e2-95f1-07a035bc14c5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"5d65213e-b803-43e2-95f1-07a035bc14c5\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"fb4a1a51-cc5a-4989-9653-5e130438405a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"fb4a1a51-cc5a-4989-9653-5e130438405a\">OneHotEncoder</label><div class=\"sk-toggleable__content\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"1979132e-08ae-41fa-9932-888da3da49c1\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"1979132e-08ae-41fa-9932-888da3da49c1\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"5f90320f-f3bb-4baf-ba5e-9292683493c3\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"5f90320f-f3bb-4baf-ba5e-9292683493c3\">pipeline-2</label><div class=\"sk-toggleable__content\"><pre>Name</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"2cce429d-e319-40b2-a178-d6eaeeead992\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"2cce429d-e319-40b2-a178-d6eaeeead992\">CountVectorizer</label><div class=\"sk-toggleable__content\"><pre>CountVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"c3df0d2e-9837-4781-afbc-40fd937ebe26\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"c3df0d2e-9837-4781-afbc-40fd937ebe26\">TruncatedSVD</label><div class=\"sk-toggleable__content\"><pre>TruncatedSVD(random_state=42)</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"03c69ef3-7831-4308-b14c-fceb9c0bcfe5\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"03c69ef3-7831-4308-b14c-fceb9c0bcfe5\">simpleimputer</label><div class=\"sk-toggleable__content\"><pre>['Age']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"b458cbb7-7ad8-4f0e-b403-cca747561cb2\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"b458cbb7-7ad8-4f0e-b403-cca747561cb2\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"72a1f8e7-7136-4559-abd5-15f2f19c0647\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"72a1f8e7-7136-4559-abd5-15f2f19c0647\">passthrough</label><div class=\"sk-toggleable__content\"><pre>['Fare', 'SibSp', 'Parch', 'Pclass']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"777b9d37-3d2f-4785-b76b-f09faaa3f46a\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"777b9d37-3d2f-4785-b76b-f09faaa3f46a\">passthrough</label><div class=\"sk-toggleable__content\"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"93e9368f-72b1-49f9-9830-30e10c2f18dc\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"93e9368f-72b1-49f9-9830-30e10c2f18dc\">StandardScaler</label><div class=\"sk-toggleable__content\"><pre>StandardScaler()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"198e6745-c4f7-4118-8fee-f6977a8f6653\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"198e6745-c4f7-4118-8fee-f6977a8f6653\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(random_state=42)</pre></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline-1',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('onehotencoder',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore')),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  ['Sex', 'Embarked']),\n",
       "                                                 ('pipeline-2',\n",
       "                                                  Pipeline(steps=[('countvectorizer',\n",
       "                                                                   CountVectorizer()),\n",
       "                                                                  ('truncatedsvd',\n",
       "                                                                   TruncatedSVD(random_state=42))]),\n",
       "                                                  'Name'),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(), ['Age']),\n",
       "                                                 ('passthrough', 'passthrough',\n",
       "                                                  ['Fare', 'SibSp', 'Parch',\n",
       "                                                   'Pclass'])])),\n",
       "                ('standardscaler', StandardScaler()),\n",
       "                ('logisticregression', LogisticRegression(random_state=42))])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "977d3ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_performance_metrics(y, y_hat, y_hat_proba):\n",
    "    precision = precision_score(y, y_hat)\n",
    "    recall = recall_score(y, y_hat)\n",
    "    roc_auc = roc_auc_score(y, y_hat_proba)\n",
    "    cm = confusion_matrix(y, y_hat)\n",
    "    \n",
    "    return {'Precision':precision, 'Recall':recall, 'ROC AUC Score':roc_auc, 'CM':cm}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "899ea452",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Precision': 0.7397260273972602,\n",
       " 'Recall': 0.7297297297297297,\n",
       " 'ROC AUC Score': 0.8720720720720722,\n",
       " 'CM': array([[86, 19],\n",
       "        [20, 54]], dtype=int64)}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Test Performance\n",
    "get_performance_metrics(y_test, pipeline.predict(X_test), pipeline.predict_proba(X_test)[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1a844f1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Precision': 0.8,\n",
       " 'Recall': 0.7313432835820896,\n",
       " 'ROC AUC Score': 0.8623268791179239,\n",
       " 'CM': array([[395,  49],\n",
       "        [ 72, 196]], dtype=int64)}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train Performance\n",
    "get_performance_metrics(y_train, pipeline.predict(X_train), pipeline.predict_proba(X_train)[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a154855",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}

Conclusion:

A pipeline is a very handy function that can sequentially ensemble all your estimators into one, and prepare a clean and well-structured code. The learning and inference task becomes very easy to run using the Pipeline. You can display the entire end-to-end pipeline in an interpretable representation.

References:

[1] Scikit-learn documentation: https://scikit-learn.org/stable/modules/compose.html

Article originally posted here. Reposted with permission.

Dominic Rubhabha-Wardslaus
Dominic Rubhabha-Wardslaushttp://wardslaus.com
infosec,malicious & dos attacks generator, boot rom exploit philanthropist , wild hacker , game developer,
RELATED ARTICLES

Most Popular

Recent Comments