Rust で自動メモ化マクロ
作ってしまいました.かなり雑.Rust 1.15.1 で動きます.
macro_rules! __memoize_internal { (@vec $FRetT:ty) => { Option<$FRetT> }; (@vec $FRetT:ty, $memo_param_head:ident $(, $memo_param:ident)*) => { Vec<__memoize_internal!(@vec $FRetT $(, $memo_param)*)> }; (@indices [$($t:tt)*]) => { [$($t)*] }; ( @indices [$($t:tt)*] $memo_param_head:ident $(, $memo_param:ident $(($memo_param_expr:expr))*)* ) => { __memoize_internal!( @indices [$($t)* $memo_param_head as usize,] $($memo_param $(($memo_param_expr))*),* ) }; ( @indices [$($t:tt)*] $memo_param_head:ident($memo_param_expr_head:expr) $(, $memo_param:ident $(($memo_param_expr:expr))*)* ) => { __memoize_internal!( @indices [$($t)* $memo_param_expr_head(&$memo_param_head),] $($memo_param $(($memo_param_expr))*),* ) }; (@get $memo:expr, $indices:expr, $i:expr) => { $memo }; (@get $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => { __memoize_internal!( @get $memo.and_then(|memo| memo.get($indices[$i])), $indices, $i + 1 $(, $memo_param)* ) }; (@resize $memo:expr, $indices:expr, $i:expr) => {}; (@resize $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => { if $memo.len() < $indices[$i] + 1 { $memo.resize($indices[$i] + 1, <_>::default()); } __memoize_internal!( @resize $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)* ) }; (@index $memo:expr, $indices:expr, $i:expr) => { $memo }; (@index $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => { __memoize_internal!( @index $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)* ) }; } macro_rules! memoize { ( #[memo_param($($memo_param:ident $(($memo_param_expr:expr))*),+)] $(#[$meta:meta])* fn $f:ident($memo:ident: _ $(, $f_param:ident: $FParamT:ty)*) -> $FRetT:ty $f_body:block ) => { $(#[$meta])* fn $f( $memo: &mut __memoize_internal!(@vec $FRetT $(, $memo_param)+) $(, $f_param: $FParamT)+ ) -> $FRetT { let indices = __memoize_internal!(@indices [] $($memo_param $(($memo_param_expr))*),+); if let Some(&Some(ref ret)) = __memoize_internal!(@get Some(&$memo), indices, 0 $(, $memo_param)+) { return ret.clone(); } let ret = $f_body; __memoize_internal!(@resize $memo, indices, 0 $(, $memo_param)+); __memoize_internal!(@index $memo, indices, 0 $(, $memo_param)+) = Some(ret.clone()); ret } }; }
これは何
例えば,$n$ 番目のフィボナッチ数を求める関数:
fn fib(n: usize) -> u64 { match n { 0 => 1, 1 => 1, n => fib(n - 1) + fib(n - 2), } }
これをメモ化したいです.真面目にやるとこんな実装になります:
fn fib(memo: &mut Vec<Option<u64>>, n: usize) -> u64 { if let Some(&Some(fib_n)) = memo.get(n) { return fib_n; } let fib_n = match n { 0 => 1, 1 => 1, n => fib(memo, n - 2) + fib(memo, n - 1), }; if memo.len() < n + 1 { memo.resize(n + 1, None); } memo[n] = Some(fib_n); fib_n }
これを毎回書くのはちょっとつらいです.また,引数が増えるとさらに面倒になります.という訳で,このボイラープレートを自動生成するのが今回紹介する memoize
マクロです.
今回の場合は次のように実装できます:
memoize! { #[memo_param(n)] fn fib(memo: _, n: usize) -> u64 { match n { 0 => 1, 1 => 1, n => fib(memo, n - 1) + fib(memo, n - 2), } } }
元のコードとほとんど形が変わっていませんね.
使い方
こんな風に使います.これは竹内関数の実装ですが,説明のためにいろいろと無駄な書き方をしています.雰囲気を感じ取ってくだされば.
memoize! { #[memo_param(x_y(|&(x, y)| x as usize * 1000 + y as usize), z)] fn tarai(memo: _, x_y: (u64, u64), z: u64, hoge: i32) -> u64 { let (x, y) = x_y; assert!(y < 1000); if x <= y { y } else { let p = tarai(memo, (x - 1, y), z, hoge); let q = tarai(memo, (y - 1, z), x, hoge); let r = tarai(memo, (z - 1, x), y, hoge); tarai(memo, (p, q), r, hoge) } } }
#[memo_param(...)]
でメモ化する際の引数 (関数の返り値が依存する引数) を指定します.実行時に決まる定数 (問題の入力や,前計算で生成した何らかのテーブルなど) を引数として渡している場合はここで指定しないようにします.
メモはすべて Vec
に取るようになっているので1,引数は (ある程度小さい) usize
の添字に変換できる必要があります.この変換には as usize
を用いています.変換を明示的に指定したい場合は #[memo_param(x(f))]
の構文を使うことができます.f
が変換する関数で,概ね fn(&T) -> usize
2 (ただし T
は x
の型) であればよいです.
memo
の名前は何でも良いですが,必ず第一引数としてください.この型はマクロが生成するので _
とします (&mut Vec<Vec< ... <Option<T>> ... >>
のような型になります (ただし T
は関数の返り値)).
関数の返り値の型は Clone
を実装している必要があります.
使用例
ABC099 C - Strange Bank に対する解答を挙げます.
const MAX_N: u64 = 100000; let coins = { let mut vec = vec![1]; fn create_table(base: u64, vec: &mut Vec<u64>) { let mut crt = base; while crt <= MAX_N { vec.push(crt); crt *= base; } } create_table(6, &mut vec); create_table(9, &mut vec); vec.sort(); vec }; memoize! { #[memo_param(price)] fn count(memo: _, price: u64, coins: &[u64]) -> usize { if price == 0 { 0 } else { coins .iter() .filter(|&&coin| coin <= price) .map(|&coin| 1 + count(memo, price - coin, coins)) .min() .unwrap() } } } let n = scan!(u64); // 入力 let ans = count(&mut vec![], n, &coins); println!("{}", ans);