Monday, August 4, 2014

Leetcode: Search a 2D Matrix

Write an efficient algorithm that searches for a value in an m x n matrix. This matrix has the following properties:
  • Integers in each row are sorted from left to right.
  • The first integer of each row is greater than the last integer of the previous row.
For example,
Consider the following matrix:
[
  [1,   3,  5,  7],
  [10, 11, 16, 20],
  [23, 30, 34, 50]
]
Given target = 3, return true.

Analysis:
From the properties of the array, we can see that the array is actually well sorted, from left to right, top to bottom. 

Naive Solution:
Since the array is well sorted, we may use the binary search to search the target element. The only trick is to treat the 2D array as 1D and flatten the index. 

Code (Java):
public class Solution {
    public boolean searchMatrix(int[][] matrix, int target) {
        if (matrix == null || matrix.length == 0 || matrix[0].length == 0) return false;
        int m = matrix.length; // row
        int n = matrix[0].length; // column
        
        return binaryMatrixSearch(matrix, target, 0, m * n - 1, n);
    }
    
    private boolean binaryMatrixSearch(int[][] matrix, int target, int lo, int hi, int n) {
        if (lo > hi) return false;
        
        int mid = (lo + hi) / 2;
        int i = getRowIndex(mid, n);
        int j = getColIndex(mid, n);
        
        if (matrix[i][j] == target) return true;
        else if (matrix[i][j] > target) return binaryMatrixSearch(matrix, target, lo, mid - 1, n);
        else return binaryMatrixSearch(matrix, target, mid + 1, hi, n);
    }
    
    // Calculate the row index
    private int getRowIndex(int mid, int n) {
        return mid / n;
    }
    
    // Calculate the column index
    private int getColIndex(int mid, int n) {
        return mid % n;
    }
}

Discussion:
This solution is very straight forward and almost similar to the 1D binary search. The major difference is after calculating the mid index, and how to calculate the 2D row and column indices. The time complexity is O(log m*n) since it is a m * n array. The space complexity is O(1).  

Note that above code is recursive binary search. The binary search can also be implemented as iterative way:
public class Solution {
    public boolean searchMatrix(int[][] matrix, int target) {
        if (matrix == null || matrix.length == 0 || matrix[0].length == 0)
            return false;
        
        int m = matrix.length; // row
        int n = matrix[0].length; // col
        
        int lo = 0;
        int hi = m * n - 1;
        int mid;
        
        while (lo <= hi) {
            mid = (lo + hi) / 2;
            
            int i = getRowIndex(mid, n);
            int j = getColIndex(mid, n);
            
            if (matrix[i][j] == target) return true;
            else if (target < matrix[i][j]) hi = mid - 1;
            else lo = mid + 1;
        }
        return false;
    }
    private int getRowIndex(int mid, int n) {
        return mid / n;
    }
    
    private int getColIndex(int mid, int n) {
        return mid % n;
    }
}

Update on 9/29/14:
Another idea to solve this problem is to first find out which row the target value will reside in, then find out the value in the specific row.

Code (Java):
public class Solution {
    public boolean searchMatrix(int[][] matrix, int target) {
        if (matrix == null || matrix.length == 0) {
            return false;
        }
        
        // Find the row
        int row = findRow(matrix, target);
        if (row == -1) {
            return false;
        }
        
        // Find the column
        int col = findCol(matrix[row], target);
        if (col == -1) {
            return false;
        }
        
        return true;
    }
    
    private int findRow(int[][] matrix, int target) {
        int lo = 0;
        int hi = matrix.length - 1;
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (matrix[mid][0] == target) {
                return mid;
            }
            
            if (matrix[mid][0] > target) {
                hi = mid;
            } else {
                lo = mid;
            }
        }
        
        if (matrix[lo][0] == target) {
            return lo;
        } else if (matrix[hi][0] == target) {
            return hi;
        } else if (matrix[lo][0] < target && target < matrix[hi][0]) {
            return lo;
        } else if (matrix[hi][0] < target) {
            return hi;
        } else {
            return -1;
        }
    }
    
    private int findCol(int[] cols, int target) {
        int lo = 0;
        int hi = cols.length - 1;
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (cols[mid] == target) {
                return mid;
            }
            
            if (cols[mid] > target) {
                hi = mid;
            } else {
                lo = mid;
            }
        }
        
        if (cols[lo] == target) {
            return lo;
        }
        
        if (cols[hi] == target) {
            return hi;
        }
        
        return -1;
    }
}


Update on 10/7/15:
public class Solution {
    public boolean searchMatrix(int[][] matrix, int target) {
        if (matrix == null || matrix.length == 0) {
            return false;
        }
        
        int m = matrix.length;
        int n = matrix[0].length;
        
        // Step 1: find the rowId of the target number
        int lo = 0;
        int hi = m - 1;
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (matrix[mid][0] == target) {
                return true;
            } else if (matrix[mid][0] < target) {
                lo = mid;
            } else {
                hi = mid - 1;
            }
        }
        
        if (matrix[hi][0] == target || matrix[lo][0] == target) {
            return true;
        }
        
        int rowId;
        if (target > matrix[lo][0] && target <= matrix[lo][n - 1]) {
            rowId = lo;
        } else {
            rowId = hi;
        }
        
        // Step 2: find the target number in the rowId
        lo = 0;
        hi = n - 1;
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (matrix[rowId][mid] == target) {
                return true;
            } else if (matrix[rowId][mid] < target) {
                lo = mid + 1;
            } else {
                hi = mid - 1;
            }
        }
        
        if (matrix[rowId][hi] == target || matrix[rowId][lo] == target) {
            return true;
        }
        
        return false;
    }
}

1 comment:

  1. Hi,
    How will you explain me why you have condition lo + 1 < hi in while (lo + 1 < hi), usually for binary search we have lo < = high

    ReplyDelete