Saturday, December 28, 2024
Google search engine
HomeData Modelling & AILength of Longest Increasing Subsequences (LIS) using Segment Tree

Length of Longest Increasing Subsequences (LIS) using Segment Tree

Given an array arr[] of size N, the task is to count the number of longest increasing subsequences present in the given array.

Example:

Input: arr[] = {2, 2, 2, 2, 2}
Output: 5
Explanation: The length of the longest increasing subsequence is 1, i.e. {2}. Therefore, count of longest increasing subsequences of length 1 is 5. 
 

Input: arr[] = {1, 3, 5, 4, 7}
Output: 2
Explanation: The length of the longest increasing subsequence is 4, and there are 2 longest increasing subsequences of length 4, i.e. {1, 3, 4, 7} and {1, 3, 5, 7}.

Approach: An approach to the given problem has been already discussed using dynamic programming in this article. 
This article suggests a different approach using segment trees. Follow the below steps to solve the given problem:

  • Initialise the segment tree as an array of pairs initially containing pairs of (0, 0), where the 1st element represents the length of LIS and 2nd element represents the count of LIS of current length.
  • The 1st element of the segment tree can be calculated similarly to the approach discussed in this article.
  • The 2nd element of the segment tree can be calculated using the following steps:
    • If cases where the length of left child > length of right child, the parent node becomes equal to the left child as LIS will that be of the left child.
    • If cases where the length of left child < length of right child, the parent node becomes equal to the right child as LIS will that be of the right child.
    • If cases where the length of left child = length of right child, the parent node becomes equal to the sum of the count of LIS of the left child and the right child.
  • The required answer is the 2nd element of the root of the segment tree.

Below is the implementation of the above approach:

C++




// C++ implementation of the above approach
#include <bits/stdc++.h>
using namespace std;
 
#define M 100000
 
// Stores the Segment tree
vector<pair<int, int> > tree(4 * M + 1);
 
// Function to update Segment tree, the root
// of which contains the length of the LIS
void update_tree(int start, int end,
                 int update_idx, int length_t,
                 int count_c, int idx)
{
    // If the intervals
    // are overlapping completely
    if (start == end
        && start == update_idx) {
        tree[idx].first
            = max(tree[idx].first, length_t);
        tree[idx].second = count_c;
        return;
    }
 
    // If intervals are not overlapping
    if (update_idx < start
        || end < update_idx) {
        return;
    }
 
    // If intervals are partially overlapping
    int mid = (start + end) / 2;
 
    update_tree(start, mid, update_idx,
                length_t, count_c,
                2 * idx);
    update_tree(mid + 1, end, update_idx,
                length_t, count_c,
                2 * idx + 1);
 
    // If length_t of left and
    // right child are equal
    if (tree[2 * idx].first
        == tree[2 * idx + 1].first) {
        tree[idx].first
            = tree[2 * idx].first;
        tree[idx].second
            = tree[2 * idx].second
              + tree[2 * idx + 1].second;
    }
 
    // If length_t of left > length_t right child
    else if (tree[2 * idx].first
             > tree[2 * idx + 1].first) {
        tree[idx] = tree[2 * idx];
    }
 
    // If length_t of left < length_t right child
    else {
        tree[idx] = tree[2 * idx + 1];
    }
}
 
// Function to find the LIS length
// and count in the given range
pair<int, int> query(int start, int end,
                     int query_start,
                     int query_end, int idx)
{
    // If the intervals
    // are overlapping completely
    if (query_start <= start
        && end <= query_end) {
        return tree[idx];
    }
 
    // If intervals are not overlapping
    pair<int, int> temp({ INT32_MIN, 0 });
    if (end < query_start
        || query_end < start) {
        return temp;
    }
 
    // If intervals are partially overlapping
    int mid = (start + end) / 2;
    auto left_child
        = query(start, mid, query_start,
                query_end, 2 * idx);
    auto right_child
        = query(mid + 1, end, query_start,
                query_end, 2 * idx + 1);
 
    // If length_t of left child is greater
    // than length_t of right child
    if (left_child.first > right_child.first) {
        return left_child;
    }
 
    // If length_t of right child is
    // greater than length_t of left child
    if (right_child.first > left_child.first) {
        return right_child;
    }
 
    // If length_t of left
    // and right child are equal
    // return there sum
    return make_pair(left_child.first,
                     left_child.second
                         + right_child.second);
}
 
// Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
bool comp(pair<int, int> a, pair<int, int> b)
{
    if (a.first == b.first) {
        return a.second > b.second;
    }
    return a.first < b.first;
}
 
