Creating an enum iterator using Macros 1.1

Yesterday I was thinking about some testing strategies for a little project I’m working on. One idea I had was to test the output of a function when applied to every possible variant of an enum. One way I could implement this is by manually naming every variant of the enum in question – probably while using a macro. However, I am terrible at keeping track of things and I would most likely add a new variant to my enum without updating my tests. I could include in the macro a line of code that matched each variant like this:

match Test::A {
    Test::A => (),
    Test::B => (),
}

This would enforce the compiler to remind me to update my tests whenever I update my enum, but what I really wanted to do is learn Macros 1.1. So let’s do that instead. Here, I will describe my story of learning Macros 1.1.

Hello World

So the first thing we need to do is start a new crate for our project.

> cargo new --bin hello-world

All I want is to be able to call hello_world() on a derived type. Something like this:

#[derive(HelloWorld)]
struct Pancakes;

fn main() {
    Pancakes::hello_world();
}

With some kind of nice output, like Hello, World! My name is Pancakes.. So let’s do that. Before we can even start, we need some reference material for using Macros 1.1. The best information I could find comes from dtolnay: his slides and his video presentation (if you haven’t seen this Rust Meetup, I highly recommend it; it’s quite fantastic).

Let’s go ahead and write up what we think our macro will look like from a user perspective. In hello-world\main.rs we write:

#![feature(proc_macro)] // Promised to not be required coming Rust 1.15!
#[macro_use]
extern crate hello_world_macro;

#[derive(HelloWorld)]
struct FrenchToast;

#[derive(HelloWorld)]
struct Waffles;

fn main() {
    FrenchToast::hello_world();
    Waffles::hello_world();
}

Great. So now we just need to actually write the procedural macro. Let’s start a new crate called hello-world-macro inside our hello-world project.

hello-world\ > cargo new hello-world-macro

To make sure that our hello-world crate is able to find this new crate we’ve crated, we’ll add it to our toml:

[dependencies]
hello-world-macro = { path = "hello-world-macro" }

As for our hello-world-macro crate, we will just copy and paste some boilerplate found in dtolnay’s slides:

#![feature(proc_macro, proc_macro_lib)]
extern crate proc_macro;
extern crate syn;
#[macro_use] extern crate quote;

use proc_macro::TokenStream;

#[proc_macro_derive(HelloWorld)]
pub fn hello_world(input: TokenStream) -> TokenStream {
    // Construct a string representation of the type definition
    let s = input.to_string();
    
    // Parse the string representation
    let ast = syn::parse_macro_input(&s).unwrap();

    // Build the impl
    let gen = impl_hello_world(&ast);
    
    // Return the generated impl
    gen.parse().unwrap()
}

So there is a lot going on here. We have introduced two new crates: syn and quote. As you may have noticed, input: TokenSteam is immediately converted to a String. This String is a string representation of the rust code for which we are deriving HelloWorld for–that’s it. So what we really need is to be able to parse rust code into something usable. This is where syn comes to play. syn is a nom based rust parser, and it’s extremely useful. The other crate we’ve introduced is quote. It’s essentially the dual of syn as it will make serializing rust code really easy.

The comments seem to give us a pretty good idea of our overall strategy. We are going to take a String of the rust code for the type we are deriving, parse it using syn, construct the implementation of hello_world (using quote), then pass it back to rust compiler.

Great, so let’s write impl_hello_world(&ast).

