Tuesday, July 29, 2014

Leetcode: 4Sum

Given an array S of n integers, are there elements abc, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.
Note:
  • Elements in a quadruplet (a,b,c,d) must be in non-descending order. (ie, a ≤ b ≤ c ≤ d)
  • The solution set must not contain duplicate quadruplets.
    For example, given array S = {1 0 -1 0 -2 2}, and target = 0.

    A solution set is:
    (-1,  0, 0, 1)
    (-2, -1, 1, 2)
    (-2,  0, 0, 2)

Understand the program:
This question is very similar to 3Sum and two-sum, as is we have discussed before. There are several requirements in the question:
1. Find all unique quadruplets in the array, so you cannot simply returns once you found a solution
2. Elements in a quadruplet(a,b,c,d) must be in non-descending order (a <= b <= c <= d)
3. The solution set must not contain duplicate quadruplets.
4. How to store the result: As is defined in the return value of the method, the quadruplet is stored in a ArrayList, and the result is stored in ArrayList of ArrayList. That is the basic data structure used in the problem.

Naive Solution:
Likewise the two sum and 3Sum problem, the native solution is very straight-forward. Pick up four elements from the array and check if the sum equals to the target. So it is obvious to know that the time complexity would be as large as O(n^4) in any case, as is required by the problem, we have to find all unique solutions. The Java implementation is omitted here.

A better Solution:
Now we consider the two pointer-based solution. We consider this solution first because it has the same time complexity as the hash table solution, but requires no additional space to store the key-value pairs as is in the hash table. 

Noted that the the problem wanna find a + b + c + d = target? It equals to a + b = target - (c + d)? So the basic idea is given a pair (a, b), we aim to find that if there is another pair (c, d) and the target - its sum equals to the sum of (a ,b). Simply, for target is zero, we aim to find a pair (c ,d) of which the sum equals to the negative sum of (a, b). The trick is to sort the array first, so given a pair (a, b), use two pointers start from the end of the array and right-hand side of the b. 

There is still another question: How do you handle duplicated quadruplets.
E.g. for the input array[-2, -1, -1, 0, 0, 2, 2, 2]. As you've already seen before, the duplicated elements in the array will be placed together due to the sorting.  Here is the idea: if you have seen the number before, you don't need to check it again, just simply jump to next until one is not duplicated. Here is the Java code:

Code (Java):
public class Solution {
    public ArrayList<ArrayList<Integer>> fourSum(int[] num, int target) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        
        if (length < 4) return result;
        
        // Sort the array
        Arrays.sort(num);
        
        for (int i = 0; i < length - 3; i++) {
            if (num[i] > target) break;
            if (i == 0 || num[i] > num[i - 1]) {
                for (int j = i + 1; j < length - 2; j++) {
                    if (j == i + 1 || num[j] > num[j - 1]) {
                        if ((num[i] + num[j]) > target) break;
                        int start = j + 1;
                        int end = length - 1;
                        int newTarget = target - (num[i] + num[j]);
                        while (start < end) {
                            if ((num[start] + num[end]) == newTarget) {
                                ArrayList<Integer> temp = new ArrayList<Integer>(4);
                                temp.add(num[i]);
                                temp.add(num[j]);
                                temp.add(num[start]);
                                temp.add(num[end]);
                            
                                result.add(temp);
                                start++;
                                end--;
                                while (start < end && num[start] == num[start - 1]) start++;
                                while (start < end && num[end] == num[end + 1]) end--;
                            } else if (num[start] + num[end] < newTarget) {  
                                start++;
                            } else {
                                end--;
                            }
                        }
                    }
                }
            }
        }
        return result;
    }
}

Look the code carefully then you will find a bug. Let's see the results first. According to the OJ, it failed the test:
Input:[1,-2,-5,-4,-3,3,3,5], -11
Output:[]
Expected:[[-5,-4,-3,1]]
But why? According to the code above, we sorted the array first, so it becomes [-5, -4, -3, -2, 1, 3, 3, 5], and the target is -11. 
Now you might find the problem. In Line 11, it checks 

if (num[i] > target) break;

However, when the target is a negative number , this condition no longer stands, because as is shown in the data input above, even if num[i] is greater than the target, it is possible that the numbers on the right hand side of num[i] are negative number. So we still have to check. 
For the case when the target is equal or greater than 0, this condition stands anyway, because the num[i] and its later on numbers must be positive numbers. Consequently, after we removed the line 12 and 16, the code becomes correct, which is as shown below:
public class Solution {
    public ArrayList<ArrayList<Integer>> fourSum(int[] num, int target) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        
        if (length < 4) return result;
        
        // Sort the array
        Arrays.sort(num);
        
