Tuesday, December 22, 2015

Leetcode: Range Sum Query - Mutable

Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
Note:
  1. The array is only modifiable by the update function.
  2. You may assume the number of calls to update and sumRange function is distributed evenly.
Naive Solution 1: 
The first method is the most straight-forward. Use an array to store the numbers. So the update takes O(1) time. sumRange() takes O(n) time. 

Naive Solution 2:
Similar to the Range Sum Query - Immutable, we calculate the prefix of the elements in the array. Therefore, for the update(), it takes O(n) time to update the elements after the updated element. But the sumRange() take O(1) time. 

Better Solution:
Since the problem assumes that the two functions are called evenly, there exists a better method using Segment Tree in O(logn) time. 

Code (Java):
public class NumArray {
    class SegmentTreeNode {
        int start, end;
        int sum;
        SegmentTreeNode left, right;
        
        // Constructor
        public SegmentTreeNode(int start, int end) {
            this.start = start;
            this.end = end;
            sum = 0;
        }
        
        public SegmentTreeNode(int start, int end, int sum) {
            this.start = start;
            this.end = end;
            this.sum = sum;
        }
        
    }
    
    private SegmentTreeNode root;
    
    public NumArray(int[] nums) {
        if (nums == null || nums.length == 0) {
            return;
        }
        root = buildSegmentTree(nums, 0, nums.length - 1);
    }

    void update(int i, int val) {
        updateHelper(root, i, val);
    }
    
    private void updateHelper(SegmentTreeNode root, int i, int val) {
        if (root == null) {
            return;
        }
        
        int mid = root.start + (root.end - root.start) / 2;
        
        if (i <= mid) {
            updateHelper(root.left, i, val);
        } else {
            updateHelper(root.right, i, val);
        }
        
        if (root.start == root.end && root.start == i) {
            root.sum = val;
            return;
        }
        
        root.sum = root.left.sum + root.right.sum;

    }

    public int sumRange(int i, int j) {
        return sumRangeHelper(root, i, j);
    }
    
    private int sumRangeHelper(SegmentTreeNode root, int start, int end) {
        if (root == null || end < root.start || start > root.end || 
            start > end) {
            return 0;
        }
        
        if (start <= root.start && end >= root.end) {
            return root.sum;
        }
        
        int mid = root.start + (root.end - root.start) / 2;
        
        return sumRangeHelper(root.left, start, Math.min(end, mid)) + 
               sumRangeHelper(root.right, Math.max(mid + 1, start), end);
    }
    
    private SegmentTreeNode buildSegmentTree(int[] nums, int start, int end) {
        if (nums == null || nums.length == 0 || start > end) {
            return null;
        }
        
        // Start == end
        if (start == end) {
            return new SegmentTreeNode(start, end, nums[start]);
        }
        
        SegmentTreeNode root = new SegmentTreeNode(start, end);
        int mid = start + (end - start) / 2;
        root.left = buildSegmentTree(nums, start, mid);
        root.right = buildSegmentTree(nums, mid + 1, end);
        
        root.sum = root.left.sum + root.right.sum;
        
        return root;
    }
}


// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);

Summary:
The take-away message for this problem is there is no such a BEST solution ever. You need to communicate with the interviewer to ask the use cases and choose the best solution. 

Update on 5/16/19:
class NumArray {
    SegmentTree segmentTree;

    public NumArray(int[] nums) {
        segmentTree = new SegmentTree(nums);
    }
    
    public void update(int i, int val) {
        segmentTree.update(i, val);
    }
    
    public int sumRange(int i, int j) {
        return segmentTree.search(i, j);
    }
}

class SegmentTree {
    public SegmentTreeNode root;
    public SegmentTree(int[] nums) {
        root = build(nums);
    }
    
    public int search(int start, int end) {
        return searchHelper(root, start, end);
    }
    
    private int searchHelper(SegmentTreeNode root, int start, int end) {
        if (root == null || start > end) {
            return 0;
        }
        
        if (root.start == start && root.end == end) {
            return root.sum;
        }
        
        int mid = root.start + (root.end - root.start) / 2;
        int leftSum = searchHelper(root.left, Math.max(root.start, start), Math.min(mid, end));
        int rightSum = searchHelper(root.right, Math.max(start, mid + 1), Math.min(root.end, end));
        
        return leftSum + rightSum;
    }
    
    public void update(int index, int val) {
        updateHelper(root, index, val);
    }
    
    private void updateHelper(SegmentTreeNode root, int index, int val) {
        if (root == null) {
            return;
        }
        
        if (root.start == root.end && index == root.start) {
            root.sum = val;
            return;
        }
        
        int mid = root.start + (root.end - root.start) / 2;
        if (index <= mid) {
            updateHelper(root.left, index, val);
        } else {
            updateHelper(root.right, index, val);
        }
        
        root.sum = root.left.sum + root.right.sum;
    } 
    
    private SegmentTreeNode build(int[] nums) {
        return buildHelper(0, nums.length - 1, nums);
    }
    
    private SegmentTreeNode buildHelper(int start, int end, int[] nums) {
        if (start == end) {
            return new SegmentTreeNode(start, end, nums[start]);
        }
        
        int mid = start + (end - start) / 2;
        SegmentTreeNode leftChild = buildHelper(start, mid, nums);
        SegmentTreeNode rightChild = buildHelper(mid + 1, end, nums);
        
        SegmentTreeNode root = new SegmentTreeNode(start, end, leftChild.sum + rightChild.sum);
        root.left = leftChild;
        root.right = rightChild;
        
        return root;
    }
    
}

class SegmentTreeNode {
    int sum;
    int start, end;
    SegmentTreeNode left, right;
    
    public SegmentTreeNode(int start, int end, int sum) {
        this.sum = sum;
        this.start = start;
        this.end = end;
        left = right = null;
    }
}

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray obj = new NumArray(nums);
 * obj.update(i,val);
 * int param_2 = obj.sumRange(i,j);
 */

No comments:

Post a Comment