Segment Tree

Segment Tree(Interval Tree)

1. Introduction

A segment tree is a data structure for storing intervals. It is useful to get the range sum given an interval, range minimum/maximum query.

2. Implementation

public class SegmentTree {
    class SegmentTreeNode {
        int start, end, max, min, sum;
        SegmentTreeNode left, right;

        public SegmentTreeNode(int start, int end, int max, int min, int sum) {
            this.start = start;
            this.end = end;
            this.max = max;
            this.min = min;
            this.sum = sum;
        }
    }

    SegmentTreeNode root;

    public SegmentTree(int[] nums) {
        this.root = build(0, nums.length - 1, nums);
    }

    public SegmentTreeNode build(int start, int end, int[] nums) {
        if (start > end) {
            return null;
        }

        SegmentTreeNode node = new SegmentTreeNode(start, end, nums[start], nums[start], nums[start]);

        if (start != end) {
            int mid = start + (end - start) / 2;
            node.left = build(start, mid, nums);
            node.right = build(mid + 1, end, nums);

            if (node.left != null && node.right != null) {
                node.sum = node.left.sum + node.right.sum;
                node.max = Math.max(node.left.max, node.right.max);
                node.min = Math.min(node.left.min, node.right.min);
            }
            else if (node.left != null) {
                node.sum = node.left.sum;
                node.max = Math.max(node.max, node.left.max);
                node.min = Math.min(node.min, node.left.min);
            }
            else {
                node.sum = node.right.sum;
                node.max = Math.max(node.max, node.right.max);
                node.min = Math.min(node.min, node.right.min);
            }
        }
        return node;
    }

    // Range Sum Query
    public int querySum(SegmentTreeNode node, int start, int end) {
        if (node == null) {
            return 0;
        }

        if (node.start == start && node.end == end) {
            return node.sum;
        }

        int mid = node.start + (node.end - node.start) / 2;

        if (start > mid && end > mid) {
            return querySum(node.right, start, end);
        }
        else if (start <= mid && end <= mid) {
            return querySum(node.left, start, end);
        }
        else {
            return querySum(node.left, start, mid) + querySum(node.right, mid + 1, end);
        }
    }

    // Range Maximum Query
    public int queryMax(SegmentTreeNode node, int start, int end) {
        if (node == null || start > end) {
            return 0;
        }

        if (node.start == start && node.end == end) {
            return node.max;
        }

        int mid = node.start + (node.end - node.start) / 2;

        if (start <= mid && end <= mid) {
            return queryMax(node.left, start, end);
        }
        else if (start > mid && end > mid) {
            return queryMax(node.right, start, end);
        }
        else {
            return Math.max(queryMax(node.left, start, mid), queryMax(node.right, mid + 1, end));
        }
    }

    // Range Minimum Query
    public int queryMin(SegmentTreeNode node, int start, int end) {
        if (node == null || start > end) {
            return 0;
        }

        if (node.start == start && node.end == end) {
            return node.min;
        }

        int mid = node.start + (node.end - node.start) / 2;

        if (start <= mid && end <= mid) {
            return queryMax(node.left, start, end);
        }
        else if (start > mid && end > mid) {
            return queryMax(node.right, start, end);
        }
        else {
            return Math.min(queryMin(node.left, start, mid), queryMin(node.right, mid + 1, end));
        }
    }

    public void modify(SegmentTreeNode node, int index, int value) {
        if (node == null) {
            return;
        }

        if (node.start == index && node.end == index) {
            node.max = value;
            node.min = value;
            node.sum = value;
            return;
        }

        int mid = node.start + (node.end - node.start) / 2;

        if (index <= mid) {
            modify(node.left, index, value);
        }
        else {
            modify(node.right, index, value);
        }

        if (node.left != null && node.right != null) {
            node.sum = node.left.sum + node.right.sum;
            node.max = Math.max(node.left.max, node.right.max);
            node.min = Math.min(node.left.min, node.right.min);
        }
        else if (node.left != null) {
            node.sum = node.left.sum;
            node.max = Math.max(node.max, node.left.max);
            node.min = Math.min(node.min, node.left.min);
        }
        else {
            node.sum = node.right.sum;
            node.max = Math.max(node.max, node.right.max);
            node.min = Math.min(node.min, node.right.min);
        }
    }
}

3. Time & Space Complexity

  • build: O(n)

  • querySum, queryMax, queryMin: O(logn)

  • modify: O(logn)

  • Space: O(n)

Last updated