Idiomatic Rust? Implementing binary search (part 2)

This is a follow-up to the original article where I looked at a few ways to improve my Rust implementation of Binary search - with a focus on removing 'mistakes' and making it as 'idiomatic' as possible.

The most common piece of feedback I received on the first article was related to the comparison between the middle value and high/low cursors.

Wasn't too far off

In my original Tweet (even before the first article) I had presented a variant of the algorithm that used pattern matching for this comparison via Rust's match construct.

One of my first attempts

rustpub fn binary_search_alt(k: i32, items: &[i32]) -> Option<i32> {
    let mut low: i32 = 0;
    let mut high: i32 = items.len() as i32 - 1;

    while low <= high {
        let middle = (((high + low) / 2) as f64).floor() as i32;
        match items.get(middle as usize) {
            Some(current) if *current == k => return Some(middle),
            Some(current) if *current > k => high = middle - 1,
            Some(current) if *current < k => low = middle + 1,
            _ => {}
        }
    }
    None
}

If we apply the feedback from the first article to this implementation, we'd end up with the following - which is just a bit cleaner since it removes some casting along with the manual 'flooring' I was originally doing.

rustpub fn binary_search_alt(k: i32, items: &[i32]) -> Option<usize> {
    let mut low: usize = 0;
    let mut high: usize = items.len() - 1;

    while low <= high {
        let middle = (high + low) / 2;
        match items.get(middle) {
            Some(current) if *current == k => return Some(middle),
            Some(current) if *current > k => {
                if middle == 0 { return None; };
                high = middle - 1
            }
            Some(current) if *current < k => { low = middle + 1; }
            _ => {}
        }
    }
    None
}

The debate begins.

This is where it became a bit complicated. I originally received some feedback on Twitter that suggested because I'm not using the return value of the match block in any of the arms, along with the fact that I'm only using it to mutate the high/low cursors, that this could be deemed not idiomatic Rust...

I fully understood that feedback, and to rectify the match could be written like so:

rustmatch items.get(middle as usize) {
    Some(current) if *current == k => { return Some(middle); },
    Some(current) if *current > k => {
        if middle == 0 { return None };
        high = middle - 1;
    },
    Some(current) if *current < k => { low = middle + 1; },
    _ => {}
}

Now each match arm has no trailing expression, possibly making it clearer to the next developer that we didn't intend on returning a value here from the match block.

Is this implementation more idiomatic? I'm not entirely sure.

More feedback

There are 2 reasons that I'm not sure about this.

1: Rust makes heavy use of pattern matching, and whilst match is technically an expression itself (and therefore can be used as a value), I don't think it's an issue if it's not always used as a value?.

2: It seems like it would be more of an issue if each arm of the match had different semantics. For example, if one arm returned a value and another didn't. This would make the implementation inconsistent, and the viewpoint may hold more water, but Rust's type system will not allow this anyway since each arm of the match must return the same value.

For those 2 reasons, the rest of this post will focus on further improving our algorithm to fully utilize match - but please reach out on Twitter if you disagree, I'd love to continue the conversation 😀

A new Baseline

Improvement 1: Removing the Option

Our algorithm is searching within a slice of i32 values. On each iteration we are accessing a value from the slice with the .get() method. This was done for safety since a runtime panic can occur if you attempt an index-access with a value that would be out-of-bounds.

rustfn this_will_panic() {
    let items = vec![10, 40, 90];
    let forth_item = items[3];
    println!("{:?}", forth_item)
    // 'main' panicked at 'index out of bounds:
    // the len is 3 but the index is 3'
}
rustfn this_will_not_panic() {
    let items = vec![10, 40, 90];
    let first_item = items.get(0);
    println!("{:?}", first_item);
    // prints `Some(10)`
    let forth_item = items.get(3);
    println!("{:?}", forth_item)
    // prints `None`
}

So the second way of accessing a value may be safer, but it comes at the cost of an additional layer of indirection in the form of that Option type.

rust// return type is Option<i32>
let item = items.get(4)

// return type is i32, but may panic
let item = items[4]

This affected our algorithm in the first article since we needed to use an if let expression to expose the current value via the Some(current) pattern.

rust//     ↓↓↓↓↓↓↓↓↓↓↓↓
if let Some(current) = items.get(middle) {
    if *current == k {
        // snip
    }
    if *current > k {
        // snip
    }
    if *current < k {
        // snip
    }
}

So, to remove the Option and therefore also remove .get() and Some(current), we'd need to be mathematically sure to be within the slice bounds - otherwise we'd get a runtime panic! Well it turns out that our original algorithm was actually doing all the checks we need - it always ensures the next index access is above zero and is below the max length of the slice.

This means that in terms of the original article, we could've replaced the if let Some(current) with a simple index access which makes the solution less syntactically noisy and is simpler overall.

rust// we know this won't panic since we control 'middle'
//            ↓↓↓↓↓↓↓↓↓↓↓↓↓
let current = items[middle]
if current == k {
   // snip
}

Now we can be confident that a direct index access is always safe, since we control the middle value, but how does this help us with the match expression mentioned previously?

Improvement 2: Making the comparison only once

