Idiomatic Rust? Implementing binary search

It's time to implement another well-known algorithm in Rust - this week's choice is Binary Search.

As always, this post will aim to highlight idiomatic Rust patterns, as well as provide interactive visualizations.


Note: This ended up being a part 1, I added more improvements in part 2 here


Binary search

The purpose of this algorithm is to determine if a given search value exists in a sorted array. You can read more about it on the wikipedia page and I've also created the following visualization to go with it:

binary search visualization

First Attempt in Rust

As you can see from above, the implementation I chose uses a high and low cursor that begin at either end of the array. We then enter a while loop and find the middle point each time which is compared to the search target.

Each time that comparison is made we decide 1 of 3 things:

  • if middle matches the search target, we're done
  • if middle is lower than the search target, we discard everything before it, halving the list. We then move our low cursor to the start of the remainder of the list
  • likewise, if middle is higher than the search target, everything after it is discarded and the high cursor moves to the relevant place.

So this was my attempt:

rust
///
/// Perform a binary search on a sorted list
///
pub fn binary_search(k: i32, items: &[i32]) -> Option<i32> {
    let mut low: i32 = 0;
    let mut high: i32 = items.len() as i32 - 1;

    while low <= high {
        let middle = (((high + low) / 2) as f64).floor() as i32;
        if let Some(current) = items.get(middle as usize) {
            if *current == k {
                return Some(middle);
            }
            if *current > k {
                high = middle - 1
            }
            if *current < k {
                low = middle + 1
            }
        }
    }
    None
}

#[test]
fn test_binary_search() {
    let items = vec![1, 2, 3, 4, 5];
    assert_eq!(Some(0), binary_search(1, &items));
    assert_eq!(Some(1), binary_search(2, &items));
    assert_eq!(Some(2), binary_search(3, &items));
    assert_eq!(Some(3), binary_search(4, &items));
    assert_eq!(Some(4), binary_search(5, &items));
    assert_eq!(None, binary_search(0, &items));
    assert_eq!(None, binary_search(90, &items));
    assert_eq!(None, binary_search(9000000, &items));

    let items = vec![2, 4, 6, 80, 90, 120, 180, 900, 2000, 4000, 5000, 60000];
    assert_eq!(None, binary_search(1, &items));

    assert_eq!(None, binary_search(1, &[]));
}

Seasoned Rust developers may notice something that I totally over-engineered here, mostly because of my background in JavaScript, let's take a look.

A couple of fixes...

Fix 1: Removing the call to floor

Part of this algorithm requires you to continuously find the mid-point between your remaining items - this becomes the new 'middle'. We do this by adding the high + low values, and then dividing by 2. so if high=5 and low=0, then its (5 + 0) / 2 = 2.5.

The issue here is that we need to use a whole number to perform the next lookup. We can't do items.get(2.5) as this just doesn't make any sense.

In a higher-level language, such as JavaScript, you would prevent index access with a none-integer by manually the flooring value after the division:

js// Safe to use in an index access
const middle = Math.floor((high + low) / 2);

Which is what I originally tried to replicate in Rust, with the following:

rustfn binary_search() {
    // snip...
    let middle = (((high + low) / 2) as f64).floor() as i32;
    // snip...
}

Not only did it look gross - it even felt wrong to me when I wrote it.

I was using my JavaScript implementation as reference, but converted to Rust it just seemed overly complex.

So as always I reached out to the Twitter/Rust community and thanks to a reply by the ever-helpful Basile Henry, I realised that in Rust this flooring is not actually needed at all.

Division in Rust is achieved with the Div trait, and as the documentation states...

rustimpl Div<i32> for i32 {
    // This operation rounds towards zero,
    // truncating any fractional part of the exact result.
}

... and since in my original example the high and low variables were both of type i32 - this means that a division between those 2 primitives would need to yield a value of the same type. To make this possible, the implementation i32 / i32 performs the flooring automatically. Nice.

Annotated with the type for clarity, the before and after would look like this

rustfn binary_search() {
    // before
    let middle: i32 = (((high + low) / 2) as f64).floor() as i32;
    // after
    let middle: i32 = (high + low) / 2;
}

Fix 2: less casting by using usize for cursors

Our example is searching for an i32 in a slice of i32's - but the mistake I made here was also using i32 values as the high and low cursors.

I had let mut low: i32 = 0 and let mut high: i32 = items.len() as i32 - 1; which forced me to cast to a usize when looking up each middle value since .get() needs a type that's a valid index, such as usize.

rustfn binary_search(k: i32, items: &[i32]) {
    // before:
    //   cast needed to `usize` since `middle` is `i32`
    let mut low: i32 = 0;
    let mut high: i32 = items.len() as i32 - 1;
    if let Some(current) = items.get(middle as usize) {

    }

    // after:
    //   when middle is `usize` it's a valid index,
    //   so no casting is required
    let mut low: usize = 0;
    let mut high: usize = items.len() - 1;
    if let Some(current) = items.get(middle) {

    }
}

That's a straight-forward change that further cleans up the code, removing another number conversion - but it comes with a slight trade off, explained in the next fix.

Fix 3: ensure usize stays above zero

Whilst it's more idomatic to use the usize number type as cursors in our example, we need to ensure we don't decrement any of them below zero, otherwise we'll get a runtime panic:

rust#[test]
fn test_usize() {
    let mut n: usize = 0;
    n -= 1; // !!! thread panicked at 'attempt to subtract with overflow'
}

This happens because we've explicitly used an 'unsigned' integer number type - that's what the u stands for in usize. It means it must be a positive integer at all times.

This is not as dangerous as it sounds, our binary search algorithm just needs a couple of extra checks: one before we decrement the low cursor and one right at the beginning to ensure we have a none-empty slice to loop through in the first place.

rust// before:
//   if unchecked, this could cause
//   a runtime panic because `middle` is
//   now a `usize`
if *current > k {
    high = middle - 1 // <-- danger
}


// after:
//   if `middle` has reached zero, end
//   the algorithm altogether and prevent
//   the runtime panic
if *current > k {
    if middle == 0 {
        return None;
    }
    high = middle - 1
}

Final solution

Putting everything together, we arrive at what I would call an 'idiomatic' Rust implementation. I would love to hear of further suggestions/improvements though - so please reach out on Twitter if you have any :)


Note: This ended up being a part 1, I added more improvements in part 2 here


rust
///
/// Perform a binary search on a sorted list
///
pub fn binary_search(k: i32, items: &[i32]) -> Option<usize> {
    if items.is_empty() {
        return None;
    }

    let mut low: usize = 0;
    let mut high: usize = items.len() - 1;

    while low <= high {
        let middle = (high + low) / 2;
        if let Some(current) = items.get(middle) {
            if *current == k {
                return Some(middle);
            }
            if *current > k {
                if middle == 0 {
                    return None;
                }
                high = middle - 1
            }
            if *current < k {
                low = middle + 1
            }
        }
    }
    None
}