Friday, October 31, 2014

Facebook: Weighted Interval Scheduling

Given a set of n jobs with [start time, end time, cost], find a subset so that no 2 jobs overlap and the cost is maximum?


A DP Solution:
  • Sort the intervals by end time. 
  • Define the DP array
    • dp[n + 1], where dp[i] means the the first set non-overlapping of i jobs, save the maximum cost. 
  • Initial state: dp[0] = 0;
  • Transit function:
    • dp[i] = Math.max(dp[i - 1], dp[p[i - 1] + 1] + interval[i - 1].cost ), where the p[i - 1] is the index with the end time that is closest to interval[i - 1].end
  • Final state: dp[n].
Code (Java):
import java.util.*;
public class Solution {
    public int weightedIntervalScheduling(List<Event> events) {
        if (events == null || events.size() == 0) {
            return 0;
        }
        
        // sort by end time.
        Collections.sort(events, new EventComparator());
        
        int[] dp = new int[events.size() + 1];
        dp[0] = 0;
        
        for (int i = 1; i <= events.size(); i++) {
            int noChoose = dp[i - 1]; // no choose
            int idx = findNearestIndex(i, events, events.get(i - 1).start);
            int choose = dp[idx + 1] + events.get(i - 1).cost;
            dp[i] = Math.max(choose, noChoose);
        }
        
        return dp[events.size()];
    }

    // binary search for the end index which is nearest to target
    private int findNearestIndex(int end, List<Event> events, int target) {
        int lo = 0;
        int hi = end - 1;
        
        if (target < events.get(0).end) {
            return -1;
        }
        
        if (target > events.get(hi).end) {
            return hi;
        }
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (events.get(mid).end == target) {
                return mid;
            }  else if (events.get(mid).end > target) {
                hi = mid;
            } else {
                lo = mid;
            }
        }
        
        if (events.get(lo).end == target) {
            return lo;
        } else if (events.get(hi).end == target) {
            return hi;
        } else {
            return lo;
        }
    }

    private class EventComparator implements Comparator<Event> {
        public int compare(Event a, Event b) {
            return a.end - b.end;
        }
    }

    public static void main(String[] args) {
        Solution sol = new Solution();
        
        List<Event> events = new ArrayList<Event>();
        events.add(new Event(1,4,1));
        events.add(new Event(2,5,2));
        events.add(new Event(0,6,3));
        events.add(new Event(2,10,4));

        System.out.println(sol.weightedIntervalScheduling(events));
    }
}

Discussion:
The time complexity of DP approach is O(n logn), because of the binary search.

Follow-up:
What if we want the solution itself? i.e. the events that produce the maximal cost?
Then we need some post processing. 

The idea is to back track from the end recursively, track the maximal path until the beginning. 

Code (Java):
import java.util.*;
public class Solution {
    public List<Integer> weightedIntervalScheduling(List<Event> events) {
        List<Integer> result = new ArrayList<Integer>();
        if (events == null || events.size() == 0) {
            return result;
        }
        
        // sort by end time.
        Collections.sort(events, new EventComparator());
        
        int[] dp = new int[events.size() + 1];
        dp[0] = 0;
        
        for (int i = 1; i <= events.size(); i++) {
            int noChoose = dp[i - 1]; // no choose
            int idx = findNearestIndex(i, events, events.get(i - 1).start);
            int choose = dp[idx + 1] + events.get(i - 1).cost;
            dp[i] = Math.max(choose, noChoose);
        }
        
        outputSolution(dp, events, events.size(), result);
        
        
        return result;
    }
    
    private void outputSolution(int[] dp, List<Event> events, int i, List<Integer> result) {
        if (i == 0) {
            return;    
        }
        
        int idx = findNearestIndex(i, events, events.get(i - 1).start);
        if (dp[idx + 1] + events.get(i - 1).cost > dp[i - 1]) {
            result.add(i - 1);
            outputSolution(dp, events, idx + 1, result);
        } else {
            outputSolution(dp, events, i - 1, result);
        }
    }

    // binary search for the end index which is nearest to target
    private int findNearestIndex(int end, List<Event> events, int target) {
        int lo = 0;
        int hi = end - 1;
        
        if (target < events.get(0).end) {
            return -1;
        }
        
        if (target > events.get(hi).end) {
            return hi;
        }
        
        while (lo + 1 < hi) {
            int mid = lo + (hi - lo) / 2;
            if (events.get(mid).end == target) {
                return mid;
            }  else if (events.get(mid).end > target) {
                hi = mid;
            } else {
                lo = mid;
            }
        }
        
        if (events.get(lo).end == target) {
            return lo;
        } else if (events.get(hi).end == target) {
            return hi;
        } else {
            return lo;
        }
    }

    private class EventComparator implements Comparator<Event> {
        public int compare(Event a, Event b) {
            return a.end - b.end;
        }
    }
    
    

    public static void main(String[] args) {
        Solution sol = new Solution();
        
        List<Event> events = new ArrayList<Event>();
        events.add(new Event(1,4,1));
        events.add(new Event(2,5,2));
        events.add(new Event(0,6,3));
        events.add(new Event(5,10,4));
        
        List<Integer> result = sol.weightedIntervalScheduling(events);
        
        for (Integer elem : result) {
            System.out.print(elem + " ");
        }

        System.out.println();
    }
}





No comments:

Post a Comment