Rust Traits: Iterator (2020-04-27)

Rust Traits: Iterator

The Iterator trait in Rust allows you to conveniently operate over a sequence of elements. They provide an expressive, functional, convenient, and performant way to do computations.

This article will focus on the mechanics of the trait and offer a deeper look into it. If you're not already familiar with traits in Rust, I recommend skimming the rust book chapter before reading this article.

The Basics

Iterators are very common in Rust, and if you've written Rust you have likely used them. Here is a basic example:

for val in 0..10 {
  // ... use val
}

0..10 is a Range, which implements Iterator, which is central to the function of for-loops (desugar here).

You may have written something more complex as well:

let v1 = vec![1, 2, 3];
let v2 = vec![4, 5, 6];

let dot_product: u32 = v1
    .iter()
    .zip(v2)
    .map(|(l, r)| l * r)
    .sum(); // 32

Either way, we're iterating and operating on some sequence of values. The Iterator trait provides convenient ways to construct, transform, and consume these sequences.

The Iterator Trait

The Iterator trait could be thought as having three parts:

  1. The foundation: a next() function which returns some type Item if it can.
  2. Lots and lots of methods for iterator transformations (e.g. functional tools like map and filter).
  3. A function called collect(), which allows you to evaluate iterators into some collection type.

The Foundation

The foundation to the Iterator is a type Item, and a method next() which returns Option<Item>:

trait Iterator {
  type Item;
  fn next(&mut self) -> Option<Self::Item>;
  // ... several elided methods
}

Annotated with comments:

trait Iterator {
  type Item;
  //   ^-- Associated type; the type we are returning each iteration
  fn next(&mut self) -> Option<Self::Item>;
  // ^    ^             ^ returns either an Item, or nothing
  // |    | it mutates something each iteration
  // | `next` method gets somehow called each iteration in for-loops

  // ... several elided methods
}

The trait signature tells us a lot about how it works:

  1. We need to mutate the type for which we implement Iterator (something needs to book-keep).
  2. If we have a value to yield, return Some(val)
    1. Otherwise, stop iteration by yielding None
  3. We return the same type each iteration.

So now that we've seen the foundation, let's preview some Iterator trait methods for transforming iterators.

The Thousand Elided Methods

While it's nice that we can cleanly define a way to retrieve a single element at a time from a collection, it would be very nice to operate on the iterable itself. The Iterator trait provides a LOT of functions to conveniently work with iterators in a functional style. We can succinctly express more complex logic with these methods – for example:

let some_iterable = 0..100;
let sum = some_iterable
    .filter(|&e| e > 50)
    .map(|e| e * e)
    .sum();

vs

let some_iterable = 0..100;
let mut sum = 0;
for e in some_iterable {
    if !(e > 50) {
        continue;
    }
    let doubled = e * e;
    sum += doubled;
}

I personally find the first far easier to read as it requires much less parsing. This isn't always true of iterators in Rust, but most of the time it is.

Other methods we'll use for this article include Iterator::take(N), which only takes up-to N elements from the iterator. This is useful for infinite iterators, and is common in functional languages.

You can view a list of the iterator methods here.

The collect() Method

While important, this article won't focus much on the mechanics of collect(). In short, this method uses the FromIterator trait to convert iterators into some collection. You'll find yourself using this often when working with iterators to convert them into tangible and convenient types.

There's a good example of collect() here.

Now that we've seen an overview of what's provided, we can implement Iterator!

Part 1: The Natural Numbers

To get more familiar with the trait, let's make a useful construct: The Natural Numbers.

To implement this, we'll need a struct holding the current value:

// Book keeping struct
struct NaturalNumbers {
    curr: u32,
}

// Start at 0 because computers
impl NaturalNumbers {
    fn new() -> Self {
        Self { curr: 0 }
    }
}

And implement Iterator by incrementing curr:

impl Iterator for NaturalNumbers {
    type Item = u32;

    fn next(&mut self) -> Option<Self::Item> {
        let ret = self.curr;
        self.curr += 1;
        Some(ret)
    }
}

Nice! We have a struct NaturalNumbers which will yield every natural number until it panics on overflow.

This is certainly useful, and will serve as a bedrock for later functions. Unfortunately our terminals don't appreciate printing millions of integers, so we'll use the method Iterator::take(N) which limits the number of iterations to at most N.