// Function to find count
// of LIS in the given array
int countLIS(int arr[], int n)
{
    // Generating value-index pair array
    vector<pair<int, int> > pair_array(n);
    for (int i = 0; i < n; i++) {
        pair_array[i].first = arr[i];
        pair_array[i].second = i;
    }
 
    // Sort array of pairs with increasing order
    // of value and decreasing order of index
    sort(pair_array.begin(),
         pair_array.end(), comp);
 
    // Traverse the array
    // and perform query updates
    for (int i = 0; i < n; i++) {
 
        int update_idx = pair_array[i].second;
 
        // If update index is the 1st index
        if (update_idx == 0) {
            update_tree(0, n - 1, 0, 1, 1, 1);
            continue;
        }
 
        // Query over the interval [0, update_idx -1]
        pair<int, int> temp
            = query(0, n - 1, 0,
                    update_idx - 1, 1);
 
        // Update the segment tree
        update_tree(0, n - 1, update_idx,
                    temp.first + 1,
                    max(1, temp.second), 1);
    }
 
    // Stores the final answer
    pair<int, int> ans
        = query(0, n - 1, 0, n - 1, 1);
 
    // Return answer
    return ans.second;
}
 
// Driver Code
int main()
{
    int arr[] = { 1, 3, 5, 4, 7 };
    int n = sizeof(arr) / sizeof(int);
 
    cout << countLIS(arr, n);
 
    return 0;
}


Java




import java.util.*;
import java.io.*;
 
// Java program for the above approach
public class GFG{
 
    public static int M = 100000;
 
    // Stores the Segment tree
    public static ArrayList<ArrayList<Integer>> tree =
      new ArrayList<ArrayList<Integer>>();
 
    // Function to update Segment tree, the root
    // of which contains the length of the LIS
    public static void update_tree(int start, int end,
                                   int update_idx, int length_t,
                                   int count_c, int idx)
    {
        // If the intervals
        // are overlapping completely
        if (start == end && start == update_idx) {
            tree.get(idx).set(0, Math.max(tree.get(idx).get(0), length_t));
            tree.get(idx).set(1, count_c);
            return;
        }
 
        // If intervals are not overlapping
        if (update_idx < start || end < update_idx) {
            return;
        }
 
        // If intervals are partially overlapping
        int mid = (start + end) / 2;
 
        update_tree(start, mid, update_idx,
                    length_t, count_c, 2 * idx);
        update_tree(mid + 1, end, update_idx,
                    length_t, count_c, 2 * idx + 1);
 
        // If length_t of left and
        // right child are equal
        if (tree.get(2 * idx).get(0) == tree.get(2 * idx + 1).get(0)) {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx).get(0),
                        tree.get(2 * idx).get(1) +
                        tree.get(2 * idx + 1).get(1))
            ));
        }
 
        // If length_t of left > length_t right child
        else if (tree.get(2 * idx).get(0) > tree.get(2 * idx + 1).get(0)) {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx).get(0), tree.get(2 * idx).get(1))
            ));
        }
 
        // If length_t of left < length_t right child
        else {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx + 1).get(0), tree.get(2 * idx + 1).get(1))
            ));
        }
    }
 
    // Function to find the LIS length
    // and count in the given range
    public static ArrayList<Integer> query(int start, int end,
                                           int query_start,
                                           int query_end, int idx)
    {
        // If the intervals
        // are overlapping completely
        if (query_start <= start && end <= query_end) {
            return new ArrayList<Integer>(tree.get(idx));
        }
 
        // If intervals are not overlapping
        ArrayList<Integer> temp = new ArrayList<Integer>(
            List.of(Integer.MIN_VALUE, 0 )
        );
 
        if (end < query_start || query_end < start) {
            return new ArrayList<Integer>(temp);
        }
 
        // If intervals are partially overlapping
        int mid = (start + end) / 2;
        ArrayList<Integer> left_child = query(start, mid,
                                              query_start,
                                              query_end, 2 * idx);
        ArrayList<Integer> right_child = query(mid + 1, end,
                                               query_start,
                                               query_end, 2 * idx + 1);
 
        // If length_t of left child is greater
        // than length_t of right child
        if (left_child.get(0) > right_child.get(0)) {
            return new ArrayList<Integer>(left_child);
        }
 
        // If length_t of right child is
        // greater than length_t of left child
        if (right_child.get(0) > left_child.get(0)) {
            return new ArrayList<Integer>(right_child);
        }
 
        // If length_t of left
        // and right child are equal
        // return there sum
        return new ArrayList<Integer>(
            List.of(
                left_child.get(0),
                left_child.get(1) + right_child.get(1)
            )
        );
    }
 
    // Function to find count
    // of LIS in the given array
    public static int countLIS(int arr[], int n)
    {
        // Generating value-index pair array
        ArrayList<ArrayList<Integer>> pair_array = new ArrayList<ArrayList<Integer>>();
 
        for(int i = 0 ; i < n ; i++){
            pair_array.add(new ArrayList<Integer>(
                List.of(arr[i], i)
            ));
        }
 
        // Sort array of pairs with increasing order
        // of value and decreasing order of index
        Collections.sort(pair_array, new comp());
 
        // Traverse the array
        // and perform query updates
        for (int i = 0 ; i < n ; i++) {
 
            int update_idx = pair_array.get(i).get(1);
 
            // If update index is the 1st index
            if (update_idx == 0) {
                update_tree(0, n - 1, 0, 1, 1, 1);
                continue;
            }
 
            // Query over the interval [0, update_idx -1]
            ArrayList<Integer> temp = query(0, n - 1, 0,
                                            update_idx - 1, 1);
 
            // Update the segment tree
            update_tree(0, n - 1, update_idx, temp.get(0) + 1,
                        Math.max(1, temp.get(1)), 1);
        }
 
        // Stores the final answer
        ArrayList<Integer> ans = query(0, n - 1, 0, n - 1, 1);
 
        // Return answer
        return ans.get(1);
    }
 
 
    // Driver code
    public static void main(String args[])
    {
        int arr[] = { 1, 3, 5, 4, 7 };
        int n = arr.length;
 
        for(int i = 0 ; i < 4*M + 1 ; i++){
            tree.add(new ArrayList<Integer>(
                List.of(Integer.MIN_VALUE,0)
            ));
        }
 
        System.out.println(countLIS(arr, n));
    }
}
 
 // Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
