Do you train for your upcoming coding interview? This question was asked by Google as reported in multiple occasions by programmers all around the world. Can you solve it optimally?
Let’s dive into the problem first.
Problem Formulation
Given an integer array or Python list nums
and an integer value k
.
Find and return the k-th
largest element in the array.
Constraints: You can assume that k
is a number between 1 and the length of the nums
list.
1 <= k <= nums.length
Therefore, it is implicitly ensured that the list nums
has at least one element and there always must be exactly one solution.
Examples
Let’s have a look at some examples to improve our understanding of this problem.
Example 1 Input: [1, 2, 3, 4, 5], k=2 Output: 4 Example 2 Input: [42, 1, 3, 2], k=1 Output: 42 Example 3 Input: [3], k=1 Output: 3 Example 4 Input: [3, 42, 30, 1, 32, 100, 44, 13, 28, 99, 100000], k=4 Output: 44
Naive Solution: Sorting
The most straightforward way to return the k-th largest element from a list is as follows:
- Sort the list in descending order. The largest element is at position 0.
- Access the (k-1)-th element of the sorted list and return it. This is the k-th largest element.
Here’s the code that accomplishes that:
def find_k_largest_element(nums, k): sorted_nums = sorted(nums, reverse=True) return sorted_nums[k-1]
You use the sorted()
function to create a new sorted list. As the first argument, you pass the list to be sorted. As second argument, you pass reverse=True which ensures that the largest element appears at the first position, the second largest element at the second position, and so on.
Given the sorted list, you now need to access the k-th element from the list. As we use zero-based indexing in Python, the k-th largest element has index (k-1).
Let’s run this on our examples:
# Example 1 lst = [1, 2, 3, 4, 5] k = 2 print(find_k_largest_element(lst, k)) # 4 # Example 2 lst = [42, 1, 3, 2] k = 1 print(find_k_largest_element(lst, k)) # 42 # Example 3 lst = [3] k = 1 print(find_k_largest_element(lst, k)) # 3 # Example 4 lst = [3, 42, 30, 1, 32, 100, 44, 13, 28, 99, 100000] k = 4 print(find_k_largest_element(lst, k)) # 44
Yes, this passes all test!
Analysis: The code consists of two lines: sorting the list and accessing the k-th element from the sorted list. Accessing an element with a given index has constant runtime complexity O(1). The runtime of the algorithm, therefore, is dominated by the runtime for sorting a list with n elements. Without any further information about the list, we must assume that the worst-case runtime complexity of sorting is O(n log n), so it grows superlinearly with an increasing number of elements.
Discussion: Intuitively, we do a lot of unnecessary work when sorting the list given that we’re only interested in the k-th largest element. All smaller elements are of no interest to us. We observe that we do need to know the (k-1) larger elements, so that we can figure out the k-th largest. Is there a better way than O(n log n)?
Iteratively Removing the Maximum
Observation: Finding the largest element only has linear runtime complexity O(n): we need to traverse the list once and compare each element against the current maximum. If the current element is larger, we simply update our maximum. After traversing the whole list, we’ve determined the maximum with only n-1 comparisons.
- If k=1, this is already the solution and the runtime complexity is O(n) instead of O(n log n).
- If k>1, we can repeat the same procedure on the smaller list—each time removing the current maximum from the list.
The overall runtime complexity would be O(k*n) because we need to perform n comparisons to find one maximum, and repeat this k times.
The following code implements this exact algorithm:
def find_k_largest_element(nums, k): for i in range(k-1): nums.remove(max(nums)) return max(nums)
In each iteration i, we remove the maximum. We repeatedly remove the maximum (k-1) times as controlled by the range()
function. After the loop is terminated, the maximum in the list is the k-th largest element. This is what we return to the user.
Discussion: This algorithm has runtime complexity O(k*n) compared to the runtime complexity of the sorting method of O(n log n). So, if k<log(n), this is in fact the more efficient algorithm. However, for k>log(n), this algorithm would be worse!
Can we do better?
Hybrid Solution to Get Best of Both Worlds
In the previous discussion, we’ve observed that if k>log(n), we should use the algorithm based on sorting and if k<log(n), we should use the algorithm based on repeatedly removing the maximum. For k=log(n), it doesn’t really matter. Fortunately, we can use this simple check at the beginning of our code to determine the best algorithm to execute.
import math def find_k_largest_sort(nums, k): sorted_nums = sorted(nums, reverse=True) return sorted_nums[k-1] def find_k_largest_remove_max(nums, k): for i in range(k-1): nums.remove(max(nums)) return max(nums) def find_k_largest_element(nums, k): n = len(nums) if k > math.log(n, 2): return find_k_largest_sort(nums, k) else: return find_k_largest_remove_max(nums, k)
The code shows the function find_k_largest_element
that either executes the sorting-based algorithm if k > log(n) or the removal-based algorithm otherwise.
Discussion: By combining both algorithms this way, the overall runtime complexity drops to O(min(k, log(n)) * n) which is better than either O(n * log(n)) or O(n * k).
Can we do even better?
Best Solution with Sorted List of Top k Elements
The removal-based algorithm has the main problem that we need to perform the min()
computation on the whole list. This is partly redundant work. Let’s explore an alternative idea based on a sliding window that largely removes the overhead of computing the min repeatedly.
The idea of the following algorithm is to maintain a window of the k largest elements in sorted order. Initially, you fill the window with the first k elements from the list. Then, you add one element to the window at a time, but only if it is larger than the minimum from the window. The trick is that as the window of k elements is sorted, accessing the window has O(1) constant runtime complexity. Then you repeat this behavior (n-k) times.
Here’s an example run of the algorithm:
You start with the list [5, 1, 3, 8, 7, 9, 2]
and the sorted window [1, 3, 5]
. In each iteration, you check if the current element is larger than the minimum at position 0 of the sorted window. For elements 8, 7, and 9, this is indeed the case. In these instances, you perform a sorted insert operation to add the new element to the window after removing the previous minimum from the window. After one complete run, you’ll have the k largest elements in the window.
Here’s a runtime analysis of the algorithm that shows that the runtime is only O(n log k) which is the best we accomplished so far.
Let’s have a look at the code:
import bisect def find_k_largest_element(nums, k): window = sorted(nums[:k]) for element in nums[k:]: if element > window[0]: # Remove minimum from window window = window[1:] # Sorted insert of new element bisect.insort(window, element) return window[0]
The code uses the bisect.insort()
method to perform the sorted insert operation into the window. You should know how sorted insert actually works. However, in a coding interview, you can usually assume you have access to basic external functionality. Here’s a basic recap on the idea of sorted insert:
? Concept Sorted Insert: To insert an element into a sorted list, you peak the mid element in the list and check if it is larger or smaller than the element you want to insert. If it is larger, all elements on the right will also be larger and you can skip them. If the mid element is smaller, all elements on the left will be smaller as well and you can skip them. Then, you repeat the same halving the potential elements each time until you find the right position to insert the new element.
As sorted insert repeatedly halves the interval, it only takes O(log k) operations to insert a new element into a sorted list with k elements. This is the core idea of the whole algorithm so make sure you understand it!