We can then test NaturalNumbers with:

fn main() {
    for i in NaturalNumbers::new().take(5) {
        println!("{}", i);
    }
}

Which outputs:

~ cargo run
   Compiling iterator-article v0.1.0 (/home/david/programming/iterator-article)
    Finished dev [unoptimized + debuginfo] target(s) in 0.15s
     Running `target/debug/iterator-article`
0
1
2
3
4

You can the run this example yourself on the Rust playground!

So now that we can generate a sequence of values, let's implement some familiar functional friends: map, filter, and reduce (fold).

Implementing Map

A frequent programming task is to loop over some collection and operate (transform) the type of an element given in each iteration.

This occurs commonly when retrieving data from some source, and you need to bind the data in some useful construct (class / struct / etc). Or if you're crunching numbers you may want to operate on each element individually before some other step.

Either way, this pattern is so common that most languages offer the map construct – a way to provide an iterable and a function, and get the function applied to each element of the iterable returned.

For example, let's double each number in a vector. Rust offers a map() method on iterators, so we'll use that first:

Pseudo-code:

seq: 0, 1, 2, 3, ...
fn:  |e| e * e
out: 0, 1, 4, 9, ...

Rust:

let input = vec![1, 2, 3];
let doubled: Vec<_> = input
    .iter()
    .map(|e| e * e)
    .collect();

So we provide a function, |e| e * e which double numbers, and map implicitly takes self, which is an iterator. This may not make sense right now, so let's dig deeper into building our own Map.

Things are going to get a little higher-order here, so let's outline what we'll need:

  1. We need a type Iter, which implements Iterator
  2. We need a function, which maps Iter::Item to some output type Out
    1. Syntax: Iter::Item is the associated type Item from implementation of Iterator on Iter.
    2. We can express the map function in Rust then as FnMut(Iter::Item) -> Out
      1. FnMut as we're consuming the element and may want to mutate captured variables. Feel free to use Fn if you don't want that. More on this later in the Reduce section.

Putting the above together we'll need a struct to store our function and iterator:

 // Our Map struct
struct Map<Iter, Fn> {
    iter: Iter,
    f: Fn,
}

// We'll want to instantiate one later, so add a constructor method:
impl<Iter, Fn> Map<Iter, Fn> {
    fn new(iter: Iter, f: Fn) -> Self {
        Self { iter, f }
    }
}

Great, we can now tackle implementing Iterator. The first challenge is getting the types setup for our impl. As described above, we'll need an Iter, F (map fn), and Out types:

impl<Iter, F, Out> Iterator ...

But we need further guarantees as described above:

impl<Iter: Iterator, F: FnMut(Iter::Item) -> Out, Out> Iterator ...

I recommend the reader really make sure the type signature above makes sense. Rust has a tendency to hit type soup, and it is worthwhile to take a minute to understand it.

We can now implement Iterator in a straightforward way:

impl<Iter: Iterator, F: FnMut(Iter::Item) -> Out, Out> Iterator for Map<Iter, F> {
    type Item = Out;

    fn next(&mut self) -> Option<Self::Item> {
        self.iter.next().map(|e| (self.f)(e))
    }
}

So we're calling next() on our stored iterator to iterate once, and mapping the value with our stored function, and returning it. This is very efficient and something that rustc / llvm love to optimize, which gives some insight into why Rust iterators are so fast.

Now that we have it, let's use it!

fn main() {
    let nat = NaturalNumbers::new().take(5);
    let seq = Map::new(nat, |e| e * e);
    for i in seq {
        println!("{}", i);
    }
}

And run it:

$ cargo run
     Compiling iterator-article v0.1.0 (/home/david/programming/iterator-article)
      Finished dev [unoptimized + debuginfo] target(s) in 0.17s
       Running `target/debug/iterator-article`
  0
  1
  4
  9
  16

Nice! We can transform sequences using our own struct. If you want to see it in action yourself, you can play with it on the rust playground.

This is certainly powerful, but it would be nice to filter the element as well. Map only has access to a single element at a time, and must operate on the element. We can play around with the function types passed but most of the time we just want to filter out certain elements.

Filter

Filter is an interesting abstraction, as it concerns itself with retaining elements of a sequence which satisfy some criteria, and dropping the rest. The criteria function, or predicate function, borrows a value from the iterator and returns true or false. If the predicate evaluates to true on an element, return it to the caller. If the predicate is false, forget about it and continue searching.