        for (int i = 0; i < length - 3; i++) {
            if (num[i] > target) break;
            //if (i == 0 || num[i] > num[i - 1]) {
                for (int j = i + 1; j < length - 2; j++) {
                    //if (j == i + 1 || num[j] > num[j - 1]) {
                        if ((num[i] + num[j]) > target) break;
                        int start = j + 1;
                        int end = length - 1;
                        int newTarget = target - (num[i] + num[j]);
                        while (start < end) {
                            if ((num[start] + num[end]) == newTarget) {
                                ArrayList<Integer> temp = new ArrayList<Integer>(4);
                                temp.add(num[i]);
                                temp.add(num[j]);
                                temp.add(num[start]);
                                temp.add(num[end]);
                            
                                result.add(temp);
                                start++;
                                end--;
                                while (start < end && num[start] == num[start - 1]) start++;
                                while (start < end && num[end] == num[end + 1]) end--;
                            } else if (num[start] + num[end] < newTarget) {  
                                start++;
                            } else {
                                end--;
                            }
                        }
                    //}
                }
            //}
        }
        return result;
    }
}

Discussion:
In this solution, sorting the array takes O(nlogn) time, and the nested loop takes O(n^3) time, so the total time complexity is bounded by O(n^3). Since we don't take additional storage in this solution, the space complexity is O(1). 

HashMap Solution:
Here I also introduced the hashMap solution, its time complexity is still O(n^3), but takes additional O(n^3) space to store the key-value pairs, which is not desirable when the storage space is constraint.
public class Solution {
    public ArrayList<ArrayList<Integer>> fourSum(int[] num, int target) {
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        HashMap>Integer, ArrayList>Integer>> hashMap = new HashMap<Integer, ArrayList<Integer>>();
        HashSet<ArrayList<Integer>> hashSet = new HashSet<ArrayList<Integer>>();
        
        final int length = num.length;
        if (length < 4) return result;
        
        Arrays.sort(num);
        
        for (int i = 0; i < length - 3; i++) {
            for (int j = i + 1; j < length - 2; j++) {
                hashMap.clear();
                for (int k = j + 1; k < length; k++) {
                    if (hashMap.containsKey(num[k])) {
                        ArrayList<Integer> temp = new ArrayList<Integer>(4);
                        temp.add(0, hashMap.get(num[k]).get(0));
                        temp.add(1, hashMap.get(num[k]).get(1));
                        temp.add(2, hashMap.get(num[k]).get(2));
                        temp.add(3, num[k]);
                        
                        // Remove duplicated keys
                        if (!hashSet.contains(temp)) {
                            result.add(temp);
                            hashSet.add(temp);
                        }
                    } else {
                        ArrayList<Integer> tmp = new ArrayList<Integer>(3);
                        tmp.add(0, num[i]);
                        tmp.add(1, num[j]);
                        tmp.add(2, num[k]);
                        int key = target - (num[i] + num[j] + num[k]);
                        hashMap.put(key, tmp);
                    }
                }
            }
        }
        return result;
    }
}
  
Note that the code above used Java hashSet. Its purpose is to check the duplicated key in the final result list. If not contains, we add the quadruplet into both the result list and the hashSet. Since get elements from hashSet requires O(1) time complexity, it will not complicate the overall implementation, but makes the code looks neat. However, it is at the cost of additional space to store the final result. Also please note that since we clear the hashMap before each inner loop, therefore, the hashMap at most stores the number of n - 2 key-value pairs, which at the cost of O(n). 

In this solution, we used the Java HashSet, what the difference between HashMap and HashSet? The difference is actually obvious:
In a HashMap, you store the key-value pair;
In a HashSet, you store only keys.

The Java HashSet has the following common instant methods:
  • boolean add(E e) -- Add the specified element, e, into the set if it is not present and return the true. If the set has already contains the element, the call leaves the set unchanged and returns false. Note it is very different from HashMap put(Key key, Value value), where if the map contains the key to be added, it will simply override the key with the updated value. 
  • void clear() -- Clear all elements of the set
  • boolean contains(Object o)  -- returns true if the set contains the specified element
  • boolean remove(Ojbect o) -- remove the specified element o, returns true if the list contains that element. If the set does not contain that element, returns false.
  • int size() -- returns the number of the elements in the set

Summary:
So far we have seen the k-sum problems, where the k equals to 2, 3 or 4. To summarize, we might observe that the k-sum problem could have the time complexity ofO(n^(k-1)). 
And the bast space complexity is O(1). 

Last but not least, the big lesson we learned from this post is it is very important to make every single line of code count. It is not a good idea to rush and write the code. You need to analyze the requirements of the problem very clearly at first before you implement your code. You should also make sure every single line in the code makes sense. That is very important to write bug-free code. Another tip for Leetcode OJ is whenever you finish your code, don't rush to submit your code, you should go through your code and use several test cases to check if there is any bug. In a real code interview, there is no such an online compiler which can help debugging. So make your best to submit only bug-free code to OJ. 

1 comment:

  1. There is a O(n^2) solution to this problem using hashtable, key idea is to hash all the pair sums and then find out pairs which sum upto target

    ReplyDelete