Friday, July 25, 2014

Leetcode: 3Sum

Given an array S of n integers, are there elements abc in S such that a + b + c = 0? Find all unique triplets in the array which gives the sum of zero.
Note:
  • Elements in a triplet (a,b,c) must be in non-descending order. (ie, a ≤ b ≤ c)
  • The solution set must not contain duplicate triplets.
    For example, given array S = {-1 0 1 2 -1 -4},

    A solution set is:
    (-1, 0, 1)
    (-1, -1, 2)
 

Analysis:
This question has some different requirements than the two sum.1. It returns the value instead of the key in the array2. Elements in a triplet (a,b ,c) must be in non-descending order (i.e., a <= b <= c)3. The solution set must contains all unique triplets in the array, but should not any duplicate triplets. For instance, one solution is (-1, 0, 1), the solution set must not contain a duplicate solution (0, 1, -1) and it is not in order as well.
4. What the return value? It is List<List<Integer>>.  For Java List, there is good article introduced about  it. http://tutorials.jenkov.com/java-collections/list.html
I will summarize the List later. 

Naive Solution:
In the two sum problem, we introduced three solutions: native (which has O(n^2)), hashMap, and sorting. For this problem, the naive solution is simple: just choose any of three from the array. One thing needs to take care is the triplet must be in non-descending order.  It could be simply resolved by sorting the array. To handle the duplicated triplets, it first check if the arrayList contains that key. The Java program is as below:

Code (Java):
public class Solution {
    public ArrayList<ArrayList<Integer>> threeSum(int[] num) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        
        
        // for array less than 3 elements, no need to check
        if (length < 3) return result;
        
        // Sort the array
        Arrays.sort(num);
        
        for (int i = 0; i < length; i++) {
            if (num[i] > 0) break;
            for (int j = i + 1; j < length; j++) {
                if (num[i] + num[j] > 0) break;
                for (int k = j + 1; k < length; k++) {
                    if (num[i] + num[j] + num[k] == 0) {
                        ArrayList<Integer> elem = new ArrayList<Integer>>();
                        elem.add(num[i]);
                        elem.add(num[j]);
                        elem.add(num[k]);
                        if (!result.contains(elem)) result.add(elem);
                    }
                }
            }
        }
        return result;
    }
}

It is easily to analyze the solution. The time complexity is O(n^3). Actually, inside of the inner loop, if sum of the triplet equals to zero, the contains() method requires a linear time to iterate the result ArrayList. So the worst case complexity could be even O(n^4). As a result, it failed the OJ for the reason of "time exceeded".

Before we go to an improved solution, there is another implementation reveals an interesting bug:
Code (Java):
public class Solution {
     public ArrayList<ArrayList<Integer>> threeSum(int[] num) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        ArrayList<Integer> elem = new ArrayList<Integer>();
        
        // for array less than 3 elements, no need to check
        if (length < 3) return result;
        
        // Sort the array
        Arrays.sort(num);
        
        for (int i = 0; i < length; i++) {
            if (num[i] > 0) break;
            for (int j = i + 1; j < length; j++) {
                if (num[i] + num[j] > 0) break;
                for (int k = j + 1; k < length; k++) {
                    if (num[i] + num[j] + num[k] == 0) {
                        elem.clear();
                        elem.add(num[i]);
                        elem.add(num[j]);
                        elem.add(num[k]);
                        if (!result.contains(elem)) result.add(elem);
                        
                    }
                }
            }
        }
        return result;
    }
}

It looks almost the same as the previous solution, the mainly difference is the triplet is "newed" before the nested loops. Whenever the sum of the triplet equals to zero we clear all the elements in the triplet using instant methods clear() before we add new values. Where is the problem? Let's see the output first

Input:[-2,0,1,1,2]
Output:[[-2,1,1]]
Expected:[[-2,0,2],[-2,1,1]]

Where is the first triplet? If you suspect the clear caused the error, you must be on the right track. We all know that in Java object is passed by reference. In line 23 when we add the Object elem into the list, we only copy a shadow reference of the object instead of its value. So if later on when we found another triplet of which sum equals to zero, we clear the triplet, which means we changed the value of the list into null and resize the list to zero. That change will be reflected in the added list. That is why we actually deleted the last found list and it only returns the last found triplet. 
There is an interesting discussion on this topic: http://stackoverflow.com/questions/7080546/add-an-object-to-an-arraylist-and-modify-it-later

Now we know where the problem comes from? So how to resolve it? In Java ArrayList, there is a clone() instant method, which returns a shadow copy of this ArrayList instance. So we can use this method to "clone()" a new instance of the reference. In the previous code, the line 23 could be changed to:
if (!result.contains(elem)) result.add((ArrayList<Integer>)elem.clone());