This abstraction is also very common in other languages, and is just as essential as Map for functional programming.

The other wrinkle is that we need to care about ownership in Rust. Map would want to own each element as it needs to transform it, but filter just needs to borrow the element. We won't cover the magic involved with the Fn family and references, but this will work:

FnMut(&Iter::Item) -> bool

Our job is then similar to Map, we need a struct and constructor:

// struct to hold iterator and predicate function pointer
struct Filter<Iter, Predicate> {
    iter: Iter,
    pred: Predicate,
}

// And a default constructor
impl<Iter, Predicate> Filter<Iter, Predicate> {
    fn new(iter: Iter, pred: Predicate) -> Self {
        Self { iter, pred }
    }
}

Same idea as Map – store the iterator and function in a struct. Now we can implement Iterator in a similar fashion:

impl<Iter, Predicate> Iterator for Filter<Iter, Predicate>
where
    Iter: Iterator,
    Predicate: FnMut(&Iter::Item) -> bool,
{
    type Item = Iter::Item;
    fn next(&mut self) -> Option<Self::Item> {
        while let Some(ele) = self.iter.next() {
            if (self.pred)(&ele) {
                return Some(ele);
            }
        }
        None
    }
}

We're again iterating over our underlying iterator, and then testing each element with our predicate. If it passes, we return the element. We're implicitly mutating self.iter as it's also an iterator, so no state is lost. When the caller calls next() we'll simply continue iterating where left off in self.iter and continue the process. Eventually we'll exhaust the underlying iterator and stop iteration by returning None.

So now that we have it, let's use it! We'll build off of the Map example above to retain the even elements:

fn main() {
    let nat = NaturalNumbers::new().take(10);
    let seq = Map::new(nat, |e| e * e);
    let mut seq = Filter::new(seq, |e: &u32| *e % 2 == 0);
    for i in seq {
        println!("{}", i);
    }
}

Which when run prints out (run it on the playground here):

~ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.04s
     Running `target/debug/iterator-article`
0
4
16
36
64

Great! We can now selectively retain elements in a sequence. The final tool to make is reduce (also called fold) which is the most powerful tool yet.

Reduce

The motivation for reduce (fold in Rust) is pretty simple: We need a way to collapse entire sequences into some type. Map and Filter only operate on each element one a time, not an entire sequence. How would we sum all numbers in a list?

The mechanics are pretty simple thankfully:

  1. We have a base type; the accumulator. In the summing example, this would be 0.
  2. We have a function FnMut(acc, ele) -> acc which melds the accumulator and the given element.

For example, to multiply a list of integers we will need:

  1. The accumulator, with initial value 1.
  2. the function |acc, ele| acc * ele
  3. A list [1, 2, 3]

We can view the computation with the table below:

Table 1: Final result: 6
iter acc ele product
1 1 1 1
2 1 2 2
3 2 3 6

So the idea is to accumulate values into the accumulator. We don't need the Iterator trait just yet, so we can implement reduce with a free standing function:

fn reduce<Acc, Iter, ReduceFn>(iterator: Iter, acc: Acc, reducefn: ReduceFn) -> Acc
where
    Iter: Iterator,
    ReduceFn: Fn(Acc, Iter::Item) -> Acc,
{
    let mut acc = acc;
    for ele in iterator {
        acc = reducefn(acc, ele);
    }
    acc
}

We can now use it:

fn main() {
    let nat = NaturalNumbers::new().take(4);
    let mut seq = Filter::new(nat, |e: &u32| *e > 0);
    let prod = reduce(seq, 1, |acc, ele| acc * ele);
    println!("{}", prod);
}

Which outputs 1 * 1 * 2 * 3 = 6 as expected (rust playground):

~ cargo run
    Blocking waiting for file lock on build directory
   Compiling iterator-article v0.1.0 (/home/david/programming/iterator-article)
    Finished dev [unoptimized + debuginfo] target(s) in 0.33s
     Running `target/debug/iterator-article`
6

Quick note on reduce

reduce is strictly more powerful than Map and Filter as it has access to the whole collection and an accumulator. We can easily implement Filter in terms of reduce for example:

let mut empty_vec = vec![];
let bigger_than_five = reduce(
    NaturalNumbers::new().take(10),
    &mut empty_vec,
    |acc, ele| {
        if ele > 5 {
            acc.push(ele);
        }
        acc
    },
);

I would recommend playing around with this function. It's useful to internalize that reduce (fold) can produce any output type. However I would keep in mind that unnecessary uses of reduce like the example above removes access to the Iterator performance optimizations.

Part 2: Our own Iterator Trait

The following code is certainly nice:

let nat = NaturalNumbers::new().take(4);
let doubled = Map::new(nat, |e| e * e);
let mut seq = Filter::new(doubled, |e: &u32| *e % 2 == 0);
let prod = reduce(seq, 1, |acc, ele| acc * ele);

But this is far easier to read:

let prod = NaturalNumbers::new()
    .take(4)
    .map(|e| e * e)
    .filter(|e: &u32| *e % 2 == 0)
    .reduce(1, |acc, ele| acc * ele);

The question is then: How does Iterator provide this interface?

As mentioned above, Iterator provides a whole bunch of default methods to facilitate this clean API. To better understand how this works, let's define our own Iterator trait:

trait MyIterator {
    type Item;
    fn next(&mut self) -> Option<Self::Item>;
}

And update our previous Iterator implementations:

-impl<Iter, Predicate> Iterator for Filter<Iter, Predicate>
+impl<Iter, Predicate> MyIterator for Filter<Iter, Predicate>
...

You can view the whole refactor on the rust playground. Unfortunately, our changes don't compile as we no longer have a Iterator::take(N) method:

error[E0599]: no method named `take` found for struct `NaturalNumbers` in the current scope
   --> src/main.rs:116:37
    |
