Given a binary tree, a target node and a positive integer K on it, the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).
Examples:
Input: target = 9, K = 1,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22Input: target = 40, K = 2,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113
Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:
Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.
Follow the steps mentioned below to implement the approach:
- Create a hash table (say par)to store the parent of each node.
- Perform a DFS and store the parent of each node.
- Now find the target in the tree.
- Create a hash table to mark the visited nodes.
- Start a DFS from target:
- If the distance is not K, add the value in the final sum.
- If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
- Return the sum of its neighbours while the recursion for the current node is complete
- Return the sum of all the nodes within K distance from the target.
Below is the implementation of the above approach:
C++
// C++ code to implement above approach #include <bits/stdc++.h> using namespace std; // Structure of a tree node struct Node { int data; Node* left; Node* right; Node( int val) { this ->data = val; this ->left = 0; this ->right = 0; } }; // Function for marking the parent node // for all the nodes using DFS void dfs(Node* root, unordered_map<Node*, Node*>& par) { if (root == 0) return ; if (root->left != 0) par[root->left] = root; if (root->right != 0) par[root->right] = root; dfs(root->left, par); dfs(root->right, par); } // Function calling for finding the sum void dfs3(Node* root, int h, int & sum, int k, unordered_map<Node*, int >& vis, unordered_map<Node*, Node*>& par) { if (h == k + 1) return ; if (root == 0) return ; if (vis[root]) return ; sum += root->data; vis[root] = 1; dfs3(root->left, h + 1, sum, k, vis, par); dfs3(root->right, h + 1, sum, k, vis, par); dfs3(par[root], h + 1, sum, k, vis, par); } // Function for finding // the target node in the tree Node* dfs2(Node* root, int target) { if (root == 0) return 0; if (root->data == target) return root; Node* node1 = dfs2(root->left, target); Node* node2 = dfs2(root->right, target); if (node1 != 0) return node1; if (node2 != 0) return node2; } // Function to find the sum at distance K int sum_at_distK(Node* root, int target, int k) { // Hash Table to store // the parent of a node unordered_map<Node*, Node*> par; // Make the parent of root node as NULL // since it does not have any parent par[root] = 0; // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node* node = dfs2(root, target); // Hash Table to mark // the visited nodes unordered_map<Node*, int > vis; int sum = 0; // DFS call to find the sum dfs3(node, 0, sum, k, vis, par); return sum; } // Driver Code int main() { // Taking Input Node* root = new Node(1); root->left = new Node(2); root->right = new Node(9); root->left->left = new Node(4); root->right->left = new Node(5); root->right->right = new Node(7); root->left->left->left = new Node(8); root->left->left->right = new Node(19); root->right->right->left = new Node(20); root->right->right->right = new Node(11); root->left->left->left->left = new Node(30); root->left->left->right->left = new Node(40); root->left->left->right->right = new Node(50); int target = 9, K = 1; // Function call cout << sum_at_distK(root, target, K); return 0; } |
Java
// Java code to implement above approach import java.util.*; public class Main { // Structure of a tree node static class Node { int data; Node left; Node right; Node( int val) { this .data = val; this .left = null ; this .right = null ; } } // Function for marking the parent node // for all the nodes using DFS static void dfs(Node root, HashMap <Node, Node> par) { if (root == null ) return ; if (root.left != null ) par.put( root.left, root); if (root.right != null ) par.put( root.right, root); dfs(root.left, par); dfs(root.right, par); } static int sum; // Function calling for finding the sum static void dfs3(Node root, int h, int k, HashMap <Node, Integer> vis, HashMap <Node, Node> par) { if (h == k + 1 ) return ; if (root == null ) return ; if (vis.containsKey(root)) return ; sum += root.data; vis.put(root, 1 ); dfs3(root.left, h + 1 , k, vis, par); dfs3(root.right, h + 1 , k, vis, par); dfs3(par.get(root), h + 1 , k, vis, par); } // Function for finding // the target node in the tree static Node dfs2(Node root, int target) { if (root == null ) return null ; if (root.data == target) return root; Node node1 = dfs2(root.left, target); Node node2 = dfs2(root.right, target); if (node1 != null ) return node1; if (node2 != null ) return node2; return null ; } static int sum_at_distK(Node root, int target, int k) { // Hash Map to store // the parent of a node HashMap <Node, Node> par = new HashMap<>(); // Make the parent of root node as NULL // since it does not have any parent par.put(root, null ); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node node = dfs2(root, target); // Hash Map to mark // the visited nodes HashMap <Node, Integer> vis = new HashMap<>(); sum = 0 ; // DFS call to find the sum dfs3(node, 0 , k, vis, par); return sum; } public static void main(String args[]) { // Taking Input Node root = new Node( 1 ); root.left = new Node( 2 ); root.right = new Node( 9 ); root.left.left = new Node( 4 ); root.right.left = new Node( 5 ); root.right.right = new Node( 7 ); root.left.left.left = new Node( 8 ); root.left.left.right = new Node( 19 ); root.right.right.left = new Node( 20 ); root.right.right.right = new Node( 11 ); root.left.left.left.left = new Node( 30 ); root.left.left.right.left = new Node( 40 ); root.left.left.right.right = new Node( 50 ); int target = 9 , K = 1 ; // Function call System.out.println( sum_at_distK(root, target, K) ); } } // This code has been contributed by Sachin Sahara (sachin801) |
Python3
# python program to implement above approach # structure of tree node class Node: def __init__( self , val): self .data = val self .left = None self .right = None # function for making the parent node # for all the nodes using DFS def dfs(root, par): if (root is None ): return if (root.left is not None ): par[root.left] = root if (root.right is not None ): par[root.right] = root dfs(root.left, par) dfs(root.right, par) # function calling for finding the sum summ = 0 def dfs3(root, h, k, vis, par): if (h = = k + 1 ): return if (root is None ): return if (vis.get(root) = = 1 ): return global summ summ + = root.data vis[root] = 1 dfs3(root.left, h + 1 , k, vis, par) dfs3(root.right, h + 1 , k, vis, par) dfs3(par[root], h + 1 , k, vis, par) # function for finding # the target node in the tree def dfs2(root, target): if (root is None ): return None if (root.data = = target): return root node1 = dfs2(root.left, target) node2 = dfs2(root.right, target) if (node1 is not None ): return node1 if (node2 is not None ): return node2 # function tofind the sum at distance k def sum_at_distK(root, target, k): # hash table to store # the parent of a node par = {} # make the parent of root node as None # since it does not have any parent par[root] = 0 # make the parent node for all the # nodes using DFS dfs(root, par) # find the target node in the tree node = dfs2(root, target) # hash table to make the visited nodes vis = {} # dfs call to find the sum dfs3(node, 0 , k, vis, par) # driver program root = Node( 1 ) root.left = Node( 2 ) root.right = Node( 9 ) root.left.left = Node( 4 ) root.right.left = Node( 5 ) root.right.right = Node( 7 ) root.left.left.left = Node( 8 ) root.left.left.right = Node( 19 ) root.right.right.left = Node( 20 ) root.right.right.right = Node( 11 ) root.left.left.left.left = Node( 30 ) root.left.left.right.left = Node( 40 ) root.left.left.right.right = Node( 50 ) target = 9 K = 1 # function call sum_at_distK(root, target, K) print (summ) # this code is contributed by Yash Agarwal(yashagarwal2852002) |
C#
// C# code to implement above approach using System; using System.Collections.Generic; public class GFG { // Structure of a tree node class Node { public int data; public Node left; public Node right; public Node( int val) { this .data = val; this .left = null ; this .right = null ; } } // Function for marking the parent node // for all the nodes using DFS static void dfs(Node root, Dictionary<Node, Node> par) { if (root == null ) return ; if (root.left != null ) par.Add(root.left, root); if (root.right != null ) par.Add(root.right, root); dfs(root.left, par); dfs(root.right, par); } static int sum; // Function calling for finding the sum static void dfs3(Node root, int h, int k, Dictionary<Node, int > vis, Dictionary<Node, Node> par) { if (h == k + 1) return ; if (root == null ) return ; if (vis.ContainsKey(root)) return ; sum += root.data; vis.Add(root, 1); dfs3(root.left, h + 1, k, vis, par); dfs3(root.right, h + 1, k, vis, par); dfs3(par[root], h + 1, k, vis, par); } // Function for finding // the target node in the tree static Node dfs2(Node root, int target) { if (root == null ) return null ; if (root.data == target) return root; Node node1 = dfs2(root.left, target); Node node2 = dfs2(root.right, target); if (node1 != null ) return node1; if (node2 != null ) return node2; return null ; } static int sum_at_distK(Node root, int target, int k) { // Hash Map to store // the parent of a node Dictionary<Node, Node> par = new Dictionary<Node, Node>(); // Make the parent of root node as NULL // since it does not have any parent par.Add(root, null ); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node node = dfs2(root, target); // Hash Map to mark // the visited nodes Dictionary<Node, int > vis = new Dictionary<Node, int >(); sum = 0; // DFS call to find the sum dfs3(node, 0, k, vis, par); return sum; } static public void Main() { // Code Node root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); int target = 9, K = 1; // Function call Console.Write(sum_at_distK(root, target, K)); } } // This code is contributed by lokesh(lokeshmvs21). |
Javascript
// JavaScript code for the above approach // Structure of a tree node class Node { constructor(val) { this .data = val; this .left = null ; this .right = null ; } } // Function for marking the parent node // for all the nodes using DFS function dfs(root, par) { if (root === null ) return ; if (root.left !== null ) par.set(root.left, root); if (root.right !== null ) par.set(root.right, root); dfs(root.left, par); dfs(root.right, par); } let sum = 0; // Function calling for finding the sum function dfs3(root, h, k, vis, par) { if (h === k + 1) return ; if (root === null ) return ; if (vis.has(root)) return ; sum += root.data; vis.set(root, 1); dfs3(root.left, h + 1, k, vis, par); dfs3(root.right, h + 1, k, vis, par); if (par.get(root) !== null && vis.has(par.get(root))) { dfs3(par.get(root), h + 1, k, vis, par); } } // Function for finding // the target node in the tree function dfs2(root, target) { if (root === null ) return null ; if (root.data === target) return root; let node1 = dfs2(root.left, target); let node2 = dfs2(root.right, target); if (node1 !== null ) return node1; if (node2 !== null ) return node2; return null ; } function sumAtDistK(root, target, k) { // Map to store the parent of a node let par = new Map(); // Make the parent of root node as NULL // since it does not have any parent par.set(root, null ); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree let node = dfs2(root, target); // Map to mark the visited nodes let vis = new Map(); sum = 1; // DFS call to find the sum dfs3(node, 0, k, vis, par); return sum; } // Taking Input let root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); let target = 9; let K = 1; console.log(sumAtDistK(root, target, K)); // This code is contributed by Potta Lokesh |
22
Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)
Approach using BFS:-
- We will be using level order traversal to find the sum of nodes
Implementation:-
- First we will find the target node using level order traversal.
- While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
- After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.
C++
// C++ code to implement above approach #include <bits/stdc++.h> using namespace std; // Structure of a tree node struct Node { int data; Node* left; Node* right; Node( int val) { this ->data = val; this ->left = 0; this ->right = 0; } }; // Function to find the sum at distance K int sum_at_distK(Node* root, int target, int k) { //variable to store answer int ans = 0; //queue for bfs queue<Node*> q; q.push(root); //to store target node Node* need; //map to store parent of each node unordered_map<Node*, Node*> m; //bfs while (q.size()){ int s = q.size(); //traversing to current level for ( int i=0;i<s;i++){ Node* temp = q.front(); q.pop(); //if target value found if (temp->data==target) need=temp; if (temp->left){ q.push(temp->left); m[temp->left]=temp; } if (temp->right){ q.push(temp->right); m[temp->right]=temp; } } } //map to store occurrence of a node //that is the node has taken or not unordered_map<Node*, int > mm; q.push(need); //to store current distance int c = 0; while (q.size()){ int s = q.size(); for ( int i=0;i<s;i++){ Node* temp = q.front(); q.pop(); mm[temp] = 1; ans+=temp->data; //moving left if (temp->left&&mm[temp->left]==0){ q.push(temp->left); } //moving right if (temp->right&&mm[temp->right]==0){ q.push(temp->right); } //movinf to parent if (m[temp]&&mm[m[temp]]==0){ q.push(m[temp]); } } c++; if (c>k) break ; } return ans; } // Driver Code int main() { // Taking Input Node* root = new Node(1); root->left = new Node(2); root->right = new Node(9); root->left->left = new Node(4); root->right->left = new Node(5); root->right->right = new Node(7); root->left->left->left = new Node(8); root->left->left->right = new Node(19); root->right->right->left = new Node(20); root->right->right->right = new Node(11); root->left->left->left->left = new Node(30); root->left->left->right->left = new Node(40); root->left->left->right->right = new Node(50); int target = 9, K = 1; // Function call cout << sum_at_distK(root, target, K); return 0; } //code contributed by shubhamrajput6156 |
Java
import java.util.*; // Structure of a tree node class Node { int data; Node left; Node right; public Node( int val) { this .data = val; this .left = null ; this .right = null ; } } public class Main { // Function to find the sum at distance K public static int sumAtDistK(Node root, int target, int k) { // Variable to store the answer int ans = 0 ; // Queue for BFS Queue<Node> q = new LinkedList<>(); q.add(root); // To store the target node Node need = null ; // Map to store the parent of each node Map<Node, Node> parentMap = new HashMap<>(); // BFS while (!q.isEmpty()) { int size = q.size(); // Traverse the current level for ( int i = 0 ; i < size; i++) { Node temp = q.poll(); // If the target value is found if (temp.data == target) { need = temp; } if (temp.left != null ) { q.add(temp.left); parentMap.put(temp.left, temp); } if (temp.right != null ) { q.add(temp.right); parentMap.put(temp.right, temp); } } } // Map to store the occurrence of a node (whether it has been visited) Map<Node, Integer> visitedMap = new HashMap<>(); q.add(need); // Current distance int currentDistance = 0 ; while (!q.isEmpty()) { int size = q.size(); for ( int i = 0 ; i < size; i++) { Node temp = q.poll(); visitedMap.put(temp, 1 ); ans += temp.data; // Moving left if (temp.left != null && visitedMap.getOrDefault(temp.left, 0 ) == 0 ) { q.add(temp.left); } // Moving right if (temp.right != null && visitedMap.getOrDefault(temp.right, 0 ) == 0 ) { q.add(temp.right); } // Moving to parent if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0 ) == 0 ) { q.add(parentMap.get(temp)); } } currentDistance++; if (currentDistance > k) { break ; } } return ans; } // Driver code public static void main(String[] args) { Node root = new Node( 1 ); root.left = new Node( 2 ); root.right = new Node( 9 ); root.left.left = new Node( 4 ); root.right.left = new Node( 5 ); root.right.right = new Node( 7 ); root.left.left.left = new Node( 8 ); root.left.left.right = new Node( 19 ); root.right.right.left = new Node( 20 ); root.right.right.right = new Node( 11 ); root.left.left.left.left = new Node( 30 ); root.left.left.right.left = new Node( 40 ); root.left.left.right.right = new Node( 50 ); int target = 9 , K = 1 ; // Function call System.out.println(sumAtDistK(root, target, K)); } } |
Python3
from collections import deque class Node: def __init__( self , val): self .data = val self .left = None self .right = None # Function to find the sum at distance K def sum_at_distK(root, target, k): ans = 0 # Queue for BFS q = deque() q.append(root) need = None # Dictionary to store parent of each node m = {} # BFS traversal to find the target node while q: s = len (q) # Traversing the current level for i in range (s): temp = q.popleft() if temp.data = = target: need = temp if temp.left: q.append(temp.left) m[temp.left] = temp if temp.right: q.append(temp.right) m[temp.right] = temp # Dictionary to store occurrence of a node (visited or not) mm = {} q.append(need) c = 0 # BFS traversal within K distance while q: s = len (q) for i in range (s): temp = q.popleft() mm[temp] = 1 ans + = temp.data # Moving left if temp.left and temp.left not in mm: q.append(temp.left) # Moving right if temp.right and temp.right not in mm: q.append(temp.right) # Moving to parent if temp in m and m[temp] not in mm: q.append(m[temp]) c + = 1 if c > k: break return ans # Driver Code # Taking Input root = Node( 1 ) root.left = Node( 2 ) root.right = Node( 9 ) root.left.left = Node( 4 ) root.right.left = Node( 5 ) root.right.right = Node( 7 ) root.left.left.left = Node( 8 ) root.left.left.right = Node( 19 ) root.right.right.left = Node( 20 ) root.right.right.right = Node( 11 ) root.left.left.left.left = Node( 30 ) root.left.left.right.left = Node( 40 ) root.left.left.right.right = Node( 50 ) target = 9 K = 1 # Function call print (sum_at_distK(root, target, K)) |
Javascript
class Node { constructor(val) { this .data = val; this .left = null ; this .right = null ; } } // Function to find the sum at distance K function sum_at_distK(root, target, k) { let ans = 0; // Queue for BFS let q = []; q.push(root); let need = null ; // Map to store parent of each node let m = new Map(); // BFS traversal to find the target node while (q.length) { let s = q.length; // Traversing the current level for (let i = 0; i < s; i++) { let temp = q.shift(); if (temp.data === target) { need = temp; } if (temp.left) { q.push(temp.left); m.set(temp.left, temp); } if (temp.right) { q.push(temp.right); m.set(temp.right, temp); } } } // Map to store occurrence of a node (visited or not) let mm = new Map(); q.push(need); let c = 0; // BFS traversal within K distance while (q.length) { let s = q.length; for (let i = 0; i < s; i++) { let temp = q.shift(); mm.set(temp, 1); ans += temp.data; // Moving left if (temp.left && !mm.has(temp.left)) { q.push(temp.left); } // Moving right if (temp.right && !mm.has(temp.right)) { q.push(temp.right); } // Moving to parent if (m.has(temp) && !mm.has(m.get(temp))) { q.push(m.get(temp)); } } c++; if (c > k) break ; } return ans; } // Driver Code // Taking Input let root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); let target = 9, K = 1; // Function call console.log(sum_at_distK(root, target, K)); |
22
Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 neveropen!