public class comp implements Comparator<ArrayList<Integer>>{
    public int compare(ArrayList<Integer> a, ArrayList<Integer> b)
    {
        if (a.get(0).equals(b.get(0))) {
            return b.get(1).compareTo(a.get(1));
        }
        return a.get(0).compareTo(b.get(0));
    }
}
 
// This code is contributed by subhamgoyal2014.


C#




// Finding the Longest Increasing Subsequence using
// Segment Tree
 
using System;
 
class SegmentTree {
    private int[] tree; // The segment tree array
 
    // Constructor that initializes the segment tree
    public SegmentTree(int size)
    {
        // Determine the height of the tree
        int height = (int)Math.Ceiling(Math.Log(size, 2));
        // Determine the maximum size of the tree array
        int maxSize = 2 * (int)Math.Pow(2, height) - 1;
        // Create the tree array
        tree = new int[maxSize];
    }
 
    // Method that builds the segment tree from an array
    public void BuildTree(int[] arr, int pos, int low,
                          int high)
    {
        // If the segment has only one element, set the
        // corresponding value in the tree
        if (low == high) {
            tree[pos] = arr[low];
            return;
        }
 
        // Determine the middle index of the segment
        int mid = (low + high) / 2;
        // Recursively build the left subtree
        BuildTree(arr, 2 * pos + 1, low, mid);
        // Recursively build the right subtree
        BuildTree(arr, 2 * pos + 2, mid + 1, high);
        // Set the value of the current node to the maximum
        // value of its children
        tree[pos] = Math.Max(tree[2 * pos + 1],
                             tree[2 * pos + 2]);
    }
 
    // Method that returns the maximum value in a given
    // range
    public int Query(int pos, int low, int high, int start,
                     int end)
    {
        // If the given range is fully contained in the
        // current segment, return the corresponding value
        // in the tree
        if (start <= low && end >= high) {
            return tree[pos];
        }
 
        // If the given range does not overlap with the
        // current segment, return the minimum possible
        // value
        if (start > high || end < low) {
            return int.MinValue;
        }
 
        // Determine the middle index of the segment
        int mid = (low + high) / 2;
        // Recursively query the left subtree
        int left = Query(2 * pos + 1, low, mid, start, end);
        // Recursively query the right subtree
        int right
            = Query(2 * pos + 2, mid + 1, high, start, end);
        // Return the maximum value of the two subtrees
        return Math.Max(left, right);
    }
}
 
