Thursday, May 16, 2019

Lintcode 817. Range Sum Query 2D - Mutable

817. Range Sum Query 2D - Mutable

中文English
Given a 2D matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2). And the elements of the matrix could be changed.
You have to implement three functions:
  • NumMatrix(matrix) The constructor.
  • sumRegion(row1, col1, row2, col2) Return the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).
  • update(row, col, val) Update the element at (row, col) to val.

Example

Example 1:
Input:
  NumMatrix(
    [[3,0,1,4,2],
     [5,6,3,2,1],
     [1,2,0,1,5],
     [4,1,0,1,7],
     [1,0,3,0,5]]
  )
  sumRegion(2,1,4,3)
  update(3,2,2)
  sumRegion(2,1,4,3)
Output:
  8
  10
Example 2:
Input:
  NumMatrix([[1]])
  sumRegion(0, 0, 0, 0)
  update(0, 0, -1)
  sumRegion(0, 0, 0, 0)
Output:
  1
  -1

Notice

  1. The matrix is only modifiable by update.
  2. You may assume the number of calls to update and sumRegion function is distributed evenly.
  3. You may assume that row1 ≤ row2 and col1 ≤ col2.

Code (Java):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class NumMatrix {
    private SegmentTree segTree;
    public NumMatrix(int[][] matrix) {
        segTree = new SegmentTree(matrix);
    }
     
    public void update(int row, int col, int val) {
        segTree.update(row,col, val);
    }
     
    public int sumRegion(int row1, int col1, int row2, int col2) {
        return segTree.search(row1, col1, row2, col2);
    }
}
 
 
class SegmentTree {
    SegmentTreeNode root;
    public SegmentTree(int[][] matrix) {
        int m = matrix.length;
        int n = matrix[0].length;
        root = build(matrix, 0, 0, m - 1, n - 1);
    }
     
    private SegmentTreeNode build(int[][] matrix, int row1, int col1, int row2, int col2) {
        if (row1 > row2 || col1 > col2) {
            return null;
        }
         
        if (row1 == row2 && col1 == col2) {
            return new SegmentTreeNode(row1, col1, row2, col2, matrix[row1][col1]);
        }
         
        int midRow = row1 + (row2 - row1) / 2;
        int midCol = col1 + (col2 - col1) / 2;
         
        SegmentTreeNode upperLeft = build(matrix, row1, col1, midRow, midCol);
        SegmentTreeNode upperRight = build(matrix, row1, midCol + 1, midRow, col2);
        SegmentTreeNode lowerLeft = build(matrix, midRow + 1, col1, row2, midCol);
        SegmentTreeNode lowerRight = build(matrix, midRow + 1, midCol + 1, row2, col2);
         
        int sum = (upperLeft == null ? 0 : upperLeft.sum) +
                  (upperRight == null ? 0 : upperRight.sum) +
                  (lowerLeft == null ? 0 : lowerLeft.sum) +
                  (lowerRight == null ? 0 : lowerRight.sum);
         
        SegmentTreeNode root =
            new SegmentTreeNode(row1,
                                col1,
                                row2,
                                col2,
                                sum);
                                 
        root.upperLeft = upperLeft;
        root.upperRight = upperRight;
        root.lowerLeft = lowerLeft;
        root.lowerRight = lowerRight;
        return root;
    }
     
    public int search(int row1, int col1, int row2, int col2) {
        return searchHelper(root, row1, col1, row2, col2);
    }
     
    private int searchHelper(SegmentTreeNode root, int row1, int col1, int row2, int col2) {
        if (row1 > row2 || col1 > col2) {
            return 0;
        }
         
        if (root.row1 == row1 && root.row2 == row2 && root.col1 == col1 && root.col2 == col2) {
            return root.sum;
        }
         
        int midRow = root.row1 + (root.row2 - root.row1) / 2;
        int midCol = root.col1 + (root.col2 - root.col1) / 2;
        int sumUpperLeft = searchHelper(root.upperLeft,
                                        Math.max(row1, root.row1),
                                        Math.max(root.col1, col1),
                                        Math.min(row2, midRow),
                                        Math.min(col2, midCol));
        int sumUpperRight = searchHelper(root.upperRight,
                                        Math.max(row1, root.row1),
                                        Math.max(col1, midCol + 1),
                                        Math.min(row2, midRow),
                                        Math.min(col2, root.col2));
        int sumLowerLeft = searchHelper(root.lowerLeft,
                                        Math.max(row1,  midRow + 1),
                                        Math.max(col1, root.col1),
                                        Math.min(root.row2, row2),
                                        Math.min(col2, midCol));
         
        int sumLowerRight = searchHelper(root.lowerRight,
                                        Math.max(row1, midRow + 1),
                                        Math.max(col1, midCol + 1),
                                        Math.min(root.row2, row2),
                                        Math.min(col2, root.col2));
                                         
        return sumUpperLeft + sumUpperRight + sumLowerLeft + sumLowerRight;
    }
     
    public void update(int row, int col, int val) {
        updateHelper(root, row, col, val);
    }
     
    private void updateHelper(SegmentTreeNode root, int row, int col, int val) {
        if (row < root.row1 || row > root.row2 || col < root.col1 || col > root.col2) {
            return;
        }
         
        if (root.row1 == root.row2 && root.col1 == root.col2 && root.row1 == row && root.col1 == col) {
            root.sum = val;
            return;
        }
         
        int midRow = root.row1 + (root.row2 - root.row1) / 2;
        int midCol = root.col1 + (root.col2 - root.col1) / 2;
         
        if (row <= midRow && col <= midCol) {
            updateHelper(root.upperLeft, row, col, val);
        } else if (row <= midRow && col > midCol) {
            updateHelper(root.upperRight, row, col, val);
        } else if (row > midRow && col <= midCol) {
            updateHelper(root.lowerLeft, row, col, val);
        } else if (row > midRow && col > midCol) {
            updateHelper(root.lowerRight, row, col, val);
        }
         
        int sum = (root.upperLeft == null ? 0 : root.upperLeft.sum) +
                  (root.upperRight == null ? 0 : root.upperRight.sum) +
                  (root.lowerLeft == null ? 0 : root.lowerLeft.sum) +
                  (root.lowerRight == null ? 0 : root.lowerRight.sum);
         
        root.sum = sum;
    }
     
}
 
class SegmentTreeNode {
    int row1, col1, row2, col2;
    int sum;
    SegmentTreeNode upperLeft, upperRight, lowerLeft, lowerRight;
     
    public SegmentTreeNode(int row1, int col1, int row2, int col2, int sum) {
        this.row1 = row1;
        this.row2 = row2;
        this.col1 = col1;
        this.col2 = col2;
        this.sum = sum;
        upperLeft = upperRight = lowerLeft = lowerRight = null;
    }
}
 
/**
 * Your NumMatrix object will be instantiated and called as such:
 * NumMatrix obj = new NumMatrix(matrix);
 * obj.update(row,col,val);
 * int param_2 = obj.sumRegion(row1,col1,row2,col2);
 */

No comments:

Post a Comment