fn impl_hello_world(ast: &syn::MacroInput) -> quote::Tokens {
    let name = &ast.ident;
    quote! {
        impl #name {
            fn hello_world() {
                println!("Hello, World! My name is {}", stringify!(#name));
            }
        }
    }
}

So this is where quotes comes in. The ast argument is a struct that gives us a representation of our type (which can be either a struct or an enum). Check out the docs, there is some useful information there. We are able to get the name of the type using ast.ident. The quote! macro let’s us write up the rust code that we wish to return and convert it into Tokens. quote! let’s us use some really cool templating mechanics; we simply write #name and quote! will replace it with the variable named name. You can even do some repition similar to regular macros work. You should check out the docs for a good introduction.

So I think that’s it. Oh, well, we do need to add dependencies for syn and quote in the cargo.toml for hello-world-macro.

[dependencies]
syn = "0.10.5"
quote = "0.3.10"

That should be it. Let’s try to compile hello-world.

error: the `#[proc_macro_derive]` attribute is only usable with crates of the `proc-macro` crate type
 --> hello-world-macro\src\lib.rs:8:3
  |
8 | #[proc_macro_derive(HelloWorld)]
  |   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Oh, so it appears tat we need to declare that our hello-world-macro crate is a proc-macro crate type. How do we do this? Well, we just go find a proc macro crate on github and see what they did to their cargo.toml of course. It turns out we need to add this to our toml:

[lib]
proc-macro = true

Ok so now, let’s compile hello-world. Executing cargo run now yields:

Hello, World! My name is FrenchToast
Hello, World! My name is Waffles

Sweet! That’s it! We’ve created our first custom derive using syn and quote and it was actually quite easy. Now let’s try to create something less trivial.

EnumIterator

So let’s follow the same strategy as we did before by creating a new crate and writing out how we think this enum iterator should work. I want to be able to iterate through enum variants, something like this:

#![feature(proc_macro)] // Rust 1.15 me pls
#[macro_use] extern crate enum_iter;

#[derive(Debug, EnumIterator)]
pub enum Test {
    A,
    B,
    C(u32,u32),
    D { a: String, b: bool },
}

fn main() {
    for variant in Test::enum_iter() {
        println!("{:?}", variant);
    }
}

I’d like to see some kind of output like this:

A
B
C(0, 0)
D { a: "", b: false }

Let’s just go ahead and copy and paste the boilerplate we from our previous discussion and make some minor modifications:

#![feature(proc_macro, proc_macro_lib)]
extern crate proc_macro;
extern crate syn;
#[macro_use]
extern crate quote;

use syn::{ Ident, Body, Variant, VariantData };
use proc_macro::TokenStream;

#[proc_macro_derive(EnumIterator)]
pub fn enum_iterator(input: TokenStream) -> TokenStream {
    let s = input.to_string();
    let ast = syn::parse_macro_input(&s).unwrap();

    let name = &ast.ident;
    let gen = match ast.body {
        Body::Enum(ref variants) => impl_enum_iter(name, variants),
        Body::Struct(_) =>
            quote! {
                impl EnumIteratorOnlyWorksForEnumsNotStructsSorryNotSorry for #name { }
            },
    };
    gen.parse().unwrap()
}

fn impl_enum_iter(name: &Ident, variants: &[Variant]) -> quote::Tokens {
    unimplemented!()
}

The first thing I need from syn are the variants of the enum. But it is entirely possible that the type someone will #[derive(EnumIterator)] a struct, and we certainly don’t want that. So we just implement a trait that I don’t believe the user has defined to communicate this error – this is probably not the recommended way to go about it, but this is what I did. Let’s try it out. Let’s go back to our main crate and test it with #[derive(EnumIterator)] struct IterateMePlease;. Doing this will give us an error:

error[E0405]: unresolved trait `EnumIteratorOnlyWorksForEnumsNotStructsSorryNotSorry`

Edit: dtolnay has informed me that the idiomatic way for reporting errors from a procedural macro is panicking. Better would be

Body::Struct(_) => panic!("#[derive(EnumIterator)] is only defined for enums, not structs");

Perfect, so now that we have done our due diligence, let’s continue. So far syn has given us the name of our enum and a vector of Variants for our enum. It’s at this point where we need to sit down and think about what we want the generated. Lucky for us, we have quote! and we can just write out what the generate code should look like.

fn impl_enum_iter(name: &Ident, variants: &[Variant]) -> quote::Tokens {
    let interface = quote::Ident::from(format!("_EnumIterator{}", name));
    let match_usize = match_usize(&name, variants);
    let size = variants.len();

    quote! {
        #[derive(Debug, Default)]
        pub struct #interface {
            count: usize,
        }

        impl #name {
            fn enum_iter() -> #interface {
                #interface::default()
            }
        }

        impl #interface {
            fn from_usize(n: usize) -> #name {
                match n {
                    #(#match_usize)*
                    _ => unreachable!(), // I think
                }
            }
        }

        impl ::std::iter::Iterator for #interface {
            type Item = #name;
            fn next(&mut self) -> Option<Self::Item> {
                if self.count >= #size { return None }
                let result = #interface::from_usize(self.count);
                self.count += 1;
                Some(result)
            }
        }
    }
}

What I really want is to create a new type implementing Iterator. I don’t want this name to conflict with any possible names the user has created so I will call it something like _EnumIterator#name. Unforuntately, I don’t think this templating magic works using the quote! macro, but no fear quote gives us a way to crate an ident from a string – quote::Ident::from(format!("_EnumIterator{}", name));. So we start by declaring our _EnumIterator{}, and implement a method on our enum enum_iter() which will build us one. The iterator will simply keep track of a count and then convert this count to the corresponding variant of the enum. This is what from_usize is for. Basically, I want something like this:

fn from_usize(n: usize) -> Test {
    match n {
        0 => Test::A,
        1 => Test::B,
        _ => unreachable!(),
    }
}

We will need to fill in those branches with #(#match_usize)*. After that, the iterator is quite easy to implement. Match the current count with an enum variant, increment the count, and return that enum. That seems straightforward. So how do we create this match_usize thing. Since this is repeating #(#match_usize)* we will need to return of Vec of things that quote! returns (Tokens). What I need to return is a default representation of each variant. So if the variant is a tuple type like Test::A(u32, u32), I need to yield Test::A(0,0), and if it’s a struct type like Test::A { a: String, b: i32 }, then I should yeild Test::A { a: "", b: 0 }. Let’s take a look at what the finished product looks like.

fn match_usize(name: &Ident, variants: &[Variant]) -> Vec<quote::Tokens> {
    let mut result = Vec::new();

    for (idx, variant) in variants.iter().enumerate() {
        let id = &variant.ident;
        let new = match variant.data {
            VariantData::Unit => quote! {
                    #idx => #name::#id,
                },

            VariantData::Tuple(ref fields) => {
                let types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
                quote! {
                    #idx => #name::#id( #(#types::default(),)* ),
                }
            },

            VariantData::Struct(ref fields) => {
                let items: Vec<_> = fields.iter().map(|f| {
                    let ident = &f.ident;
                    let ty = &f.ty;

                    quote! {
                        #ident: #ty::default()
                    }
                }).collect();

                quote! {
                    #idx => #name::#id { #(#items,)*  },
                }
            }
        };
        result.push(new);
    }

    result
}

Edit: dtolnay mentioned that the quote! repition macro can take anything that implements IntoIterator, which makes .collect() unnecessary in these examples. Instead, we could write

let types = fields.iter().map(|f| &f.ty);

Edit continued: We would be better off simply using ::std::default::Default::default() in place of #types::default() to correctly handle genrics. For instance, Vec<u32>::default(); is invalid rust which is what would be returned with #types::default().

So basically, I am just iterating through the variants of the enum trying to figure out how to write out a default representaiton, and pushing that into a vector for which I will return. The unit types are the easiest one. I just need to return quote! { #idx => #name::#id } where the #id is the name of the enum variant. Let’s look at the other branches:

    VariantData::Tuple(ref fields) => {
        let types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
        quote! {
            #idx => #name::#id( #(#types::default(),)* ),
        }
    },

Here we have a tuple type, and syn is nice enough to give us the fields of the tuples as well; thanks buddy. All we need to do is get the types for all of these fields so we can call #types::default(), on each of them for our return. So basically for an enum variant that is of the form Test::C(u32,u32) we will return 0 => Test::C(u32::defualt(), u32::default()),. Neat. Let’s look at struct variants branch:

    VariantData::Struct(ref fields) => {
        let items: Vec<_> = fields.iter().map(|f| {
            let ident = &f.ident;
            let ty = &f.ty;

            quote! {
                #ident: #ty::default()
            }
        }).collect();

        quote! {
            #idx => #name::#id { #(#items,)*  },
        }
    }

So in struct variant, we need both the field name and the type of our fields. We can collect these into a Vec of quote::Tokens by iterating through fields and returing quote! { #ident: #ty::default() }. This is like creating a Vec of Strings that look like "a: String::default()". Once we have that, we can chain them together using the familiar macro repetition pattern { #(#items,)* },. So if our variant was Test::D { a: String, b: i32 }, we will be returning Test::D { a: String::default(), b: i32::default() }.

And that’s it; I think. Let’s test it out. Compile our main crate and we get exactly what we expected.

A
B
C(0, 0)
D { a: "", b: false }

As you may have noticed, this macro should only work if each of the variants consists of things that implement Default. Let’s see what happends if one of our variants doesn’t.

#[derive(Debug)]
struct NoDefault;

#[derive(Debug, EnumIterator)]
pub enum Test {
    A,
    B(NoDefault),
}

If we compile this we get the error

error: no associated item named `default` found for type `NoDefault` in the current scope
 --> src\main.rs:7:17
  |
7 | #[derive(Debug, EnumIterator)]
  |                 ^^^^^^^^^^^^
  |

Beautiful. Users will love this error message. If you would like to play around with this project, you can find the source on github.

Written on December 31, 2016