Efficient Segment Trees

Memory and Time Efficient Segment Trees.

A Segment Tree is a data structure that allows answering range queries over an array effectively, while still being flexible enough to allow modifying the array. This includes finding the sum of consecutive array elements a[lr], or finding the minimum element in a such a range in O(logn) time. Between answering such queries the Segment Tree allows modifying the array by replacing one element, or even change the elements of a whole sub segment (e.g. assigning all elements a[lr] to any value, or adding a value to all element in the sub segment).

Not diving deeper into how segment trees work and how to implement them in a simple/recursive way, I would further explain how to implement them in an iterative and more memory efficient way. Go through the following link if you are not up to date with implementing segment trees in a recursive manner:

The above picture shows a visual representation of a recursively implemented segment tree to store cumulative sum.

Instead, lets imagine the segment tree in form of a perfect binary tree:



Notation is node_index: corresponding segment (left border included, right excluded). At the bottom row we have our array (0-indexed), the leaves of the tree. For now suppose it's length is a power of 2 (16 in the example), so we get perfect binary tree. When going up the tree we take pairs of nodes with indices (2 * i, 2 * i + 1) and combine their values in their parent with index i. This way when we're asked to find a sum on interval [3, 11), we need to sum up only values in the nodes 19, 5, 12 and 26 (marked with bold), not all 8 values inside the interval. Let's jump directly to implementation (in C++) to see how it works:

const int N = 1e5;  // limit for array size
int n;  // array size
int t[2 * N];

void build() {  // build the tree
  for (int i = n - 1; i > 0; --i) t[i] = t[i<<1] + t[i<<1|1];
}

void modify(int pint value) {  // set value at position p
  for (t[p += n] = value; p > 1; p >>= 1t[p>>1] = t[p] + t[p^1];
}

int query(int lint r) {  // sum on interval [l, r)
  int res = 0;
  for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
    if (l&1) res += t[l++];
    if (r&1) res += t[--r];
  }
  return res;
}

int main() {
  scanf("%d", &n);
  for (int i = 0; i < n; ++i) scanf("%d", t + n + i);
  build();
  modify(01);
  printf("%d\n"query(311));
  return 0;
}


That's it! The code works for modifying a single element and querying on an interval. Following is the code explanation:

  1. As you could notice from the picture, leaves are stored in continuous nodes with indices starting with n, element with index i corresponds to a node with index i + n. So we can read initial values directly into the tree where they belong.

  2. Before doing any queries we need to build the tree, which is quite straightforward and takes O(n) time. Since parent always has index less than its children, we just process all the internal nodes in decreasing order. In case you're confused by bit operations, the code in build() is equivalent to t[i] = t[2*i] + t[2*i+1].

  3. Modifying an element is also quite straightforward and takes time proportional to the height of the tree, which is O(log(n)). We only need to update values in the parents of given node. So we just go up the tree knowing that parent of node p is p / 2 or p>>1, which means the same. p^1 turns 2 * i into 2 * i + 1 and vice versa, so it represents the second child of p's parent.

  4. Finding the sum also works in O(log(n)) time. To better understand it's logic you can go through example with interval [3, 11) and verify that result is composed exactly of values in nodes 19, 26, 12 and 5 (in that order).

General idea is the following. If l, the left interval border, is odd (which is equivalent to l&1) then l is the right child of its parent. Then our interval includes node l but doesn't include it's parent. So we add t[l] and move to the right of l's parent by setting l = (l + 1) / 2. If l is even, it is the left child, and the interval includes its parent as well (unless the right border interferes), so we just move to it by setting l = l / 2. Similar argumentation is applied to the right border. We stop once borders meet.

No recursion and no additional computations like finding the middle of the interval are involved, we just go through all the nodes we need, so this is very efficient.

Some people begin to struggle and invent something too complex when the operations are inverted, for example:

  1. add a value to all elements in some interval;
  2. compute an element at some position.

But all we need to do in this case is to switch the code in methods modify and query as follows:


void modify(int lint rint value) {
  for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
    if (l&1t[l++] += value;
    if (r&1t[--r] += value;
  }
}

int query(int p) {
  int res = 0;
  for (p += n; p > 0; p >>= 1) res += t[p];
  return res;
}

Also,If at some point after modifications we need to inspect all the elements in the array, we can push all the modifications to the leaves using the following code. After that we can just traverse elements starting with index n. This way we reduce the complexity from O(nlog(n)) to O(n) similarly to using build instead of n modifications:

void push() {
  for (int i = 1; i < n; ++i) {
    t[i<<1] += t[i];
    t[i<<1|1] += t[i];
    t[i] = 0;
  }
}

So, instead of writing all those cumbersome recursive functions with multiple arguments, try implementing the question using this code, and yeah, don't worry it works perfectly for an arbitrary sized array, i.e. there is no hard and fast rule for array to be size of power of 2.

Practice problems: Practice SEGTREE-GFG

Thanks for reading. :)
______________________________________________________





Comments