If all we did was take the match expression from the beginning of this article, and remove the Option as mentioned above, we'd end up with something that looked like this

rust// before
match items.get(middle) {
    Some(current) if *current == k => return Some(middle),
    Some(current) if *current > k => {
        if middle == 0 { return None };
        high = middle - 1
    },
    Some(current) if *current < k => { low = middle + 1; },
    _ => {}
}

// after
match items[middle] {
    current if current == k => { return Some(middle); },
    current if *current > k => {
        if middle == 0 { return None };
        high = middle - 1
    },
    current if *current < k => { low = middle + 1; },
    _ => {}
}

Which is barely even an improvement 🙁! If anything, I'd say it's actually harder to read.

It's missing the next big step, which is also another move towards more idiomatic Rust - and that's to take advantage of the fact that the Ord trait is implemented for i32 in the standard library.

The Ord trait has a .cmp(other) method which returns a variant of the Ordering enum, it's definition looks like this...

rust// std::cmp
pub trait Ord: Eq + PartialOrd<Self> {
    fn cmp(&self, other: &Self) -> Ordering;
    // snip
}

... and an example in isolation would look like this:

rustfn main() {
    let first: i32 = 1;
    let second: i32 = 3;
    let result = first.cmp(&second);
    println!("result = {}", result)
    // outputs `Less`, `Greater` or `Equal` respectively
}

Notice how this is consolidating 3 separate comparisons into a single method call - it removes the need to manually check equal, greater, and less in an imperative style and instead it ends up being more declarative since we're no longer defining the actual implementation.

After the call to .cmp(other) we now have a value which is equal to one of Ordering's 3 enum variants and this is where languages with Pattern Matching really shine since we can do a match on the variant and simplify our code quite a bit.

rustmatch items[middle].cmp(&k) {
    Ordering::Equal => return Some(middle),
    Ordering::Greater => {
        if middle == 0 { return None };
        high = middle - 1
    },
    Ordering::Less => low = middle + 1
}

Along with being a more declarative style, it's also much less noisy with far fewer things to mentally parse. 😇

Also, since the matching on the Ordering enum is exhaustive, we no-longer need the empty _ => {} as a final catch-all match arm either. 🙏

A final thing to note here is that since we're invoking a method implemented for the Ord trait, it means that our algorithm could be made more generic in the future to include searching for any type that implements Ord, and not just i32 as in our example. This would make a great follow-on blog post, highlighting the power of traits in Rust, please reach out to me on Twitter if you'd like to see that post happen 🐦

Putting it together

After applying both improvements documented so far (removing Option + doing a single comparison) we end up with the following implementation - which is clearly a big improvement 👌.

rust
use std::cmp::Ordering;

fn binary_search(k: i32, items: &[i32]) -> Option<usize> {
    if items.is_empty() { return None }

    let mut low: usize = 0;
    let mut high: usize = items.len() - 1;

    while low <= high {
        let middle = (high + low) / 2;
        match items[middle].cmp(&k) {
            Ordering::Equal => return Some(middle),
            Ordering::Greater => {
                if middle == 0 { return None };
                high = middle - 1
            },
            Ordering::Less => low = middle + 1
        }
    }
    None
}

Final improvement: removing the checks related to usize dropping below zero.

There are 2 lines in our implementation that still feel 'over-engineered' - or rather, it feels like they could be improved, or removed.

rust// this is the opening check in the function
if items.is_empty() { return None }

// part of the `Greater` match arm, need to
// ensure we don't subtract below zero
if middle == 0 { return None };

To solve this, we need to address the core algorithm. If we define high as the exclusive upper bound and only ever re-assign it to either the length of the slice at the start of the function, or to a subsequent middle value, we can be sure that we'll never decrease its value below zero, and therefore we can remove both of the checks mentioned above.

We'd be ensuring that none of the mutable usize cursors can ever drop below zero, and that would allow us to remove those two manual checks.

Final implementation?

That leaves us with the following, which I'm starting to feel happy about after these 2 long blog posts & the feedback I received on Twitter 🙏

rust
use std::cmp::Ordering;

fn binary_search(k: i32, items: &[i32]) -> Option<usize> {
    let mut low: usize = 0;
    let mut high: usize = items.len();

    while low < high {
        let middle = (high + low) / 2;
        match items[middle].cmp(&k) {
            Ordering::Equal => return Some(middle),
            Ordering::Greater => high = middle,
            Ordering::Less => low = middle + 1
        }
    }
    None
}

Algorithm differences visualized

This first visualization here shows how all the demos across both posts (apart from the final change) play out

Original

This next visualization shows how much simpler everything becomes when you have high as the exclusive upper bound - fewer steps are taken overall. Also note how the high cursor begins at 1 index higher than the end of the list, whereas previously it began on the last element directly.

Exclusive upper bound

Thanks!

A huge thanks to Wiebe Cnossen and Basile Henry for providing the feedback that inspired this follow-up post.

Does this finally seem like 'idiomatic' Rust?

I'd love to hear any feedback or alternative ways to implement this algorithm in Rust - so please reach out on Twitter if you have any thoughts 🦀

Playground Link