kuretchi's blog

kuretchi's blog

競技プログラミングなどなど...

Rust で自動メモ化マクロ

作ってしまいました.かなり雑.Rust 1.15.1 で動きます.

macro_rules! __memoize_internal {
  (@vec $FRetT:ty) => {
    Option<$FRetT>
  };
  (@vec $FRetT:ty, $memo_param_head:ident $(, $memo_param:ident)*) => {
    Vec<__memoize_internal!(@vec $FRetT $(, $memo_param)*)>
  };
  (@indices [$($t:tt)*]) => {
    [$($t)*]
  };
  (
    @indices [$($t:tt)*]
    $memo_param_head:ident
    $(, $memo_param:ident $(($memo_param_expr:expr))*)*
  ) => {
    __memoize_internal!(
      @indices [$($t)* $memo_param_head as usize,] $($memo_param $(($memo_param_expr))*),*
    )
  };
  (
    @indices [$($t:tt)*]
    $memo_param_head:ident($memo_param_expr_head:expr)
    $(, $memo_param:ident $(($memo_param_expr:expr))*)*
  ) => {
    __memoize_internal!(
      @indices [$($t)* $memo_param_expr_head(&$memo_param_head),]
      $($memo_param $(($memo_param_expr))*),*
    )
  };
  (@get $memo:expr, $indices:expr, $i:expr) => {
    $memo
  };
  (@get $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    __memoize_internal!(
      @get $memo.and_then(|memo| memo.get($indices[$i])), $indices, $i + 1 $(, $memo_param)*
    )
  };
  (@resize $memo:expr, $indices:expr, $i:expr) => {};
  (@resize $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    if $memo.len() < $indices[$i] + 1 {
      $memo.resize($indices[$i] + 1, <_>::default());
    }
    __memoize_internal!(
      @resize $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)*
    )
  };
  (@index $memo:expr, $indices:expr, $i:expr) => {
    $memo
  };
  (@index $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    __memoize_internal!(
      @index $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)*
    )
  };
}

macro_rules! memoize {
  (
    #[memo_param($($memo_param:ident $(($memo_param_expr:expr))*),+)]
    $(#[$meta:meta])*
    fn $f:ident($memo:ident: _ $(, $f_param:ident: $FParamT:ty)*) -> $FRetT:ty $f_body:block
  ) => {
    $(#[$meta])*
    fn $f(
      $memo: &mut __memoize_internal!(@vec $FRetT $(, $memo_param)+)
      $(, $f_param: $FParamT)+
    ) -> $FRetT {
      let indices = __memoize_internal!(@indices [] $($memo_param $(($memo_param_expr))*),+);

      if let Some(&Some(ref ret)) =
        __memoize_internal!(@get Some(&$memo), indices, 0 $(, $memo_param)+)
      {
        return ret.clone();
      }

      let ret = $f_body;

      __memoize_internal!(@resize $memo, indices, 0 $(, $memo_param)+);
      __memoize_internal!(@index $memo, indices, 0 $(, $memo_param)+) = Some(ret.clone());

      ret
    }
  };
}

これは何

例えば,$n$ 番目のフィボナッチ数を求める関数:

fn fib(n: usize) -> u64 {
  match n {
    0 => 1,
    1 => 1,
    n => fib(n - 1) + fib(n - 2),
  }
}

これをメモ化したいです.真面目にやるとこんな実装になります:

fn fib(memo: &mut Vec<Option<u64>>, n: usize) -> u64 {
  if let Some(&Some(fib_n)) = memo.get(n) {
    return fib_n;
  }

  let fib_n = match n {
    0 => 1,
    1 => 1,
    n => fib(memo, n - 2) + fib(memo, n - 1),
  };

  if memo.len() < n + 1 {
    memo.resize(n + 1, None);
  }

  memo[n] = Some(fib_n);

  fib_n
}

これを毎回書くのはちょっとつらいです.また,引数が増えるとさらに面倒になります.という訳で,このボイラープレートを自動生成するのが今回紹介する memoize マクロです.

今回の場合は次のように実装できます:

memoize! {
  #[memo_param(n)]
  fn fib(memo: _, n: usize) -> u64 {
    match n {
      0 => 1,
      1 => 1,
      n => fib(memo, n - 1) + fib(memo, n - 2),
    }
  }
}

元のコードとほとんど形が変わっていませんね.

使い方

こんな風に使います.これは竹内関数の実装ですが,説明のためにいろいろと無駄な書き方をしています.雰囲気を感じ取ってくだされば.

memoize! {
  #[memo_param(x_y(|&(x, y)| x as usize * 1000 + y as usize), z)]
  fn tarai(memo: _, x_y: (u64, u64), z: u64, hoge: i32) -> u64 {
    let (x, y) = x_y;
    assert!(y < 1000);
    if x <= y {
      y
    } else {
      let p = tarai(memo, (x - 1, y), z, hoge);
      let q = tarai(memo, (y - 1, z), x, hoge);
      let r = tarai(memo, (z - 1, x), y, hoge);
      tarai(memo, (p, q), r, hoge)
    }
  }
}

#[memo_param(...)] でメモ化する際の引数 (関数の返り値が依存する引数) を指定します.実行時に決まる定数 (問題の入力や,前計算で生成した何らかのテーブルなど) を引数として渡している場合はここで指定しないようにします.

メモはすべて Vec に取るようになっているので1,引数は (ある程度小さい) usize の添字に変換できる必要があります.この変換には as usize を用いています.変換を明示的に指定したい場合は #[memo_param(x(f))] の構文を使うことができます.f が変換する関数で,概ね fn(&T) -> usize2 (ただし Tx の型) であればよいです.

memo の名前は何でも良いですが,必ず第一引数としてください.この型はマクロが生成するので _ とします (&mut Vec<Vec< ... <Option<T>> ... >> のような型になります (ただし T は関数の返り値)).

関数の返り値の型は Clone を実装している必要があります.

使用例

ABC099 C - Strange Bank に対する解答を挙げます.

const MAX_N: u64 = 100000;

let coins = {
  let mut vec = vec![1];
  fn create_table(base: u64, vec: &mut Vec<u64>) {
    let mut crt = base;
    while crt <= MAX_N {
      vec.push(crt);
      crt *= base;
    }
  }
  create_table(6, &mut vec);
  create_table(9, &mut vec);
  vec.sort();
  vec
};

memoize! {
  #[memo_param(price)]
  fn count(memo: _, price: u64, coins: &[u64]) -> usize {
    if price == 0 {
      0
    } else {
      coins
        .iter()
        .filter(|&&coin| coin <= price)
        .map(|&coin| 1 + count(memo, price - coin, coins))
        .min()
        .unwrap()
    }
  }
}

let n = scan!(u64); // 入力
let ans = count(&mut vec![], n, &coins);

println!("{}", ans);

  1. 改良の余地あり.

  2. クロージャでもいけます.何か変なものをキャプチャしていると険しいかも.この辺は適当 (詳しくは実装を見てください).