Given an integer X and a binary tree, the task is to count the number of triplet triplets of nodes such that their sum is greater than X and they have a grandparent -> parent -> child relationship.
Example:
Input: X = 100 10 / \ 1 22 / \ / \ 35 4 15 67 / \ / \ / \ / \ 57 38 9 10 110 312 131 414 / \ 8 39 Output: 6 The triplets are: 22 -> 15 -> 110 22 -> 15 -> 312 15 -> 312 -> 8 22 -> 67 -> 131 22 -> 67 -> 414 67 -> 414 -> 39
Approach: The problem can be solved using a rolling sum up approach with a rolling period as 3 (grandparent -> parent -> child)
- Traverse the tree in preorder or postorder (INORDER WON’T WORK)
- Maintain a stack where we maintain rolling sum with a rolling period as 3
- Whenever we have more than 3 elements in the stack and if the topmost value is greater than X, we increment the result by 1.
- When we move up the recursion tree we do a POP operation on the stack so that all the rolling sums of lower levels get removed from the stack.
Below is the implementation of the above approach:
C++
#include <iostream> #include <vector> using namespace std; // Class to store node information class Node { public : int val; Node* left; Node* right; Node( int value = 0, Node* l = nullptr, Node* r = nullptr) : val(value) , left(l) , right(r) { } }; // Stack to perform stack operations class Stack { public : vector< int > stack; int size() { return stack.size(); } int top() { return (size() > 0) ? stack[size() - 1] : 0; } void push( int val) { stack.push_back(val); } void pop() { if (size() >= 1) { stack.pop_back(); } else { stack.clear(); } } // Period is 3 to satisfy grandparent-parent-child // relationship void rolling_push( int val, int period = 3) { // Find the index of element to remove int to_remove_idx = size() - period; // If index is out of bounds then we remove nothing, // i.e, 0 int to_remove = (to_remove_idx < 0) ? 0 : stack[to_remove_idx]; // For rolling sum what we want is that at each // index i, we remove out-of-period elements, but // since we are not maintaining a separate list of // actual elements, we can get the actual element by // taking diff between current and previous element. // So every time we remove an element, we also add // the element just before it, which is equivalent // to removing the actual value and not the rolling // sum. // If index is out of bounds or 0 then we add // nothing i.e, 0, because there is no previous // element int to_add = (to_remove_idx <= 0) ? 0 : stack[to_remove_idx - 1]; // If stack is empty then just push the value // Else add last element to current value to get // rolling sum then subtract out-of-period elements, // then finally add the element just before // out-of-period element if (size() <= 0) { push(val); } else { push(val + stack[size() - 1] - to_remove + to_add); } } void show() { for ( auto item : stack) { cout << item << " " ; } cout << endl; } }; // Global variables used by count_with_greater_sum() int count = 0; Stack s; void count_with_greater_sum(Node* root_node, int x) { if (!root_node) { return ; } s.rolling_push(root_node->val); if (s.size() >= 3 && s.top() > x) { count++; } count_with_greater_sum(root_node->left, x); count_with_greater_sum(root_node->right, x); // Moving up the tree so pop the last element s.pop(); } int main() { Node* root = new Node(10); root->left = new Node(1); root->right = new Node(22); root->left->left = new Node(35); root->left->right = new Node(4); root->right->left = new Node(15); root->right->right = new Node(67); root->left->left->left = new Node(57); root->left->left->right = new Node(38); root->left->right->left = new Node(9); root->left->right->right = new Node(10); root->right->left->left = new Node(110); root->right->left->right = new Node(312); root->right->right->left = new Node(131); root->right->right->right = new Node(414); root->right->left->right->left = new Node(8); root->right->right->right->right = new Node(39); count_with_greater_sum(root, 100); cout << count << endl; } |
Python3
# Python3 implementation of the approach # Class to store node information class Node: def __init__( self , val = None , left = None , right = None ): self .val = val self .left = left self .right = right # Stack to perform stack operations class Stack: def __init__( self ): self .stack = [] @property def size( self ): return len ( self .stack) def top( self ): return self .stack[ self .size - 1 ] if self .size > 0 else 0 def push( self , val): self .stack.append(val) def pop( self ): if self .size > = 1 : self .stack.pop( self .size - 1 ) else : self .stack = [] # Period is 3 to satisfy grandparent-parent-child relationship def rolling_push( self , val, period = 3 ): # Find the index of element to remove to_remove_idx = self .size - period # If index is out of bounds then we remove nothing, i.e, 0 to_remove = 0 if to_remove_idx < 0 else self .stack[to_remove_idx] # For rolling sum what we want is that at each index i, # we remove out-of-period elements, but since we are not # maintaining a separate list of actual elements, # we can get the actual element by taking diff between current # and previous element. So every time we remove an element, # we also add the element just before it, which is # equivalent to removing the actual value and not the rolling sum. # If index is out of bounds or 0 then we add nothing # i.e, 0, because there is no previous element to_add = 0 if to_remove_idx < = 0 else self .stack[to_remove_idx - 1 ] # If stack is empty then just push the value # Else add last element to current value to get rolling sum # then subtract out-of-period elements, # then finally add the element just before out-of-period element self .push(val if self .size < = 0 else val + self .stack[ self .size - 1 ] - to_remove + to_add) def show( self ): for item in self .stack: print (item) # Global variables used by count_with_greater_sum() count = 0 s = Stack() def count_with_greater_sum(root_node, x): global s, count if not root_node: return 0 s.rolling_push(root_node.val) if s.size > = 3 and s.top() > x: count + = 1 count_with_greater_sum(root_node.left, x) count_with_greater_sum(root_node.right, x) # Moving up the tree so pop the last element s.pop() if __name__ = = '__main__' : root = Node( 10 ) root.left = Node( 1 ) root.right = Node( 22 ) root.left.left = Node( 35 ) root.left.right = Node( 4 ) root.right.left = Node( 15 ) root.right.right = Node( 67 ) root.left.left.left = Node( 57 ) root.left.left.right = Node( 38 ) root.left.right.left = Node( 9 ) root.left.right.right = Node( 10 ) root.right.left.left = Node( 110 ) root.right.left.right = Node( 312 ) root.right.right.left = Node( 131 ) root.right.right.right = Node( 414 ) root.right.left.right.left = Node( 8 ) root.right.right.right.right = Node( 39 ) count_with_greater_sum(root, 100 ) print (count) |
Javascript
// Class to store node information class Node { constructor(val = null , left = null , right = null ) { this .val = val; this .left = left; this .right = right; } } // Stack to perform stack operations class Stack { constructor() { this .stack = []; } get size() { return this .stack.length; } top() { return this .size > 0 ? this .stack[ this .size - 1] : 0; } push(val) { this .stack.push(val); } pop() { if ( this .size >= 1) { this .stack.pop(); } else { this .stack = []; } } // Period is 3 to satisfy grandparent-parent-child relationship rolling_push(val, period = 3) { // Find the index of element to remove const to_remove_idx = this .size - period; // If index is out of bounds then we remove nothing, i.e, 0 const to_remove = to_remove_idx < 0 ? 0 : this .stack[to_remove_idx]; // For rolling sum what we want is that at each index i, // we remove out-of-period elements, but since we are not // maintaining a separate list of actual elements, // we can get the actual element by taking diff between current // and previous element. So every time we remove an element, // we also add the element just before it, which is // equivalent to removing the actual value and not the rolling sum. // If index is out of bounds or 0 then we add nothing // i.e, 0, because there is no previous element const to_add = to_remove_idx <= 0 ? 0 : this .stack[to_remove_idx - 1]; // If stack is empty then just push the value // Else add last element to current value to get rolling sum // then subtract out-of-period elements, // then finally add the element just before out-of-period element this .push( val + ( this .size <= 0 ? 0 : this .stack[ this .size - 1]) - to_remove + to_add ); } show() { for (const item of this .stack) { console.log(item); } } } // Global variables used by count_with_greater_sum() let count = 0; const s = new Stack(); function count_with_greater_sum(root_node, x) { if (!root_node) { return 0; } s.rolling_push(root_node.val); if (s.size >= 3 && s.top() > x) { count += 1; } count_with_greater_sum(root_node.left, x); count_with_greater_sum(root_node.right, x); // Moving up the tree so pop the last element s.pop(); } const root = new Node(10); root.left = new Node(1); root.right = new Node(22); root.left.left = new Node(35); root.left.right = new Node(4); root.right.left = new Node(15); root.right.right = new Node(67); root.left.left.left = new Node(57); root.left.left.right = new Node(38); root.left.right.left = new Node(9); root.left.right.right = new Node(10); root.right.left.left = new Node(110); root.right.left.right = new Node(312); root.right.right.left = new Node(131); root.right.right.right = new Node(414); root.right.left.right.left = new Node(8); root.right.right.right.right = new Node(39); count_with_greater_sum(root, 100); console.log(count); // This code is contributed by Prince |
Java
import java.util.*; class Node { int val; Node left; Node right; Node( int value, Node l, Node r) { val = value; left = l; right = r; } Node( int value) { this (value, null , null ); } } class Stack { ArrayList<Integer> stack = new ArrayList<>(); int size() { return stack.size(); } int top() { return (size() > 0 ) ? stack.get(size() - 1 ) : 0 ; } void push( int val) { stack.add(val); } void pop() { if (size() >= 1 ) { stack.remove(size() - 1 ); } else { stack.clear(); } } void rolling_push( int val, int period) { int to_remove_idx = size() - period; int to_remove = (to_remove_idx < 0 ) ? 0 : stack.get(to_remove_idx); int to_add = (to_remove_idx <= 0 ) ? 0 : stack.get(to_remove_idx - 1 ); if (size() <= 0 ) { push(val); } else { push(val + stack.get(size() - 1 ) - to_remove + to_add); } } void show() { for ( int item : stack) { System.out.print(item + " " ); } System.out.println(); } } public class Main { static int count = 0 ; static Stack s = new Stack(); static void count_with_greater_sum(Node root_node, int x) { if (root_node == null ) { return ; } s.rolling_push(root_node.val, 3 ); if (s.size() >= 3 && s.top() > x) { count++; } count_with_greater_sum(root_node.left, x); count_with_greater_sum(root_node.right, x); s.pop(); } public static void main(String[] args) { Node root = new Node( 10 ); root.left = new Node( 1 ); root.right = new Node( 22 ); root.left.left = new Node( 35 ); root.left.right = new Node( 4 ); root.right.left = new Node( 15 ); root.right.right = new Node( 67 ); root.left.left.left = new Node( 57 ); root.left.left.right = new Node( 38 ); root.left.right.left = new Node( 9 ); root.left.right.right = new Node( 10 ); root.right.left.left = new Node( 110 ); root.right.left.right = new Node( 312 ); root.right.right.left = new Node( 131 ); root.right.right.right = new Node( 414 ); root.right.left.right.left = new Node( 8 ); root.right.right.right.right = new Node( 39 ); count_with_greater_sum(root, 100 ); System.out.println(count); } } |
C#
using System; using System.Collections.Generic; using System.Collections; using System.Linq; // C# code addition class Node { public int val; public Node left; public Node right; public Node( int value, Node l, Node r) { val = value; left = l; right = r; } public Node( int value){ val = value; left = null ; right = null ; } } class stack { public List< int > st = new List< int >(); public int size() { return st.Count; } public int top() { return (size() > 0) ? st[size() - 1]: 0; } public void push( int val) { st.Add(val); } public void pop() { if (size() >= 1) { st.RemoveAt(size() - 1); } else { st.Clear(); } } public void rolling_push( int val, int period) { int to_remove_idx = size() - period; int to_remove = (to_remove_idx < 0) ? 0 : st[to_remove_idx]; int to_add = (to_remove_idx <= 0) ? 0 : st[to_remove_idx - 1]; if (size() <= 0) { push(val); } else { push(val + st[size() - 1]- to_remove + to_add); } } public void show() { foreach ( var item in st){ Console.Write(item + " " ); } Console.WriteLine(); } } public class HelloWorld { static int count = 0; static stack s = new stack(); static void count_with_greater_sum(Node root_node, int x) { if (root_node == null ) { return ; } s.rolling_push(root_node.val, 3); if (s.size() >= 3 && s.top() > x) { count++; } count_with_greater_sum(root_node.left, x); count_with_greater_sum(root_node.right, x); s.pop(); } static void Main() { Node root = new Node(10); root.left = new Node(1); root.right = new Node(22); root.left.left = new Node(35); root.left.right = new Node(4); root.right.left = new Node(15); root.right.right = new Node(67); root.left.left.left = new Node(57); root.left.left.right = new Node(38); root.left.right.left = new Node(9); root.left.right.right = new Node(10); root.right.left.left = new Node(110); root.right.left.right = new Node(312); root.right.right.left = new Node(131); root.right.right.right = new Node(414); root.right.left.right.left = new Node(8); root.right.right.right.right = new Node(39); count_with_greater_sum(root, 100); Console.WriteLine(count); } } // The code is contributed by Arushi Jindal. |
6
Efficient Approach: The problem can be solved by maintaining 3 variables called grandparent, parent, and child. It can be done in constant space without using other data structures.
- Traverse the tree in preorder
- Maintain 3 variables called grandParent, parent, and child
- Whenever we have sum more than the target we can increase the count or print the triplet.
Below is the implementation of the above approach:
C++
// CPP implementation to print // the nodes having a single child #include <bits/stdc++.h> using namespace std; // Class of the Binary Tree node struct Node { int data; Node *left, *right; Node( int x) { data = x; left = right = NULL; } }; // global variable int count = 0; void preorder(Node* grandParent, Node* parent, Node* child, int sum) { if (grandParent != NULL && parent != NULL && child != NULL && (grandParent -> data + parent -> data + child->data) > sum) { count++; //uncomment below lines if you // want to print triplets /*System->out->print(grandParent ->data+"-->"+parent->data+"--> "+child->data); System->out->println();*/ } if (child == NULL) return ; preorder(parent, child, child -> left, sum); preorder(parent, child, child -> right, sum); } //Driver code int main() { Node *r10 = new Node(10); Node *r1 = new Node(1); Node *r22 = new Node(22); Node *r35 = new Node(35); Node *r4 = new Node(4); Node *r15 = new Node(15); Node *r67 = new Node(67); Node *r57 = new Node(57); Node *r38 = new Node(38); Node *r9 = new Node(9); Node *r10_2 = new Node(10); Node *r110 = new Node(110); Node *r312 = new Node(312); Node *r131 = new Node(131); Node *r414 = new Node(414); Node *r8 = new Node(8); Node *r39 = new Node(39); r10 -> left = r1; r10 -> right = r22; r1 -> left = r35; r1 -> right = r4; r22 -> left = r15; r22 -> right = r67; r35 -> left = r57; r35 -> right = r38; r4 -> left = r9; r4 -> right = r10_2; r15 -> left = r110; r15 -> right = r312; r67 -> left = r131; r67 -> right = r414; r312 -> left = r8; r414 -> right = r39; preorder(NULL, NULL, r10, 100); cout << cont; } // This code is contributed by Mohit Kumar 29 |
Java
class Node{ int data; Node left; Node right; public Node( int data) { this .data=data; } } class TreeTriplet { static int count= 0 ; // global variable public void preorder(Node grandParent,Node parent,Node child, int sum) { if (grandParent!= null && parent!= null && child!= null && (grandParent.data+parent.data+child.data) > sum) { count++; //uncomment below lines if you want to print triplets /*System.out.print(grandParent.data+"-->"+parent.data+"-->"+child.data); System.out.println();*/ } if (child== null ) return ; preorder(parent,child,child.left,sum); preorder(parent,child,child.right,sum); } public static void main(String args[]) { Node r10 = new Node( 10 ); Node r1 = new Node( 1 ); Node r22 = new Node( 22 ); Node r35 = new Node( 35 ); Node r4 = new Node( 4 ); Node r15 = new Node( 15 ); Node r67 = new Node( 67 ); Node r57 = new Node( 57 ); Node r38 = new Node( 38 ); Node r9 = new Node( 9 ); Node r10_2 = new Node( 10 ); Node r110 = new Node( 110 ); Node r312 = new Node( 312 ); Node r131 = new Node( 131 ); Node r414 = new Node( 414 ); Node r8 = new Node( 8 ); Node r39 = new Node( 39 ); r10.left=r1; r10.right=r22; r1.left=r35; r1.right=r4; r22.left=r15; r22.right=r67; r35.left=r57; r35.right=r38; r4.left=r9; r4.right=r10_2; r15.left=r110; r15.right=r312; r67.left=r131; r67.right=r414; r312.left=r8; r414.right=r39; TreeTriplet p = new TreeTriplet(); p.preorder( null , null , r10, 100 ); System.out.println(count); } } // This code is contributed by Akshay Siddhpura |
Python3
# Python3 program to implement # the above approach class Node: def __init__( self , data): self .left = None self .right = None self .data = data # global variable count = 0 def preorder(grandParent, parent, child, sum ): global count if (grandParent ! = None and parent ! = None and child ! = None and (grandParent.data + parent.data + child.data) > sum ): count + = 1 # uncomment below lines if # you want to print triplets # System.out.print(grandParent. # data+"-->"+parent.data+"--> # "+child.data); System.out.println(); if (child = = None ): return ; preorder(parent, child, child.left, sum ); preorder(parent, child, child.right, sum ); # Driver code if __name__ = = "__main__" : r10 = Node( 10 ); r1 = Node( 1 ); r22 = Node( 22 ); r35 = Node( 35 ); r4 = Node( 4 ); r15 = Node( 15 ); r67 = Node( 67 ); r57 = Node( 57 ); r38 = Node( 38 ); r9 = Node( 9 ); r10_2 = Node( 10 ); r110 = Node( 110 ); r312 = Node( 312 ); r131 = Node( 131 ); r414 = Node( 414 ); r8 = Node( 8 ); r39 = Node( 39 ); r10.left = r1; r10.right = r22; r1.left = r35; r1.right = r4; r22.left = r15; r22.right = r67; r35.left = r57; r35.right = r38; r4.left = r9; r4.right = r10_2; r15.left = r110; r15.right = r312; r67.left = r131; r67.right = r414; r312.left = r8; r414.right = r39; preorder( None , None , r10, 100 ) print (count); # This code is contributed by Rutvik_56 |
C#
// C# program to find an index which has // same number of even elements on left and // right, Or same number of odd elements on // left and right. using System; public class Node { public int data; public Node left; public Node right; public Node( int data) { this .data = data; } } class GFG { static int count = 0; // global variable public void preorder(Node grandParent, Node parent, Node child, int sum) { if (grandParent != null && parent != null && child != null && (grandParent.data + parent.data + child.data) > sum) { count++; // uncomment below lines if you want to print triplets /*System.out.print(grandParent.data+"-->"+parent.data+"-->"+child.data); System.out.println();*/ } if (child == null ) return ; preorder(parent,child,child.left,sum); preorder(parent,child,child.right,sum); } // Driver Code public static void Main(String []args) { Node r10 = new Node(10); Node r1 = new Node(1); Node r22 = new Node(22); Node r35 = new Node(35); Node r4 = new Node(4); Node r15 = new Node(15); Node r67 = new Node(67); Node r57 = new Node(57); Node r38 = new Node(38); Node r9 = new Node(9); Node r10_2 = new Node(10); Node r110 = new Node(110); Node r312 = new Node(312); Node r131 = new Node(131); Node r414 = new Node(414); Node r8 = new Node(8); Node r39 = new Node(39); r10.left = r1; r10.right = r22; r1.left = r35; r1.right = r4; r22.left = r15; r22.right = r67; r35.left = r57; r35.right = r38; r4.left = r9; r4.right = r10_2; r15.left = r110; r15.right = r312; r67.left = r131; r67.right = r414; r312.left = r8; r414.right = r39; GFG p = new GFG(); p.preorder( null , null , r10,100); Console.WriteLine(count); } } // This code is contributed by 29AjayKumar |
Javascript
<script> class Node { constructor(data) { this .data = data; this .left = this .right = null ; } } let count = 0; // global variable function preorder(grandParent, parent, child, sum) { if (grandParent != null && parent != null && child != null && (grandParent.data+parent.data+child.data) > sum) { count++; // uncomment below lines if you want to print triplets /* System.out.print(grandParent.data+"-->"+parent.data+"-->"+child.data); System.out.println();*/ } if (child == null ) return ; preorder(parent, child, child.left, sum); preorder(parent, child, child.right, sum); } let r10 = new Node(10); let r1 = new Node(1); let r22 = new Node(22); let r35 = new Node(35); let r4 = new Node(4); let r15 = new Node(15); let r67 = new Node(67); let r57 = new Node(57); let r38 = new Node(38); let r9 = new Node(9); let r10_2 = new Node(10); let r110 = new Node(110); let r312 = new Node(312); let r131 = new Node(131); let r414 = new Node(414); let r8 = new Node(8); let r39 = new Node(39); r10.left = r1; r10.right = r22; r1.left = r35; r1.right = r4; r22.left = r15; r22.right = r67; r35.left = r57; r35.right = r38; r4.left = r9; r4.right = r10_2; r15.left = r110; r15.right = r312; r67.left = r131; r67.right = r414; r312.left = r8; r414.right = r39; preorder( null , null , r10,100); document.write(count); // This code is contributed by unknown2108 </script> |
6
Time complexity: O(n)
Auxiliary Space : The space complexity of the above code is O(1), as it does not use any additional data structures and only uses a few variables for storing intermediate values. (the recursion stack is not being considered for this approach).
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 neveropen!