class LIS {
    // Method that returns the length of the longest
    // increasing subsequence in an array
    public static int GetLISLength(int[] arr)
    {
        int n = arr.Length;
 
        // Create a sorted copy of the array
        int[] sortedArr = new int[n];
        Array.Copy(arr, sortedArr, n);
        Array.Sort(sortedArr);
 
        // Create a map that maps the elements of the
        // original array to their indices in the sorted
        // array
        int[] indexMap = new int[n];
        for (int i = 0; i < n; i++) {
            indexMap[Array.IndexOf(sortedArr, arr[i])] = i;
        }
 
        // Create a segment tree to store the dynamic
        // programming values
        SegmentTree tree = new SegmentTree(n);
        // Build the initial tree with all values set to
        // zero
        tree.BuildTree(new int[n], 0, 0, n - 1);
 
        // Create an array to store the dynamic programming
        // values
        int[] dp = new int[n];
        for (int i = 0; i < n; i++) {
            int prevMax = tree.Query(0, 0, n - 1, 0,
                                     indexMap[i] - 1);
            dp[i] = prevMax + 1;
            tree.BuildTree(dp, 0, 0, n - 1);
        }
 
        // Searching for the maximum value of longest
        // increasing subsequence
        int maxLIS = 0;
        for (int i = 0; i < n; i++) {
            maxLIS = Math.Max(maxLIS, dp[i]);
        }
 
        // Return the maximum value of longest increasing
        // subsequence
        return maxLIS;
    }
}
 
// Driver code
class Program {
    static void Main(string[] args)
    {
        int[] arr = { 1, 3, 5, 4, 7 };
        Console.WriteLine(LIS.GetLISLength(arr));
    }
}


Javascript




<script>
// Javascript implementation of the above approach
 
 
let M = 100000
 
// Stores the Segment tree
let tree = new Array(4 * M + 1).fill(0).map(() => []);
 
// Function to update Segment tree, the root
// of which contains the length of the LIS
function update_tree(start, end, update_idx, length_t, count_c, idx) {
  // If the intervals
  // are overlapping completely
  if (start == end
    && start == update_idx) {
    tree[idx][0]
      = Math.max(tree[idx][0], length_t);
    tree[idx][1] = count_c;
    return;
  }
 
  // If intervals are not overlapping
  if (update_idx < start
    || end < update_idx) {
    return;
  }
 
  // If intervals are partially overlapping
  let mid = Math.floor((start + end) / 2);
 
  update_tree(start, mid, update_idx,
    length_t, count_c,
    2 * idx);
  update_tree(mid + 1, end, update_idx,
    length_t, count_c,
    2 * idx + 1);
 
  // If length_t of left and
  // right child are equal
  if (tree[2 * idx][0]
    == tree[2 * idx + 1][0]) {
    tree[idx][0]
      = tree[2 * idx][0];
    tree[idx][1]
      = tree[2 * idx][1]
      + tree[2 * idx + 1][1];
  }
 
  // If length_t of left > length_t right child
  else if (tree[2 * idx][0]
    > tree[2 * idx + 1][0]) {
    tree[idx] = tree[2 * idx];
  }
 
  // If length_t of left < length_t right child
  else {
    tree[idx] = tree[2 * idx + 1];
  }
}
 
// Function to find the LIS length
// and count in the given range
function query(start, end, query_start, query_end, idx) {
  // If the intervals
  // are overlapping completely
  if (query_start <= start
    && end <= query_end) {
    return tree[idx];
  }
 
  // If intervals are not overlapping
  let temp = [Number.MIN_SAFE_INTEGER, 0];
  if (end < query_start
    || query_end < start) {
    return temp;
  }
 
  // If intervals are partially overlapping
  let mid = Math.floor((start + end) / 2);
  let left_child
    = query(start, mid, query_start,
      query_end, 2 * idx);
  let right_child
    = query(mid + 1, end, query_start,
      query_end, 2 * idx + 1);
 
  // If length_t of left child is greater
  // than length_t of right child
  if (left_child[0] > right_child[0]) {
    return left_child;
  }
 
  // If length_t of right child is
  // greater than length_t of left child
  if (right_child[0] > left_child[0]) {
    return right_child;
  }
 
  // If length_t of left
  // and right child are equal
  // return there sum
  return [left_child[0],
  left_child[1]
  + right_child[1]];
}
 
// Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
function comp(a, b) {
  if (a[0] == b[0]) {
    return a[1] > b[1];
  }
  return a[0] < b[0];
}
 
