Creating an enum iterator using Macros 1.1
Yesterday I was thinking about some testing strategies for a little project I’m working on. One idea I had was to test the output of a function when applied to every possible variant of an enum. One way I could implement this is by manually naming every variant of the enum in question – probably while using a macro. However, I am terrible at keeping track of things and I would most likely add a new variant to my enum without updating my tests. I could include in the macro a line of code that matched each variant like this:
match Test::A {
Test::A => (),
Test::B => (),
}
This would enforce the compiler to remind me to update my tests whenever I update my enum, but what I really wanted to do is learn Macros 1.1. So let’s do that instead. Here, I will describe my story of learning Macros 1.1.
Hello World
So the first thing we need to do is start a new crate for our project.
> cargo new --bin hello-world
All I want is to be able to call hello_world()
on a derived type. Something like this:
#[derive(HelloWorld)]
struct Pancakes;
fn main() {
Pancakes::hello_world();
}
With some kind of nice output, like Hello, World! My name is Pancakes.
. So let’s do that. Before we can even start, we need some reference material for using Macros 1.1. The best information I could find comes from dtolnay: his slides and his video presentation (if you haven’t seen this Rust Meetup, I highly recommend it; it’s quite fantastic).
Let’s go ahead and write up what we think our macro will look like from a user perspective. In hello-world\main.rs
we write:
#![feature(proc_macro)] // Promised to not be required coming Rust 1.15!
#[macro_use]
extern crate hello_world_macro;
#[derive(HelloWorld)]
struct FrenchToast;
#[derive(HelloWorld)]
struct Waffles;
fn main() {
FrenchToast::hello_world();
Waffles::hello_world();
}
Great. So now we just need to actually write the procedural macro. Let’s start a new crate called hello-world-macro
inside our hello-world
project.
hello-world\ > cargo new hello-world-macro
To make sure that our hello-world
crate is able to find this new crate we’ve crated, we’ll add it to our toml:
[dependencies]
hello-world-macro = { path = "hello-world-macro" }
As for our hello-world-macro
crate, we will just copy and paste some boilerplate found in dtolnay’s slides:
#![feature(proc_macro, proc_macro_lib)]
extern crate proc_macro;
extern crate syn;
#[macro_use] extern crate quote;
use proc_macro::TokenStream;
#[proc_macro_derive(HelloWorld)]
pub fn hello_world(input: TokenStream) -> TokenStream {
// Construct a string representation of the type definition
let s = input.to_string();
// Parse the string representation
let ast = syn::parse_macro_input(&s).unwrap();
// Build the impl
let gen = impl_hello_world(&ast);
// Return the generated impl
gen.parse().unwrap()
}
So there is a lot going on here. We have introduced two new crates: syn
and quote
. As you may have noticed, input: TokenSteam
is immediately converted to a String. This String is a string representation of the rust code for which we are deriving HelloWorld
for–that’s it. So what we really need is to be able to parse rust code into something usable. This is where syn
comes to play. syn
is a nom
based rust parser, and it’s extremely useful. The other crate we’ve introduced is quote
. It’s essentially the dual of syn
as it will make serializing rust code really easy.
The comments seem to give us a pretty good idea of our overall strategy. We are going to take a String of the rust code for the type we are deriving, parse it using syn
, construct the implementation of hello_world
(using quote
), then pass it back to rust compiler.
Great, so let’s write impl_hello_world(&ast)
.
fn impl_hello_world(ast: &syn::MacroInput) -> quote::Tokens {
let name = &ast.ident;
quote! {
impl #name {
fn hello_world() {
println!("Hello, World! My name is {}", stringify!(#name));
}
}
}
}
So this is where quotes comes in. The ast
argument is a struct that gives us a representation of our type (which can be either a struct
or an enum
). Check out the docs, there is some useful information there. We are able to get the name of the type using ast.ident
. The quote!
macro let’s us write up the rust code that we wish to return and convert it into Tokens
. quote!
let’s us use some really cool templating mechanics; we simply write #name
and quote!
will replace it with the variable named name
. You can even do some repition similar to regular macros work. You should check out the docs for a good introduction.
So I think that’s it. Oh, well, we do need to add dependencies for syn
and quote
in the cargo.toml
for hello-world-macro
.
[dependencies]
syn = "0.10.5"
quote = "0.3.10"
That should be it. Let’s try to compile hello-world
.
error: the `#[proc_macro_derive]` attribute is only usable with crates of the `proc-macro` crate type
--> hello-world-macro\src\lib.rs:8:3
|
8 | #[proc_macro_derive(HelloWorld)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Oh, so it appears tat we need to declare that our hello-world-macro
crate is a proc-macro
crate type. How do we do this? Well, we just go find a proc macro crate on github and see what they did to their cargo.toml
of course. It turns out we need to add this to our toml:
[lib]
proc-macro = true
Ok so now, let’s compile hello-world
. Executing cargo run
now yields:
Hello, World! My name is FrenchToast
Hello, World! My name is Waffles
Sweet! That’s it! We’ve created our first custom derive using syn
and quote
and it was actually quite easy. Now let’s try to create something less trivial.
EnumIterator
So let’s follow the same strategy as we did before by creating a new crate and writing out how we think this enum iterator should work. I want to be able to iterate through enum variants, something like this:
#![feature(proc_macro)] // Rust 1.15 me pls
#[macro_use] extern crate enum_iter;
#[derive(Debug, EnumIterator)]
pub enum Test {
A,
B,
C(u32,u32),
D { a: String, b: bool },
}
fn main() {
for variant in Test::enum_iter() {
println!("{:?}", variant);
}
}
I’d like to see some kind of output like this:
A
B
C(0, 0)
D { a: "", b: false }
Let’s just go ahead and copy and paste the boilerplate we from our previous discussion and make some minor modifications:
#![feature(proc_macro, proc_macro_lib)]
extern crate proc_macro;
extern crate syn;
#[macro_use]
extern crate quote;
use syn::{ Ident, Body, Variant, VariantData };
use proc_macro::TokenStream;
#[proc_macro_derive(EnumIterator)]
pub fn enum_iterator(input: TokenStream) -> TokenStream {
let s = input.to_string();
let ast = syn::parse_macro_input(&s).unwrap();
let name = &ast.ident;
let gen = match ast.body {
Body::Enum(ref variants) => impl_enum_iter(name, variants),
Body::Struct(_) =>
quote! {
impl EnumIteratorOnlyWorksForEnumsNotStructsSorryNotSorry for #name { }
},
};
gen.parse().unwrap()
}
fn impl_enum_iter(name: &Ident, variants: &[Variant]) -> quote::Tokens {
unimplemented!()
}
The first thing I need from syn
are the variants of the enum. But it is entirely possible that the type someone will #[derive(EnumIterator)]
a struct, and we certainly don’t want that. So we just implement a trait that I don’t believe the user has defined to communicate this error – this is probably not the recommended way to go about it, but this is what I did. Let’s try it out. Let’s go back to our main crate and test it with #[derive(EnumIterator)] struct IterateMePlease;
. Doing this will give us an error:
error[E0405]: unresolved trait `EnumIteratorOnlyWorksForEnumsNotStructsSorryNotSorry`
Edit: dtolnay has informed me that the idiomatic way for reporting errors from a procedural macro is panicking. Better would be
Body::Struct(_) => panic!("#[derive(EnumIterator)] is only defined for enums, not structs");
Perfect, so now that we have done our due diligence, let’s continue. So far syn
has given us the name of our enum and a vector of Variant
s for our enum. It’s at this point where we need to sit down and think about what we want the generated. Lucky for us, we have quote!
and we can just write out what the generate code should look like.
fn impl_enum_iter(name: &Ident, variants: &[Variant]) -> quote::Tokens {
let interface = quote::Ident::from(format!("_EnumIterator{}", name));
let match_usize = match_usize(&name, variants);
let size = variants.len();
quote! {
#[derive(Debug, Default)]
pub struct #interface {
count: usize,
}
impl #name {
fn enum_iter() -> #interface {
#interface::default()
}
}
impl #interface {
fn from_usize(n: usize) -> #name {
match n {
#(#match_usize)*
_ => unreachable!(), // I think
}
}
}
impl ::std::iter::Iterator for #interface {
type Item = #name;
fn next(&mut self) -> Option<Self::Item> {
if self.count >= #size { return None }
let result = #interface::from_usize(self.count);
self.count += 1;
Some(result)
}
}
}
}
What I really want is to create a new type implementing Iterator
. I don’t want this name to conflict with any possible names the user has created so I will call it something like _EnumIterator#name
. Unforuntately, I don’t think this templating magic works using the quote!
macro, but no fear quote
gives us a way to crate an ident
from a string – quote::Ident::from(format!("_EnumIterator{}", name));
. So we start by declaring our _EnumIterator{}
, and implement a method on our enum enum_iter()
which will build us one. The iterator will simply keep track of a count and then convert this count to the corresponding variant of the enum. This is what from_usize
is for. Basically, I want something like this:
fn from_usize(n: usize) -> Test {
match n {
0 => Test::A,
1 => Test::B,
_ => unreachable!(),
}
}
We will need to fill in those branches with #(#match_usize)*
. After that, the iterator is quite easy to implement. Match the current count with an enum variant, increment the count, and return that enum. That seems straightforward. So how do we create this match_usize
thing. Since this is repeating #(#match_usize)*
we will need to return of Vec
of things that quote!
returns (Tokens
). What I need to return is a default representation of each variant. So if the variant is a tuple type like Test::A(u32, u32)
, I need to yield Test::A(0,0)
, and if it’s a struct type like Test::A { a: String, b: i32 }
, then I should yeild Test::A { a: "", b: 0 }
. Let’s take a look at what the finished product looks like.
fn match_usize(name: &Ident, variants: &[Variant]) -> Vec<quote::Tokens> {
let mut result = Vec::new();
for (idx, variant) in variants.iter().enumerate() {
let id = &variant.ident;
let new = match variant.data {
VariantData::Unit => quote! {
#idx => #name::#id,
},
VariantData::Tuple(ref fields) => {
let types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
quote! {
#idx => #name::#id( #(#types::default(),)* ),
}
},
VariantData::Struct(ref fields) => {
let items: Vec<_> = fields.iter().map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! {
#ident: #ty::default()
}
}).collect();
quote! {
#idx => #name::#id { #(#items,)* },
}
}
};
result.push(new);
}
result
}
Edit: dtolnay mentioned that the quote!
repition macro can take anything that implements IntoIterator, which makes .collect()
unnecessary in these examples. Instead, we could write
let types = fields.iter().map(|f| &f.ty);
Edit continued: We would be better off simply using ::std::default::Default::default()
in place of #types::default()
to
correctly handle genrics. For instance, Vec<u32>::default();
is invalid rust which is what would be returned with
#types::default()
.
So basically, I am just iterating through the variants of the enum trying to figure out how to write out a default representaiton, and pushing that into a vector for which I will return. The unit types are the easiest one. I just need to return quote! { #idx => #name::#id }
where the #id
is the name of the enum variant. Let’s look at the other branches:
VariantData::Tuple(ref fields) => {
let types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
quote! {
#idx => #name::#id( #(#types::default(),)* ),
}
},
Here we have a tuple type, and syn
is nice enough to give us the fields of the tuples as well; thanks buddy. All we need to do is get the types for all of these fields so we can call #types::default(),
on each of them for our return. So basically for an enum variant that is of the form Test::C(u32,u32)
we will return 0 => Test::C(u32::defualt(), u32::default()),
. Neat. Let’s look at struct variants branch:
VariantData::Struct(ref fields) => {
let items: Vec<_> = fields.iter().map(|f| {
let ident = &f.ident;
let ty = &f.ty;
quote! {
#ident: #ty::default()
}
}).collect();
quote! {
#idx => #name::#id { #(#items,)* },
}
}
So in struct variant, we need both the field name and the type of our fields. We can collect these into a Vec of quote::Tokens
by iterating through fields and returing quote! { #ident: #ty::default() }
. This is like creating a Vec of Strings that look like "a: String::default()"
. Once we have that, we can chain them together using the familiar macro repetition pattern { #(#items,)* },
. So if our variant was Test::D { a: String, b: i32 }
, we will be returning Test::D { a: String::default(), b: i32::default() }
.
And that’s it; I think. Let’s test it out. Compile our main crate and we get exactly what we expected.
A
B
C(0, 0)
D { a: "", b: false }
As you may have noticed, this macro should only work if each of the variants consists of things that implement Default
. Let’s see what happends if one of our variants doesn’t.
#[derive(Debug)]
struct NoDefault;
#[derive(Debug, EnumIterator)]
pub enum Test {
A,
B(NoDefault),
}
If we compile this we get the error
error: no associated item named `default` found for type `NoDefault` in the current scope
--> src\main.rs:7:17
|
7 | #[derive(Debug, EnumIterator)]
| ^^^^^^^^^^^^
|
Beautiful. Users will love this error message. If you would like to play around with this project, you can find the source on github.