1   | struct NaturalNumbers {
    | ---------------------
    | |
    | method `take` not found for this
    | doesn't satisfy `NaturalNumbers: std::iter::Iterator`
...
116 |     let nat = NaturalNumbers::new().take(4);
    |                                     ^^^^ method not found in `NaturalNumbers`
    |
    = note: the method `take` exists but the following trait bounds were not satisfied:
            `NaturalNumbers: std::iter::Iterator`
            which is required by `&mut NaturalNumbers: std::iter::Iterator`
    = help: items from traits can only be used if the trait is implemented and in scope
    = note: the following trait defines an item `take`, perhaps you need to implement it:
            candidate #1: `std::iter::Iterator`

It's looking like we'll need to implement Take ourselves. It's a very similar process as before. We'll need a struct and Iterator implementation:

struct Take<Iter> {
    iterator: Iter,
    left: usize,
}

impl<Iter> Take<Iter> {
    fn new(iterator: Iter, left: usize) -> Self {
        Self { iterator, left }
    }
}

impl<Iter: MyIterator> MyIterator for Take<Iter> {
    type Item = Iter::Item;
    fn next(&mut self) -> Option<Self::Item> {
        if self.left > 0 {
            self.left -= 1;
            self.iterator.next()
        } else {
            None
        }
    }
}

Now that we have the struct, we need to modify MyIterator to achieve the desired API. Things will get a bit introspective, as we cannot refer to any concrete types. We instead rely on the Self language feature to specify that types which implement MyIterator will be the ones used in the method calls. We'll want to transfer ownership of iterators in these methods, so our MyIterator::Take(N) signature will read:

fn take(self, left: usize) -> Take<Self>

The other wrinkle is that this won't compile, as the Rust compiler is not confident it can layout the Take struct properly, as Self can be !Sized. This can seem obscure, but the error message is pretty good:

error[E0277]: the size for values of type `Self` cannot be known at compilation time
   --> src/main.rs:116:37
    |
90  | struct Take<Iter> {
    |             ---- required by this bound in `Take`
...
116 |     fn take(self, amount: usize) -> Take<Self> {
    |                                     ^^^^^^^^^^- help: consider further restricting `Self`: `where Self: std::marker::Sized`
    |                                     |
    |                                     doesn't have a size known at compile-time
    |
    = help: the trait `std::marker::Sized` is not implemented for `Self`
    = note: to learn more, visit <https://doc.rust-lang.org/book/ch19-04-advanced-types.html#dynamically-sized-types-and-the-sized-trait>

To better understand this error, what is the type of seq in the following?

let seq = NaturalNumbers::new()
    .take(4)
    .map(|e| e * e)
    .filter(|e: &u32| *e % 2 == 0);

The answer is Filter<Map<Take<NaturalNumbers>, fn#1>, fn#2>.

Recall that Map, Filter, and Take all take a type Iter: MyIterator by value, so it needs to physically store that iterator in the struct memory layout. The Rust language tracks this information in the Sized trait. So if a type is Sized, Rust can properly lay out the struct. If a type is !Sized, then indirection or obscure language features are required to embed that type in the struct. The compiler has helpfully told us to add a Sized bound on Self:

 fn take(self, amount: usize) -> Take<Self>
+where
+    Self: std::marker::Sized,
 {
     Take::new(self, amount)
 }

This compiles and works! Let's run our main again:

fn main() {
    let nat = NaturalNumbers::new().take(4);
    let doubled = Map::new(nat, |e| e * e);
    let seq = Filter::new(doubled, |e: &u32| *e > 0);
    let prod = reduce(seq, 1, |acc, ele| acc * ele);
    println!("{}", prod);
}

Which outputs:

~ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.03s
     Running `target/debug/iterator-article`
36

We can now do the same procedure for Map and Filter. We can reuse the constructors but replace Iter with Self:

trait MyIterator {
    // elided ...

    fn map<Out, F>(self, f: F) -> Map<Self, F>
    where
        F: FnMut(Self::Item) -> Out,
        Self: std::marker::Sized,
    {
        Map::new(self, f)
    }

    fn filter<F>(self, f: F) -> Filter<Self, F>
    where
        F: FnMut(&Self::Item) -> bool,
        Self: std::marker::Sized,
    {
        Filter::new(self, f)
    }
}

Our main function is now:

fn main() {
    let seq = NaturalNumbers::new()
        .take(4)
        .map(|e| e * e)
        .filter(|e: &u32| *e > 0);
    let prod = reduce(seq, 1, |acc, ele| acc * ele);
    println!("{}", prod);
}

Which outputs 36 as before. Now we just need to implement reduce in a similar way as before:

trait MyIterator {
  // elided...

  fn reduce<Acc, ReduceFn>(mut self, acc: Acc, mut reducefn: ReduceFn) -> Acc
  where
      ReduceFn: FnMut(Acc, Self::Item) -> Acc,
      Self: std::marker::Sized,
  {
      let mut acc = acc;
      while let Some(ele) = self.next() {
          acc = reducefn(acc, ele);
      }
      acc
  }
}

And change our main function to be:

fn main() {
    let prod = NaturalNumbers::new()
        .take(4)
        .map(|e| e * e)
        .filter(|e: &u32| *e > 0)
        .reduce(1, |acc, ele| acc * ele);
    println!("{}", prod);
}

Which outputs 36 as expected (rust playground):

~ cargo run
   Compiling iterator-article v0.1.0 (/home/david/programming/iterator-article)
    Finished dev [unoptimized + debuginfo] target(s) in 0.15s
     Running `target/debug/iterator-article`
36

Conclusion

Phew, 3.6k words later we've accomplished our goal. We've recreated the Iterator, and delved into it's mechanics. I hope you've learned something from his article, as I certainly learned a lot writing it. I really like this language feature, and think it represents some of the best API design Rust offers.

Appendix: The Primes

We started our journey by defining the NaturalNumbers, so it would be cool if we could generate an infinite sequence of Primes:

struct Primes {
    seen: Vec<u32>,
    curr: u32,
}

impl Primes {
    fn new() -> Self {
        Self {
            seen: vec![],
            curr: 2,
        }
    }
}

impl Iterator for Primes {
    type Item = u32;

    fn next(&mut self) -> Option<u32> {
        for ele in self.curr.. {
            if !self.seen.iter().any(|prime| ele % prime == 0) {
                self.seen.push(ele);
                self.curr = ele + 1;
                return Some(ele);
            }
        }
        None
    }
}

Which can we use:

fn main() {
    println!("{:?}", Primes::new().take(20).collect::<Vec<_>>());
}

And this outputs the first twenty primes (rust playground):

~ cargo run
      Finished dev [unoptimized + debuginfo] target(s) in 0.19s
       Running `target/debug/iterator-article`
  [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]

It's just that easy to generate a sequence of Primes using Iterator in Rust. The reader is encouraged to use MyIterator::reduce to achieve the same effect.