Thursday, January 30, 2025
Google search engine
HomeData Modelling & AISum of nodes within K distance from target

Sum of nodes within K distance from target

Given a binary tree, a target node and a positive integer K on it,  the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).


Input: target = 9, K = 1,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22

Input: target = 40,  K = 2,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113


Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:

Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.

Follow the steps mentioned below to implement the approach:

  • Create a hash table (say par)to store the parent of each node.
  • Perform a DFS and store the parent of each node.
  • Now find the target in the tree.
  • Create a hash table to mark the visited nodes.
  • Start a DFS from target:
    • If the distance is not K, add the value in the final sum.
    • If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
    • Return the sum of its neighbours while the recursion for the current node is complete
  • Return the sum of all the nodes within K distance from the target.

Below is the implementation of the above approach:


// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
        this->data = val;
        this->left = 0;
        this->right = 0;
// Function for marking the parent node
// for all the nodes using DFS
void dfs(Node* root,
         unordered_map<Node*, Node*>& par)
    if (root == 0)
    if (root->left != 0)
        par[root->left] = root;
    if (root->right != 0)
        par[root->right] = root;
    dfs(root->left, par);
    dfs(root->right, par);
// Function calling for finding the sum
void dfs3(Node* root, int h, int& sum, int k,
          unordered_map<Node*, int>& vis,
          unordered_map<Node*, Node*>& par)
    if (h == k + 1)
    if (root == 0)
    if (vis[root])
    sum += root->data;
    vis[root] = 1;
    dfs3(root->left, h + 1, sum, k, vis, par);
    dfs3(root->right, h + 1, sum, k, vis, par);
    dfs3(par[root], h + 1, sum, k, vis, par);
// Function for finding
// the target node in the tree
Node* dfs2(Node* root, int target)
    if (root == 0)
        return 0;
    if (root->data == target)
        return root;
    Node* node1 = dfs2(root->left, target);
    Node* node2 = dfs2(root->right, target);
    if (node1 != 0)
        return node1;
    if (node2 != 0)
        return node2;
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
    // Hash Table to store
    // the parent of a node
    unordered_map<Node*, Node*> par;
    // Make the parent of root node as NULL
    // since it does not have any parent
    par[root] = 0;
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
    // Find the target node in the tree
    Node* node = dfs2(root, target);
    // Hash Table to mark
    // the visited nodes
    unordered_map<Node*, int> vis;
    int sum = 0;
    // DFS call to find the sum
    dfs3(node, 0, sum, k, vis, par);
    return sum;
// Driver Code
int main()
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
        = new Node(11);
        = new Node(30);
        = new Node(40);
        = new Node(50);
    int target = 9, K = 1;
    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;