Improved Solution:
Now let's move to the improved solution: Remember that in the two sum, we used hash table to reduce the time complexity from O(n^2) to O(n). Can we borrow the similar idea to this three sum problem? The answer is yes. 

Code (Java):
public class Solution {
    public ArrayList<ArrayList<Integer>> threeSum(int[] num) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        HashMap<Integer, int[]> hashMap = new HashMap<Integer, int[]>();
        
        // if length is less than 3, return empty result set
        if (length < 3) return result;
        
        Arrays.sort(num);
        
        for (int i = 0; i < length - 2; i++) {
            if (num[i] > 0) break;
            hashMap.clear();
            
            if (i == 0 || num[i] > num[i - 1]) {
                for (int j = i + 1; j < length; j++) {
                    if (hashMap.containsKey(num[j])) { // found target
                        ArrayList<Integer> elem = new ArrayList<Integer>(3);
                    
                        elem.add(hashMap.get(num[j])[0]);
                        elem.add(hashMap.get(num[j])[1]);
                        elem.add(num[j]);
                        
                        result.add(elem);
                        
                        // remove duplicated elements
                        while (j < (length - 1) && num[j] == num[j + 1]) j++;
                    } else {
                        int[] temp = new int[2];
                        temp[0] = num[i];
                        temp[1] = num[j];
                        hashMap.put(0 - (num[i] + num[j]), temp);
                    }
                }
            }
        }
        return result;
    }
}
Discussion:
In this solution, time complexity is O(n^2). The space complexity is also O(n^2) since we used additional n^2 storage to store key-value in the hash table. 

Compared to the solution in the two-sum problem, the code above is much trickier. First of all, it cleared the hash map using hashMap.clear() before each inner loop. Why? Remember that the basic idea of using the hash map is to determine if num[j] in hashMap ? If not, put num[i] and num[j] into the table. For instance, for the input array [-2, -1, 1, 2]. In the first round when i = 0, the key-value pairs in the hash table would be:
<3, [-2, -1]>, <1, [-2, 1]>, <0, [-2, 2]>. In the second round when i = 1, it will starts from j = 2, and check if num[j] = 1 is in the hash table. If we doesn't clean the first round of the table, we will find the key equals to 1 falsely existed, and output the triplet <-2, 1, 1> which does not exist at all! 

The second trick is to remove the duplicated triplets. We can use the property that the array is sorted. For each inner loop, we check if the number we starts has been checked before. And also, after we add a triplet into the result list, we check if the number with greater indices has the same value; if yes, we just simply jump this number.

An even better solution:
Remember in the two-sum problem, we proposed a two-pointers based solution, which has time complexity of O(nlogn). We can extend that solution here:
public class Solution {
    public ArrayList<ArrayList<Integer>> threeSum(int[] num) {
        final int length = num.length;
        ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();
        
        if (length < 3) return result;
        
        // Sort the array
        Arrays.sort(num);
        
        for (int i = 0; i < length - 2; i++) {
            if (num[i] > 0) break;
            if (i == 0 || num[i] > num[i - 1]) {
                int target = 0 - num[i];
                int start = i + 1;
                int end = length - 1;
                while (start < end) {
                    if (num[start] + num[end] == target) {
                        ArrayList<Integer> elem = new ArrayList<Integer>();
                        elem.add(num[i]);
                        elem.add(num[start]);
                        elem.add(num[end]);
                        
                        result.add(elem);
                        start++;
                        end--;
                        
                        // Remove duplicated results
                        while (start < end && num[end + 1] == num[end]) end--;
                        while (start < end && num[start - 1] == num[start]) start++;
                    } else 
                        if (num[start] + num[end] > target) end--;
                        else start++;
                 }
            }
        }
        return result;
    }
}



It has the same time complexity, i.e, O(n^2) as the hash table solution, but does not require additional storage. Furthermore, array access is much faster than hash table put and get, so it should be faster than the hash table solution in reality. Please be aware the tricks in the solution in removing the duplicated keys. 

Conclusion:
Three sum is very closed to the two-sum problem, but the implementation is much trickier. It is important to understand how the duplicated keys are well handled. Last but not least, each of the three solution has its pros and cons, both in time and space complexity, it is important to analyze and stand your idea why you think of this way.

3 comments:

  1. "array access is much faster than hash table put and get" hashtables get and put has avg runtime of O(1)?

    ReplyDelete
    Replies
    1. The **amortized** average cost of get() in a hash table is O(1). In worst-case, it is O(n). The cost of get() in an array is always O(1).

      Delete
  2. This comment has been removed by the author.

    ReplyDelete