Thursday, May 16, 2019

Lintcode 817. Range Sum Query 2D - Mutable

817. Range Sum Query 2D - Mutable

中文English
Given a 2D matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2). And the elements of the matrix could be changed.
You have to implement three functions:
  • NumMatrix(matrix) The constructor.
  • sumRegion(row1, col1, row2, col2) Return the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).
  • update(row, col, val) Update the element at (row, col) to val.

Example

Example 1:
Input:
  NumMatrix(
    [[3,0,1,4,2],
     [5,6,3,2,1],
     [1,2,0,1,5],
     [4,1,0,1,7],
     [1,0,3,0,5]]
  )
  sumRegion(2,1,4,3)
  update(3,2,2)
  sumRegion(2,1,4,3)
Output:
  8
  10
Example 2:
Input:
  NumMatrix([[1]])
  sumRegion(0, 0, 0, 0)
  update(0, 0, -1)
  sumRegion(0, 0, 0, 0)
Output:
  1
  -1

Notice

  1. The matrix is only modifiable by update.
  2. You may assume the number of calls to update and sumRegion function is distributed evenly.
  3. You may assume that row1 ≤ row2 and col1 ≤ col2.

Code (Java):
class NumMatrix {
    private SegmentTree segTree;
    public NumMatrix(int[][] matrix) {
        segTree = new SegmentTree(matrix);
    }
    
    public void update(int row, int col, int val) {
        segTree.update(row,col, val);
    }
    
    public int sumRegion(int row1, int col1, int row2, int col2) {
        return segTree.search(row1, col1, row2, col2);
    }
}


class SegmentTree {
    SegmentTreeNode root;
    public SegmentTree(int[][] matrix) {
        int m = matrix.length;
        int n = matrix[0].length;
        root = build(matrix, 0, 0, m - 1, n - 1);
    }
    
    private SegmentTreeNode build(int[][] matrix, int row1, int col1, int row2, int col2) {
        if (row1 > row2 || col1 > col2) {
            return null;
        }
        
        if (row1 == row2 && col1 == col2) {
            return new SegmentTreeNode(row1, col1, row2, col2, matrix[row1][col1]);
        }
        
        int midRow = row1 + (row2 - row1) / 2;
        int midCol = col1 + (col2 - col1) / 2;
        
        SegmentTreeNode upperLeft = build(matrix, row1, col1, midRow, midCol);
        SegmentTreeNode upperRight = build(matrix, row1, midCol + 1, midRow, col2);
        SegmentTreeNode lowerLeft = build(matrix, midRow + 1, col1, row2, midCol);
        SegmentTreeNode lowerRight = build(matrix, midRow + 1, midCol + 1, row2, col2);
        
        int sum = (upperLeft == null ? 0 : upperLeft.sum) + 
                  (upperRight == null ? 0 : upperRight.sum) + 
                  (lowerLeft == null ? 0 : lowerLeft.sum) + 
                  (lowerRight == null ? 0 : lowerRight.sum);
        
        SegmentTreeNode root = 
            new SegmentTreeNode(row1, 
                                col1, 
                                row2, 
                                col2, 
                                sum);
                                
        root.upperLeft = upperLeft;
        root.upperRight = upperRight;
        root.lowerLeft = lowerLeft;
        root.lowerRight = lowerRight;
        return root;
    }
    
    public int search(int row1, int col1, int row2, int col2) {
        return searchHelper(root, row1, col1, row2, col2);
    }
    
    private int searchHelper(SegmentTreeNode root, int row1, int col1, int row2, int col2) {
        if (row1 > row2 || col1 > col2) {
            return 0;
        }
        
        if (root.row1 == row1 && root.row2 == row2 && root.col1 == col1 && root.col2 == col2) {
            return root.sum;
        }
        
        int midRow = root.row1 + (root.row2 - root.row1) / 2;
        int midCol = root.col1 + (root.col2 - root.col1) / 2;
        int sumUpperLeft = searchHelper(root.upperLeft, 
                                        Math.max(row1, root.row1), 
                                        Math.max(root.col1, col1), 
                                        Math.min(row2, midRow), 
                                        Math.min(col2, midCol));
        int sumUpperRight = searchHelper(root.upperRight,
                                        Math.max(row1, root.row1),
                                        Math.max(col1, midCol + 1),
                                        Math.min(row2, midRow),
                                        Math.min(col2, root.col2));
        int sumLowerLeft = searchHelper(root.lowerLeft,
                                        Math.max(row1,  midRow + 1),
                                        Math.max(col1, root.col1),
                                        Math.min(root.row2, row2),
                                        Math.min(col2, midCol));
        
        int sumLowerRight = searchHelper(root.lowerRight,
                                        Math.max(row1, midRow + 1),
                                        Math.max(col1, midCol + 1),
                                        Math.min(root.row2, row2),
                                        Math.min(col2, root.col2));
                                        
        return sumUpperLeft + sumUpperRight + sumLowerLeft + sumLowerRight;
    }
    
    public void update(int row, int col, int val) {
        updateHelper(root, row, col, val);
    }
    
    private void updateHelper(SegmentTreeNode root, int row, int col, int val) {
        if (row < root.row1 || row > root.row2 || col < root.col1 || col > root.col2) {
            return;
        }
        
        if (root.row1 == root.row2 && root.col1 == root.col2 && root.row1 == row && root.col1 == col) {
            root.sum = val;
            return;
        }
        
        int midRow = root.row1 + (root.row2 - root.row1) / 2;
        int midCol = root.col1 + (root.col2 - root.col1) / 2;
        
        if (row <= midRow && col <= midCol) {
            updateHelper(root.upperLeft, row, col, val);
        } else if (row <= midRow && col > midCol) {
            updateHelper(root.upperRight, row, col, val);
        } else if (row > midRow && col <= midCol) {
            updateHelper(root.lowerLeft, row, col, val);
        } else if (row > midRow && col > midCol) {
            updateHelper(root.lowerRight, row, col, val);
        }
        
        int sum = (root.upperLeft == null ? 0 : root.upperLeft.sum) + 
                  (root.upperRight == null ? 0 : root.upperRight.sum) + 
                  (root.lowerLeft == null ? 0 : root.lowerLeft.sum) + 
                  (root.lowerRight == null ? 0 : root.lowerRight.sum); 
        
        root.sum = sum;
    }
    
}

class SegmentTreeNode {
    int row1, col1, row2, col2;
    int sum;
    SegmentTreeNode upperLeft, upperRight, lowerLeft, lowerRight;
    
    public SegmentTreeNode(int row1, int col1, int row2, int col2, int sum) {
        this.row1 = row1;
        this.row2 = row2;
        this.col1 = col1;
        this.col2 = col2;
        this.sum = sum;
        upperLeft = upperRight = lowerLeft = lowerRight = null;
    }
}

/**
 * Your NumMatrix object will be instantiated and called as such:
 * NumMatrix obj = new NumMatrix(matrix);
 * obj.update(row,col,val);
 * int param_2 = obj.sumRegion(row1,col1,row2,col2);
 */

No comments:

Post a Comment