Given a binary tree, a target node and a positive integer K on it, the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).
Examples:
Input: target = 9, K = 1,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22Input: target = 40, K = 2,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113
Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:
Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.
Follow the steps mentioned below to implement the approach:
- Create a hash table (say par)to store the parent of each node.
- Perform a DFS and store the parent of each node.
- Now find the target in the tree.
- Create a hash table to mark the visited nodes.
- Start a DFS from target:
- If the distance is not K, add the value in the final sum.
- If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
- Return the sum of its neighbours while the recursion for the current node is complete
- Return the sum of all the nodes within K distance from the target.
Below is the implementation of the above approach:
C++
// C++ code to implement above approach#include <bits/stdc++.h>using namespace std;// Structure of a tree nodestruct Node { int data; Node* left; Node* right; Node(int val) { this->data = val; this->left = 0; this->right = 0; }};// Function for marking the parent node// for all the nodes using DFSvoid dfs(Node* root, unordered_map<Node*, Node*>& par){ if (root == 0) return; if (root->left != 0) par[root->left] = root; if (root->right != 0) par[root->right] = root; dfs(root->left, par); dfs(root->right, par);}// Function calling for finding the sumvoid dfs3(Node* root, int h, int& sum, int k, unordered_map<Node*, int>& vis, unordered_map<Node*, Node*>& par){ if (h == k + 1) return; if (root == 0) return; if (vis[root]) return; sum += root->data; vis[root] = 1; dfs3(root->left, h + 1, sum, k, vis, par); dfs3(root->right, h + 1, sum, k, vis, par); dfs3(par[root], h + 1, sum, k, vis, par);}// Function for finding// the target node in the treeNode* dfs2(Node* root, int target){ if (root == 0) return 0; if (root->data == target) return root; Node* node1 = dfs2(root->left, target); Node* node2 = dfs2(root->right, target); if (node1 != 0) return node1; if (node2 != 0) return node2;}// Function to find the sum at distance Kint sum_at_distK(Node* root, int target, int k){ // Hash Table to store // the parent of a node unordered_map<Node*, Node*> par; // Make the parent of root node as NULL // since it does not have any parent par[root] = 0; // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node* node = dfs2(root, target); // Hash Table to mark // the visited nodes unordered_map<Node*, int> vis; int sum = 0; // DFS call to find the sum dfs3(node, 0, sum, k, vis, par); return sum;}// Driver Codeint main(){ // Taking Input Node* root = new Node(1); root->left = new Node(2); root->right = new Node(9); root->left->left = new Node(4); root->right->left = new Node(5); root->right->right = new Node(7); root->left->left->left = new Node(8); root->left->left->right = new Node(19); root->right->right->left = new Node(20); root->right->right->right = new Node(11); root->left->left->left->left = new Node(30); root->left->left->right->left = new Node(40); root->left->left->right->right = new Node(50); int target = 9, K = 1; // Function call cout << sum_at_distK(root, target, K); return 0;} |
Java
// Java code to implement above approachimport java.util.*;public class Main { // Structure of a tree node static class Node { int data; Node left; Node right; Node(int val) { this.data = val; this.left = null; this.right = null; } } // Function for marking the parent node // for all the nodes using DFS static void dfs(Node root, HashMap <Node, Node> par) { if (root == null) return; if (root.left != null) par.put( root.left, root); if (root.right != null) par.put( root.right, root); dfs(root.left, par); dfs(root.right, par); } static int sum; // Function calling for finding the sum static void dfs3(Node root, int h, int k, HashMap <Node, Integer> vis, HashMap <Node, Node> par) { if (h == k + 1) return; if (root == null) return; if (vis.containsKey(root)) return; sum += root.data; vis.put(root, 1); dfs3(root.left, h + 1, k, vis, par); dfs3(root.right, h + 1, k, vis, par); dfs3(par.get(root), h + 1, k, vis, par); } // Function for finding // the target node in the tree static Node dfs2(Node root, int target) { if (root == null) return null; if (root.data == target) return root; Node node1 = dfs2(root.left, target); Node node2 = dfs2(root.right, target); if (node1 != null) return node1; if (node2 != null) return node2; return null; } static int sum_at_distK(Node root, int target, int k) { // Hash Map to store // the parent of a node HashMap <Node, Node> par = new HashMap<>(); // Make the parent of root node as NULL // since it does not have any parent par.put(root, null); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node node = dfs2(root, target); // Hash Map to mark // the visited nodes HashMap <Node, Integer> vis = new HashMap<>(); sum = 0; // DFS call to find the sum dfs3(node, 0, k, vis, par); return sum; } public static void main(String args[]) { // Taking Input Node root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); int target = 9, K = 1; // Function call System.out.println( sum_at_distK(root, target, K) ); }}// This code has been contributed by Sachin Sahara (sachin801) |
Python3
# python program to implement above approach# structure of tree nodeclass Node: def __init__(self, val): self.data = val self.left = None self.right = None# function for making the parent node# for all the nodes using DFSdef dfs(root, par): if(root is None): return if(root.left is not None): par[root.left] = root if(root.right is not None): par[root.right] = root dfs(root.left, par) dfs(root.right, par)# function calling for finding the sumsumm = 0def dfs3(root, h, k, vis, par): if(h == k+1): return if(root is None): return if(vis.get(root) == 1): return global summ summ += root.data vis[root] = 1 dfs3(root.left, h+1, k, vis, par) dfs3(root.right, h+1, k, vis, par) dfs3(par[root], h+1, k, vis, par)# function for finding# the target node in the treedef dfs2(root, target): if(root is None): return None if(root.data == target): return root node1 = dfs2(root.left, target) node2 = dfs2(root.right, target) if(node1 is not None): return node1 if(node2 is not None): return node2 # function tofind the sum at distance kdef sum_at_distK(root, target, k): # hash table to store # the parent of a node par = {} # make the parent of root node as None # since it does not have any parent par[root] = 0 # make the parent node for all the # nodes using DFS dfs(root, par) # find the target node in the tree node = dfs2(root, target) # hash table to make the visited nodes vis = {} # dfs call to find the sum dfs3(node, 0, k, vis, par)# driver programroot = Node(1)root.left = Node(2)root.right = Node(9)root.left.left = Node(4)root.right.left = Node(5)root.right.right = Node(7)root.left.left.left = Node(8)root.left.left.right = Node(19)root.right.right.left = Node(20)root.right.right.right = Node(11)root.left.left.left.left = Node(30)root.left.left.right.left = Node(40)root.left.left.right.right = Node(50)target = 9K = 1# function callsum_at_distK(root, target, K)print(summ)# this code is contributed by Yash Agarwal(yashagarwal2852002) |
C#
// C# code to implement above approachusing System;using System.Collections.Generic;public class GFG { // Structure of a tree node class Node { public int data; public Node left; public Node right; public Node(int val) { this.data = val; this.left = null; this.right = null; } } // Function for marking the parent node // for all the nodes using DFS static void dfs(Node root, Dictionary<Node, Node> par) { if (root == null) return; if (root.left != null) par.Add(root.left, root); if (root.right != null) par.Add(root.right, root); dfs(root.left, par); dfs(root.right, par); } static int sum; // Function calling for finding the sum static void dfs3(Node root, int h, int k, Dictionary<Node, int> vis, Dictionary<Node, Node> par) { if (h == k + 1) return; if (root == null) return; if (vis.ContainsKey(root)) return; sum += root.data; vis.Add(root, 1); dfs3(root.left, h + 1, k, vis, par); dfs3(root.right, h + 1, k, vis, par); dfs3(par[root], h + 1, k, vis, par); } // Function for finding // the target node in the tree static Node dfs2(Node root, int target) { if (root == null) return null; if (root.data == target) return root; Node node1 = dfs2(root.left, target); Node node2 = dfs2(root.right, target); if (node1 != null) return node1; if (node2 != null) return node2; return null; } static int sum_at_distK(Node root, int target, int k) { // Hash Map to store // the parent of a node Dictionary<Node, Node> par = new Dictionary<Node, Node>(); // Make the parent of root node as NULL // since it does not have any parent par.Add(root, null); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree Node node = dfs2(root, target); // Hash Map to mark // the visited nodes Dictionary<Node, int> vis = new Dictionary<Node, int>(); sum = 0; // DFS call to find the sum dfs3(node, 0, k, vis, par); return sum; } static public void Main() { // Code Node root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); int target = 9, K = 1; // Function call Console.Write(sum_at_distK(root, target, K)); }}// This code is contributed by lokesh(lokeshmvs21). |
Javascript
// JavaScript code for the above approach // Structure of a tree node class Node { constructor(val) { this.data = val; this.left = null; this.right = null; } } // Function for marking the parent node // for all the nodes using DFS function dfs(root, par) { if (root === null) return; if (root.left !== null) par.set(root.left, root); if (root.right !== null) par.set(root.right, root); dfs(root.left, par); dfs(root.right, par); } let sum = 0; // Function calling for finding the sum function dfs3(root, h, k, vis, par) { if (h === k + 1) return; if (root === null) return; if (vis.has(root)) return; sum += root.data; vis.set(root, 1); dfs3(root.left, h + 1, k, vis, par); dfs3(root.right, h + 1, k, vis, par); if (par.get(root) !== null && vis.has(par.get(root))) { dfs3(par.get(root), h + 1, k, vis, par); } } // Function for finding // the target node in the tree function dfs2(root, target) { if (root === null) return null; if (root.data === target) return root; let node1 = dfs2(root.left, target); let node2 = dfs2(root.right, target); if (node1 !== null) return node1; if (node2 !== null) return node2; return null; } function sumAtDistK(root, target, k) { // Map to store the parent of a node let par = new Map(); // Make the parent of root node as NULL // since it does not have any parent par.set(root, null); // Mark the parent node for all the // nodes using DFS dfs(root, par); // Find the target node in the tree let node = dfs2(root, target); // Map to mark the visited nodes let vis = new Map(); sum = 1; // DFS call to find the sum dfs3(node, 0, k, vis, par); return sum; } // Taking Input let root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); let target = 9; let K = 1; console.log(sumAtDistK(root, target, K));// This code is contributed by Potta Lokesh |
22
Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)
Approach using BFS:-
- We will be using level order traversal to find the sum of nodes
Implementation:-
- First we will find the target node using level order traversal.
- While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
- After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.
C++
// C++ code to implement above approach#include <bits/stdc++.h>using namespace std;// Structure of a tree nodestruct Node { int data; Node* left; Node* right; Node(int val) { this->data = val; this->left = 0; this->right = 0; }};// Function to find the sum at distance Kint sum_at_distK(Node* root, int target, int k){ //variable to store answer int ans = 0; //queue for bfs queue<Node*> q; q.push(root); //to store target node Node* need; //map to store parent of each node unordered_map<Node*, Node*> m; //bfs while(q.size()){ int s = q.size(); //traversing to current level for(int i=0;i<s;i++){ Node* temp = q.front(); q.pop(); //if target value found if(temp->data==target) need=temp; if(temp->left){ q.push(temp->left); m[temp->left]=temp; } if(temp->right){ q.push(temp->right); m[temp->right]=temp; } } } //map to store occurrence of a node //that is the node has taken or not unordered_map<Node*, int> mm; q.push(need); //to store current distance int c = 0; while(q.size()){ int s = q.size(); for(int i=0;i<s;i++){ Node* temp = q.front(); q.pop(); mm[temp] = 1; ans+=temp->data; //moving left if(temp->left&&mm[temp->left]==0){ q.push(temp->left); } //moving right if(temp->right&&mm[temp->right]==0){ q.push(temp->right); } //movinf to parent if(m[temp]&&mm[m[temp]]==0){ q.push(m[temp]); } } c++; if(c>k)break; } return ans;}// Driver Codeint main(){ // Taking Input Node* root = new Node(1); root->left = new Node(2); root->right = new Node(9); root->left->left = new Node(4); root->right->left = new Node(5); root->right->right = new Node(7); root->left->left->left = new Node(8); root->left->left->right = new Node(19); root->right->right->left = new Node(20); root->right->right->right = new Node(11); root->left->left->left->left = new Node(30); root->left->left->right->left = new Node(40); root->left->left->right->right = new Node(50); int target = 9, K = 1; // Function call cout << sum_at_distK(root, target, K); return 0;}//code contributed by shubhamrajput6156 |
Java
import java.util.*;// Structure of a tree nodeclass Node { int data; Node left; Node right; public Node(int val) { this.data = val; this.left = null; this.right = null; }}public class Main { // Function to find the sum at distance K public static int sumAtDistK(Node root, int target, int k) { // Variable to store the answer int ans = 0; // Queue for BFS Queue<Node> q = new LinkedList<>(); q.add(root); // To store the target node Node need = null; // Map to store the parent of each node Map<Node, Node> parentMap = new HashMap<>(); // BFS while (!q.isEmpty()) { int size = q.size(); // Traverse the current level for (int i = 0; i < size; i++) { Node temp = q.poll(); // If the target value is found if (temp.data == target) { need = temp; } if (temp.left != null) { q.add(temp.left); parentMap.put(temp.left, temp); } if (temp.right != null) { q.add(temp.right); parentMap.put(temp.right, temp); } } } // Map to store the occurrence of a node (whether it has been visited) Map<Node, Integer> visitedMap = new HashMap<>(); q.add(need); // Current distance int currentDistance = 0; while (!q.isEmpty()) { int size = q.size(); for (int i = 0; i < size; i++) { Node temp = q.poll(); visitedMap.put(temp, 1); ans += temp.data; // Moving left if (temp.left != null && visitedMap.getOrDefault(temp.left, 0) == 0) { q.add(temp.left); } // Moving right if (temp.right != null && visitedMap.getOrDefault(temp.right, 0) == 0) { q.add(temp.right); } // Moving to parent if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0) == 0) { q.add(parentMap.get(temp)); } } currentDistance++; if (currentDistance > k) { break; } } return ans; } // Driver code public static void main(String[] args) { Node root = new Node(1); root.left = new Node(2); root.right = new Node(9); root.left.left = new Node(4); root.right.left = new Node(5); root.right.right = new Node(7); root.left.left.left = new Node(8); root.left.left.right = new Node(19); root.right.right.left = new Node(20); root.right.right.right = new Node(11); root.left.left.left.left = new Node(30); root.left.left.right.left = new Node(40); root.left.left.right.right = new Node(50); int target = 9, K = 1; // Function call System.out.println(sumAtDistK(root, target, K)); }} |
Python3
from collections import dequeclass Node: def __init__(self, val): self.data = val self.left = None self.right = None# Function to find the sum at distance Kdef sum_at_distK(root, target, k): ans = 0 # Queue for BFS q = deque() q.append(root) need = None # Dictionary to store parent of each node m = {} # BFS traversal to find the target node while q: s = len(q) # Traversing the current level for i in range(s): temp = q.popleft() if temp.data == target: need = temp if temp.left: q.append(temp.left) m[temp.left] = temp if temp.right: q.append(temp.right) m[temp.right] = temp # Dictionary to store occurrence of a node (visited or not) mm = {} q.append(need) c = 0 # BFS traversal within K distance while q: s = len(q) for i in range(s): temp = q.popleft() mm[temp] = 1 ans += temp.data # Moving left if temp.left and temp.left not in mm: q.append(temp.left) # Moving right if temp.right and temp.right not in mm: q.append(temp.right) # Moving to parent if temp in m and m[temp] not in mm: q.append(m[temp]) c += 1 if c > k: break return ans# Driver Code# Taking Inputroot = Node(1)root.left = Node(2)root.right = Node(9)root.left.left = Node(4)root.right.left = Node(5)root.right.right = Node(7)root.left.left.left = Node(8)root.left.left.right = Node(19)root.right.right.left = Node(20)root.right.right.right = Node(11)root.left.left.left.left = Node(30)root.left.left.right.left = Node(40)root.left.left.right.right = Node(50)target = 9K = 1# Function callprint(sum_at_distK(root, target, K)) |
Javascript
class Node { constructor(val) { this.data = val; this.left = null; this.right = null; }}// Function to find the sum at distance Kfunction sum_at_distK(root, target, k) { let ans = 0; // Queue for BFS let q = []; q.push(root); let need = null; // Map to store parent of each node let m = new Map(); // BFS traversal to find the target node while (q.length) { let s = q.length; // Traversing the current level for (let i = 0; i < s; i++) { let temp = q.shift(); if (temp.data === target) { need = temp; } if (temp.left) { q.push(temp.left); m.set(temp.left, temp); } if (temp.right) { q.push(temp.right); m.set(temp.right, temp); } } } // Map to store occurrence of a node (visited or not) let mm = new Map(); q.push(need); let c = 0; // BFS traversal within K distance while (q.length) { let s = q.length; for (let i = 0; i < s; i++) { let temp = q.shift(); mm.set(temp, 1); ans += temp.data; // Moving left if (temp.left && !mm.has(temp.left)) { q.push(temp.left); } // Moving right if (temp.right && !mm.has(temp.right)) { q.push(temp.right); } // Moving to parent if (m.has(temp) && !mm.has(m.get(temp))) { q.push(m.get(temp)); } } c++; if (c > k) break; } return ans;}// Driver Code// Taking Inputlet root = new Node(1);root.left = new Node(2);root.right = new Node(9);root.left.left = new Node(4);root.right.left = new Node(5);root.right.right = new Node(7);root.left.left.left = new Node(8);root.left.left.right = new Node(19);root.right.right.left = new Node(20);root.right.right.right = new Node(11);root.left.left.left.left = new Node(30);root.left.left.right.left = new Node(40);root.left.left.right.right = new Node(50);let target = 9, K = 1;// Function callconsole.log(sum_at_distK(root, target, K)); |
22
Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 neveropen!