// Function to find count
// of LIS in the given array
function countLIS(arr, n) {
  // Generating value-index pair array
  let pair_array = new Array(n).fill(0).map(() => []);
  for (let i = 0; i < n; i++) {
    pair_array[i][0] = arr[i];
    pair_array[i][1] = i;
  }
 
  // Sort array of pairs with increasing order
  // of value and decreasing order of index
  pair_array.sort(comp);
 
  // Traverse the array
  // and perform query updates
  for (let i = 0; i < n; i++) {
 
    let update_idx = pair_array[i][1];
 
    // If update index is the 1st index
    if (update_idx == 0) {
      update_tree(0, n - 1, 0, 1, 1, 1);
      continue;
    }
 
    // Query over the interval [0, update_idx -1]
    let temp = query(0, n - 1, 0, update_idx - 1, 1);
 
    // Update the segment tree
    update_tree(0, n - 1, update_idx,
      temp[0] + 1,
      Math.max(1, temp[1]), 1);
  }
 
  // Stores the final answer
  let ans = query(0, n - 1, 0, n - 1, 1);
 
  // Return answer
  return ans[1];
}
 
// Driver Code
 
let arr = [1, 3, 5, 4, 7];
let n = arr.length;
 
document.write(countLIS(arr, n));
 
// This code is contributed by saurabh_jaiswal.
</script>


Python3




# Python program for the above approach
 
import math
 
# Finding the longest increasing Subsequence using
# Segment Tree
class SegmentTree:
     
    # Constructor that initializes the segment tree
    def __init__(self, size):
        # Determine the height of the tree
        height = math.ceil(math.log2(size))
         
        # Determine the maximum size of the tree array
        maxSize = 2 * (2 ** height) - 1
         
        # Create the tree array
        self.tree = [0] * maxSize
 
    # Method that builds the segment tree from an array
    def BuildTree(self, arr, pos, low, high):
         
        # If the segment has only one element, set the
        # corresponding value in the tree
        if low == high:
            self.tree[pos] = arr[low]
            return
 
        # Determine the middle index of the segment
        mid = (low + high) // 2
         
        # Recursively build the left subtree
        self.BuildTree(arr, 2 * pos + 1, low, mid)
         
        # Recursively build the right subtree
        self.BuildTree(arr, 2 * pos + 2, mid + 1, high)
         
        # Set the value of the current node to the maximum
        # value of its children
        self.tree[pos] = max(self.tree[2 * pos + 1], self.tree[2 * pos + 2])
 
    # Method that returns the maximum value in a given
    # range
    def Query(self, pos, low, high, start, end):
         
        # If the given range is fully contained in the
        # current segment, return the corresponding value
        # in the tree
        if start <= low and end >= high:
            return self.tree[pos]
         
        # If the given range does not overlap with the
        # current segment, return the minimum possible
        # value
        if start > high or end < low:
            return float('-inf')
         
        # Determine the middle index of the segment
        mid = (low + high) // 2
         
        # Recursively query the left subtree
        left = self.Query(2 * pos + 1, low, mid, start, end)
         
        # Recursively query the right subtree
        right = self.Query(2 * pos + 2, mid + 1, high, start, end)
         
        # Return the maximum value of the two subtrees
        return max(left, right)
 
class LIS:
    @staticmethod
    # Method that returns the length of the longest
    # increasing subsequence in an array
    def GetLISLength(arr):
        n = len(arr)
         
        # Create a sorted copy of the array
        sortedArr = sorted(arr)
         
         
        # Create a map that maps the elements of the
        # original array to their indices in the sorted
        # array
        indexMap = [0] * n
         
        for i in range(n):
            indexMap[sortedArr.index(arr[i])] = i
     
        # Create a segment tree to store the dynamic
        # programming values
        tree = SegmentTree(n)
         
        # Build the initial tree with all values set to
        # zero
        tree.BuildTree([0] * n, 0, 0, n - 1)
         
        # Create an array to store the dynamic programming
        # values
        dp = [0] * n
 
        for i in range(n):
            prevMax = tree.Query(0, 0, n - 1, 0, indexMap[i] - 1)
            dp[i] = prevMax + 1
            tree.BuildTree(dp, 0, 0, n - 1)
         
        # Searching for the maximum value of longest
        # increasing subsequence
        maxLIS = 0
        for i in range(n):
            maxLIS = max(maxLIS, dp[i])
         
        # Return the maximum value of longest increasing
        # subsequence
        return maxLIS
 
# Driver code
arr = [1, 3, 5, 4, 7]
print(LIS.GetLISLength(arr))
 
 
# This code is contributed by princekumaras


Output

2

Time Complexity: O(N*log N)
Auxiliary Space: O(N)

Related Topic: Segment Tree

Feeling lost in the world of random DSA topics, wasting time without progress? It’s time for a change! Join our DSA course, where we’ll guide you on an exciting journey to master DSA efficiently and on schedule.
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 neveropen!

RELATED ARTICLES

Most Popular

Recent Comments