// Java code to implement above approach
import java.util.*;
public class Main {
    // Structure of a tree node
    static class Node {
        int data;
        Node left;
        Node right;
        Node(int val)
   = val;
            this.left = null;
            this.right = null;
    // Function for marking the parent node
    // for all the nodes using DFS
    static void dfs(Node root,
            HashMap <Node, Node> par)
        if (root == null)
        if (root.left != null)
            par.put( root.left, root);
        if (root.right != null)
            par.put( root.right, root);
        dfs(root.left, par);
        dfs(root.right, par);
    static int sum;
    // Function calling for finding the sum
    static void dfs3(Node root, int h, int k,
            HashMap <Node, Integer> vis,
            HashMap <Node, Node> par)
        if (h == k + 1)
        if (root == null)
        if (vis.containsKey(root))
        sum +=;
        vis.put(root, 1);
        dfs3(root.left, h + 1, k, vis, par);
        dfs3(root.right, h + 1, k, vis, par);
        dfs3(par.get(root), h + 1, k, vis, par);
    // Function for finding
    // the target node in the tree
    static Node dfs2(Node root, int target)
        if (root == null)
            return null;
        if ( == target)
            return root;
        Node node1 = dfs2(root.left, target);
        Node node2 = dfs2(root.right, target);
        if (node1 != null)
            return node1;
        if (node2 != null)
            return node2;
        return null;
    static int sum_at_distK(Node root, int target,
                 int k)
        // Hash Map to store
        // the parent of a node
        HashMap <Node, Node> par =  new HashMap<>();
        // Make the parent of root node as NULL
        // since it does not have any parent
        par.put(root, null);
        // Mark the parent node for all the
        // nodes using DFS
        dfs(root, par);
        // Find the target node in the tree
        Node node = dfs2(root, target);
        // Hash Map to mark
        // the visited nodes
        HashMap <Node, Integer> vis = new HashMap<>();
        sum = 0;
        // DFS call to find the sum
        dfs3(node, 0, k, vis, par);
        return sum;
    public static void main(String args[]) {
        // Taking Input
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
            = new Node(11);
            = new Node(30);
            = new Node(40);
            = new Node(50);
        int target = 9, K = 1;
        // Function call
        System.out.println( sum_at_distK(root, target, K) );
// This code has been contributed by Sachin Sahara (sachin801)


# python program to implement above approach
# structure of tree node
class Node:
    def __init__(self, val): = val
        self.left = None
        self.right = None
# function for making the parent node
# for all the nodes using DFS
def dfs(root, par):
    if(root is None):
    if(root.left is not None):
        par[root.left] = root
    if(root.right is not None):
        par[root.right] = root
    dfs(root.left, par)
    dfs(root.right, par)
# function calling for finding the sum
summ = 0
def dfs3(root, h, k, vis, par):
    if(h == k+1):
    if(root is None):
    if(vis.get(root) == 1):
    global summ
    summ +=
    vis[root] = 1
    dfs3(root.left, h+1, k, vis, par)
    dfs3(root.right, h+1, k, vis, par)
    dfs3(par[root], h+1, k, vis, par)
# function for finding
# the target node in the tree
def dfs2(root, target):
    if(root is None):
        return None
    if( == target):
        return root
    node1 = dfs2(root.left, target)
    node2 = dfs2(root.right, target)
    if(node1 is not None):
        return node1
    if(node2 is not None):
        return node2
# function tofind the sum at distance k
def sum_at_distK(root, target, k):
    # hash table to store
    # the parent of a node
    par = {}
    # make the parent of root node as None
    # since it does not have any parent
    par[root] = 0
    # make the parent node for all the
    # nodes using DFS
    dfs(root, par)
    # find the target node in the tree
    node = dfs2(root, target)
    # hash table to make the visited nodes
    vis = {}
    # dfs call to find the sum
    dfs3(node, 0, k, vis, par)
# driver program
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
target = 9
K = 1
# function call
sum_at_distK(root, target, K)
# this code is contributed by Yash Agarwal(yashagarwal2852002)


// C# code to implement above approach
using System;
using System.Collections.Generic;
public class GFG {
  // Structure of a tree node
  class Node {
    public int data;
    public Node left;
    public Node right;
    public Node(int val)
    { = val;
      this.left = null;
      this.right = null;
  // Function for marking the parent node
  // for all the nodes using DFS
  static void dfs(Node root, Dictionary<Node, Node> par)
    if (root == null)
    if (root.left != null)
      par.Add(root.left, root);
    if (root.right != null)
      par.Add(root.right, root);
    dfs(root.left, par);
    dfs(root.right, par);
  static int sum;
  // Function calling for finding the sum
  static void dfs3(Node root, int h, int k,
                   Dictionary<Node, int> vis,
                   Dictionary<Node, Node> par)
    if (h == k + 1)
    if (root == null)
    if (vis.ContainsKey(root))
    sum +=;
    vis.Add(root, 1);
    dfs3(root.left, h + 1, k, vis, par);
    dfs3(root.right, h + 1, k, vis, par);
    dfs3(par[root], h + 1, k, vis, par);
  // Function for finding
  // the target node in the tree
  static Node dfs2(Node root, int target)
    if (root == null)
      return null;
    if ( == target)
      return root;
    Node node1 = dfs2(root.left, target);
    Node node2 = dfs2(root.right, target);
    if (node1 != null)
      return node1;
    if (node2 != null)
      return node2;
    return null;
  static int sum_at_distK(Node root, int target, int k)
    // Hash Map to store
    // the parent of a node
    Dictionary<Node, Node> par
      = new Dictionary<Node, Node>();
    // Make the parent of root node as NULL
    // since it does not have any parent
    par.Add(root, null);
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
    // Find the target node in the tree
    Node node = dfs2(root, target);
    // Hash Map to mark
    // the visited nodes
    Dictionary<Node, int> vis
      = new Dictionary<Node, int>();
    sum = 0;
    // DFS call to find the sum
    dfs3(node, 0, k, vis, par);
    return sum;
  static public void Main()
    // Code
    Node root = new Node(1);
    root.left = new Node(2);
    root.right = new Node(9);
    root.left.left = new Node(4);
    root.right.left = new Node(5);
    root.right.right = new Node(7);
    root.left.left.left = new Node(8);
    root.left.left.right = new Node(19);
    root.right.right.left = new Node(20);
    root.right.right.right = new Node(11);
    root.left.left.left.left = new Node(30);
    root.left.left.right.left = new Node(40);
    root.left.left.right.right = new Node(50);
    int target = 9, K = 1;
    // Function call
    Console.Write(sum_at_distK(root, target, K));
// This code is contributed by lokesh(lokeshmvs21).


       // JavaScript code for the above approach
       // Structure of a tree node
       class Node {
           constructor(val) {
      = val;
               this.left = null;
               this.right = null;
       // Function for marking the parent node
       // for all the nodes using DFS
       function dfs(root, par) {
           if (root === null) return;
           if (root.left !== null) par.set(root.left, root);
           if (root.right !== null) par.set(root.right, root);
           dfs(root.left, par);
           dfs(root.right, par);
       let sum = 0;
       // Function calling for finding the sum
       function dfs3(root, h, k, vis, par) {
           if (h === k + 1) return;
           if (root === null) return;
           if (vis.has(root)) return;
           sum +=;
           vis.set(root, 1);
           dfs3(root.left, h + 1, k, vis, par);
           dfs3(root.right, h + 1, k, vis, par);
           if (par.get(root) !== null && vis.has(par.get(root))) {
               dfs3(par.get(root), h + 1, k, vis, par);
       // Function for finding
       // the target node in the tree
       function dfs2(root, target) {
           if (root === null) return null;
           if ( === target) return root;
           let node1 = dfs2(root.left, target);
           let node2 = dfs2(root.right, target);
           if (node1 !== null) return node1;
           if (node2 !== null) return node2;
           return null;
       function sumAtDistK(root, target, k)
           // Map to store the parent of a node
           let par = new Map();
           // Make the parent of root node as NULL
           // since it does not have any parent
           par.set(root, null);
           // Mark the parent node for all the
           // nodes using DFS
           dfs(root, par);
           // Find the target node in the tree
           let node = dfs2(root, target);
           // Map to mark the visited nodes
           let vis = new Map();
           sum = 1;
           // DFS call to find the sum
           dfs3(node, 0, k, vis, par);
           return sum;
       // Taking Input
       let root = new Node(1);
       root.left = new Node(2);
       root.right = new Node(9);
       root.left.left = new Node(4);
       root.right.left = new Node(5);
       root.right.right = new Node(7);
       root.left.left.left = new Node(8);
       root.left.left.right = new Node(19);
       root.right.right.left = new Node(20);
       root.right.right.right = new Node(11);
       root.left.left.left.left = new Node(30);
       root.left.left.right.left = new Node(40);
       root.left.left.right.right = new Node(50);
       let target = 9;
       let K = 1;
       console.log(sumAtDistK(root, target, K));
// This code is contributed by Potta Lokesh



Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)

Approach using BFS:-

  • We will be using level order traversal to find the sum of nodes


  • First we will find the target node using level order traversal.
  • While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
  • After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.


// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
        this->data = val;
        this->left = 0;
        this->right = 0;
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
  //variable to store answer
  int ans = 0;
  //queue for bfs
  queue<Node*> q;
  //to store target node
  Node* need;
  //map to store parent of each node
  unordered_map<Node*, Node*> m;
    int s = q.size();
    //traversing to current level
    for(int i=0;i<s;i++){
      Node* temp = q.front();
      //if target value found
      if(temp->data==target) need=temp;
  //map to store occurrence of a node
  //that is the node has taken or not
  unordered_map<Node*, int> mm;
  //to store current distance
  int c = 0;
    int s = q.size();
    for(int i=0;i<s;i++){
      Node* temp = q.front();
      mm[temp] = 1;
      //moving left
      //moving right
      //movinf to parent
  return ans;
// Driver Code
int main()
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
        = new Node(11);
        = new Node(30);
        = new Node(40);
        = new Node(50);
    int target = 9, K = 1;
    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
//code contributed by shubhamrajput6156


import java.util.*;
// Structure of a tree node
class Node {
    int data;
    Node left;
    Node right;
    public Node(int val) { = val;
        this.left = null;
        this.right = null;
public class Main {
    // Function to find the sum at distance K
    public static int sumAtDistK(Node root, int target, int k) {
        // Variable to store the answer
        int ans = 0;
        // Queue for BFS
        Queue<Node> q = new LinkedList<>();
        // To store the target node
        Node need = null;
        // Map to store the parent of each node
        Map<Node, Node> parentMap = new HashMap<>();
        // BFS
        while (!q.isEmpty()) {
            int size = q.size();
            // Traverse the current level
            for (int i = 0; i < size; i++) {
                Node temp = q.poll();
                // If the target value is found
                if ( == target) {
                    need = temp;
                if (temp.left != null) {
                    parentMap.put(temp.left, temp);
                if (temp.right != null) {
                    parentMap.put(temp.right, temp);
        // Map to store the occurrence of a node (whether it has been visited)
        Map<Node, Integer> visitedMap = new HashMap<>();
        // Current distance
        int currentDistance = 0;
        while (!q.isEmpty()) {
            int size = q.size();
            for (int i = 0; i < size; i++) {
                Node temp = q.poll();
                visitedMap.put(temp, 1);
                ans +=;
                // Moving left
                if (temp.left != null && visitedMap.getOrDefault(temp.left, 0) == 0) {
                // Moving right
                if (temp.right != null && visitedMap.getOrDefault(temp.right, 0) == 0) {
                // Moving to parent
                if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0) == 0) {
            if (currentDistance > k) {
        return ans;
    // Driver code
    public static void main(String[] args) {
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);
        int target = 9, K = 1;
        // Function call
        System.out.println(sumAtDistK(root, target, K));


from collections import deque
class Node:
    def __init__(self, val): = val
        self.left = None
        self.right = None
# Function to find the sum at distance K
def sum_at_distK(root, target, k):
    ans = 0
    # Queue for BFS
    q = deque()
    need = None
    # Dictionary to store parent of each node
    m = {}
    # BFS traversal to find the target node
    while q:
        s = len(q)
        # Traversing the current level
        for i in range(s):
            temp = q.popleft()
            if == target:
                need = temp
            if temp.left:
                m[temp.left] = temp
            if temp.right:
                m[temp.right] = temp
    # Dictionary to store occurrence of a node (visited or not)
    mm = {}
    c = 0
    # BFS traversal within K distance
    while q:
        s = len(q)
        for i in range(s):
            temp = q.popleft()
            mm[temp] = 1
            ans +=
            # Moving left
            if temp.left and temp.left not in mm:
            # Moving right
            if temp.right and temp.right not in mm:
            # Moving to parent
            if temp in m and m[temp] not in mm:
        c += 1
        if c > k:
    return ans
# Driver Code
# Taking Input
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
target = 9
K = 1
# Function call
print(sum_at_distK(root, target, K))


class Node {
    constructor(val) { = val;
        this.left = null;
        this.right = null;
// Function to find the sum at distance K
function sum_at_distK(root, target, k) {
    let ans = 0;
    // Queue for BFS
    let q = [];
    let need = null;
    // Map to store parent of each node
    let m = new Map();
    // BFS traversal to find the target node
    while (q.length) {
        let s = q.length;
        // Traversing the current level
        for (let i = 0; i < s; i++) {
            let temp = q.shift();
            if ( === target) {
                need = temp;
            if (temp.left) {
                m.set(temp.left, temp);
            if (temp.right) {
                m.set(temp.right, temp);
    // Map to store occurrence of a node (visited or not)
    let mm = new Map();
    let c = 0;
    // BFS traversal within K distance
    while (q.length) {
        let s = q.length;
        for (let i = 0; i < s; i++) {
            let temp = q.shift();
            mm.set(temp, 1);
            ans +=;
            // Moving left
            if (temp.left && !mm.has(temp.left)) {
            // Moving right
            if (temp.right && !mm.has(temp.right)) {
            // Moving to parent
            if (m.has(temp) && !mm.has(m.get(temp))) {
        if (c > k) break;
    return ans;
// Driver Code
// Taking Input
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
let target = 9, K = 1;
// Function call
console.log(sum_at_distK(root, target, K));



Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)

Feeling lost in the world of random DSA topics, wasting time without progress? It’s time for a change! Join our DSA course, where we’ll guide you on an exciting journey to master DSA efficiently and on schedule.
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 neveropen!


Most Popular

Recent Comments