kuretchi's blog

kuretchi's blog

競技プログラミングなどなど...

Rust で自動メモ化マクロ

作ってしまいました.かなり雑.Rust 1.15.1 で動きます.

macro_rules! __memoize_internal {
  (@vec $FRetT:ty) => {
    Option<$FRetT>
  };
  (@vec $FRetT:ty, $memo_param_head:ident $(, $memo_param:ident)*) => {
    Vec<__memoize_internal!(@vec $FRetT $(, $memo_param)*)>
  };
  (@indices [$($t:tt)*]) => {
    [$($t)*]
  };
  (
    @indices [$($t:tt)*]
    $memo_param_head:ident
    $(, $memo_param:ident $(($memo_param_expr:expr))*)*
  ) => {
    __memoize_internal!(
      @indices [$($t)* $memo_param_head as usize,] $($memo_param $(($memo_param_expr))*),*
    )
  };
  (
    @indices [$($t:tt)*]
    $memo_param_head:ident($memo_param_expr_head:expr)
    $(, $memo_param:ident $(($memo_param_expr:expr))*)*
  ) => {
    __memoize_internal!(
      @indices [$($t)* $memo_param_expr_head(&$memo_param_head),]
      $($memo_param $(($memo_param_expr))*),*
    )
  };
  (@get $memo:expr, $indices:expr, $i:expr) => {
    $memo
  };
  (@get $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    __memoize_internal!(
      @get $memo.and_then(|memo| memo.get($indices[$i])), $indices, $i + 1 $(, $memo_param)*
    )
  };
  (@resize $memo:expr, $indices:expr, $i:expr) => {};
  (@resize $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    if $memo.len() < $indices[$i] + 1 {
      $memo.resize($indices[$i] + 1, <_>::default());
    }
    __memoize_internal!(
      @resize $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)*
    )
  };
  (@index $memo:expr, $indices:expr, $i:expr) => {
    $memo
  };
  (@index $memo:expr, $indices:expr, $i:expr, $memo_param_head:ident $(, $memo_param:ident)*) => {
    __memoize_internal!(
      @index $memo[$indices[$i]], $indices, $i + 1 $(, $memo_param)*
    )
  };
}

macro_rules! memoize {
  (
    #[memo_param($($memo_param:ident $(($memo_param_expr:expr))*),+)]
    $(#[$meta:meta])*
    fn $f:ident($memo:ident: _ $(, $f_param:ident: $FParamT:ty)*) -> $FRetT:ty $f_body:block
  ) => {
    $(#[$meta])*
    fn $f(
      $memo: &mut __memoize_internal!(@vec $FRetT $(, $memo_param)+)
      $(, $f_param: $FParamT)+
    ) -> $FRetT {
      let indices = __memoize_internal!(@indices [] $($memo_param $(($memo_param_expr))*),+);

      if let Some(&Some(ref ret)) =
        __memoize_internal!(@get Some(&$memo), indices, 0 $(, $memo_param)+)
      {
        return ret.clone();
      }

      let ret = $f_body;

      __memoize_internal!(@resize $memo, indices, 0 $(, $memo_param)+);
      __memoize_internal!(@index $memo, indices, 0 $(, $memo_param)+) = Some(ret.clone());

      ret
    }
  };
}

これは何

例えば,$n$ 番目のフィボナッチ数を求める関数:

fn fib(n: usize) -> u64 {
  match n {
    0 => 1,
    1 => 1,
    n => fib(n - 1) + fib(n - 2),
  }
}

これをメモ化したいです.真面目にやるとこんな実装になります:

fn fib(memo: &mut Vec<Option<u64>>, n: usize) -> u64 {
  if let Some(&Some(fib_n)) = memo.get(n) {
    return fib_n;
  }

  let fib_n = match n {
    0 => 1,
    1 => 1,
    n => fib(memo, n - 2) + fib(memo, n - 1),
  };

  if memo.len() < n + 1 {
    memo.resize(n + 1, None);
  }

  memo[n] = Some(fib_n);

  fib_n
}

これを毎回書くのはちょっとつらいです.また,引数が増えるとさらに面倒になります.という訳で,このボイラープレートを自動生成するのが今回紹介する memoize マクロです.

今回の場合は次のように実装できます:

memoize! {
  #[memo_param(n)]
  fn fib(memo: _, n: usize) -> u64 {
    match n {
      0 => 1,
      1 => 1,
      n => fib(memo, n - 1) + fib(memo, n - 2),
    }
  }
}

元のコードとほとんど形が変わっていませんね.

使い方

こんな風に使います.これは竹内関数の実装ですが,説明のためにいろいろと無駄な書き方をしています.雰囲気を感じ取ってくだされば.

memoize! {
  #[memo_param(x_y(|&(x, y)| x as usize * 1000 + y as usize), z)]
  fn tarai(memo: _, x_y: (u64, u64), z: u64, hoge: i32) -> u64 {
    let (x, y) = x_y;
    assert!(y < 1000);
    if x <= y {
      y
    } else {
      let p = tarai(memo, (x - 1, y), z, hoge);
      let q = tarai(memo, (y - 1, z), x, hoge);
      let r = tarai(memo, (z - 1, x), y, hoge);
      tarai(memo, (p, q), r, hoge)
    }
  }
}

#[memo_param(...)] でメモ化する際の引数 (関数の返り値が依存する引数) を指定します.実行時に決まる定数 (問題の入力や,前計算で生成した何らかのテーブルなど) を引数として渡している場合はここで指定しないようにします.

メモはすべて Vec に取るようになっているので1,引数は (ある程度小さい) usize の添字に変換できる必要があります.この変換には as usize を用いています.変換を明示的に指定したい場合は #[memo_param(x(f))] の構文を使うことができます.f が変換する関数で,概ね fn(&T) -> usize2 (ただし Tx の型) であればよいです.

memo の名前は何でも良いですが,必ず第一引数としてください.この型はマクロが生成するので _ とします (&mut Vec<Vec< ... <Option<T>> ... >> のような型になります (ただし T は関数の返り値)).

関数の返り値の型は Clone を実装している必要があります.

使用例

ABC099 C - Strange Bank に対する解答を挙げます.

const MAX_N: u64 = 100000;

let coins = {
  let mut vec = vec![1];
  fn create_table(base: u64, vec: &mut Vec<u64>) {
    let mut crt = base;
    while crt <= MAX_N {
      vec.push(crt);
      crt *= base;
    }
  }
  create_table(6, &mut vec);
  create_table(9, &mut vec);
  vec.sort();
  vec
};

memoize! {
  #[memo_param(price)]
  fn count(memo: _, price: u64, coins: &[u64]) -> usize {
    if price == 0 {
      0
    } else {
      coins
        .iter()
        .filter(|&&coin| coin <= price)
        .map(|&coin| 1 + count(memo, price - coin, coins))
        .min()
        .unwrap()
    }
  }
}

let n = scan!(u64); // 入力
let ans = count(&mut vec![], n, &coins);

println!("{}", ans);

  1. 改良の余地あり.

  2. クロージャでもいけます.何か変なものをキャプチャしていると険しいかも.この辺は適当 (詳しくは実装を見てください).

ラッパーの参照と参照のラッパー

何かをラップした構造体を作ることがよくある.例えば,何らかの 1 ワードの値 hoge を常に引き回す必要があるとき:

struct Value {}

impl Value {
  fn f0(&self, hoge: usize) {}
  fn f1(&self, hoge: usize, fuga: i32) {}
  // ...
}

こんな構造体を作って,Deref を実装する.

struct WithHoge<T> {
  hoge: usize,
  inner: T,
}

impl<T> Deref for WithHoge<T> {
  type Target = T;

  fn deref(&self) -> &T {
    &self.inner
  }
}

しめしめ.これでさっきの例はこう書けるぞ.

impl WithHoge<Value> {
  fn f0(&self) {}
  fn f1(&self, fuga: i32) {}
  // ...
}

と思いきや,これは適切ではない.正しくはこう.

impl WithHoge<&Value> {
  fn f0(self) {}
  fn f1(self, fuga: i32) {}
  // ...
}

Rust の多次元 Vec を初期化するマクロ

小ネタ.

Rust で多次元 Vec (dp[0][1][2] のように使えるもの) を作りたい.例えば None で初期化された $2 \times 3 \times 4$ の Vec<Vec<Vec<Option<T>>>> を作るときはこう書く.

vec![vec![vec![None; 4]; 3]; 2]

うーん.という訳でこんなマクロ.

macro_rules! nested_vec {
  ($e:expr; $n:expr) => {
    vec![$e; $n]
  };
  ($e:expr; $n:expr $(; $m:expr)+) => {
    vec![nested_vec!($e $(; $m)+); $n]
  };
}

こんな風に書ける.

nested_vec![None; 2; 3; 4]

"遅延伝播" の一般化

よく混乱するので,自分用のメモとして分かりやすくまとめておく.

遅延伝播*1についての説明は省略.

要件

半群 $(S, \bullet)$,モノイド $(O, \circ, e)$ と,(外部) 二項演算 $\triangleleft : S \rightarrow O \rightarrow S$ について,

  • $\forall s \in S. \ s \triangleleft e = s$
  • $\forall s \in S. \ \forall p, q \in O. \ s \triangleleft (p \circ q) = (s \triangleleft p) \triangleleft q$
  • $\forall s, t \in S. \ \forall p \in O. \ (s \bullet t) \triangleleft p = (s \triangleleft p) \bullet (t \triangleleft p)$

を満足すればよい.

$S$ は (連続) 部分列の畳み込み, $O$ と $\triangleleft$ は部分列に対する一様な作用を表現する.

具体例

区間最小値 & 区間加算

  • $(S, \bullet) := (\mathbb{R}, \min)$
  • $(O, \circ, e) := (\mathbb{R}, +, 0)$
  • $\triangleleft := +$
type S = Double
type O = Double

opS :: S -> S -> S
opS = min

opO :: O -> O -> O
opO = (+)

act :: S -> O -> S
act = (+)

要件を確認する.

  • $(\mathbb{R}, \min)$ は半群の要件を満たす.
  • $(\mathbb{R}, +, 0)$ はモノイドの要件を満たす.
  • $\forall s \in \mathbb{R}. \ s + 0 = s$
  • $\forall s \in \mathbb{R}. \ \forall p, q \in \mathbb{R}. \ s + (p + q) = (s + p) + q$
  • $\forall s, t \in \mathbb{R}. \ \forall p \in \mathbb{R}. \ \min(s, t) + p = \min(s + p, t + p)$

区間和 & 区間加算

作用が分配的でないときは,区間の長さを持つ必要がある.

  • $S := \mathbb{R} \times \mathbb{N}$
  • $\forall (l, n), (r, m) \in S. \ (l, n) \bullet (r, m) := (l + r, n + m)$
  • $(O, \circ, e) := (\mathbb{R}, +, 0)$
  • $\forall (s, n) \in S. \ \forall p \in O. \ (s, n) \triangleleft p := (s + np, n)$
type S = (Double, Int)
type O = Double

opS :: S -> S -> S
(l, n) `opS` (r, m) = (l + r, n + m)

opO :: O -> O -> O
opO = (+)

act :: S -> O -> S
(s, n) `act` p = (s + fromIntegral n * p, n)

要件を確認する.

  • $(\mathbb{R} \times \mathbb{N}, \bullet)$ は半群の要件を満たす ($\because$ 半群 $(\mathbb{R}, +)$, $(\mathbb{N}, +)$ の直積).
  • $(\mathbb{R}, +, 0)$ はモノイドの要件を満たす.
  • $\forall (s, n) \in \mathbb{R} \times \mathbb{N}.$
    • $(s + n \times 0, n) = (s, n)$
  • $\forall (s, n) \in \mathbb{R} \times \mathbb{N}. \ \forall p, q \in \mathbb{R}.$
    • $(s + n(p + q), n) = ((s + np) + nq, n)$
  • $\forall (s, n), (t, m) \in \mathbb{R} \times \mathbb{N}. \ \forall p \in \mathbb{R}.$
    • $((s + t) + (n + m)p, n + m) = ((s + np) + (t + mp), n + m)$

区間和 & 区間代入

代入は右零半群で表現できる.これに単位元を添加しモノイドに拡張する.

  • $S := \mathbb{R} \times \mathbb{N}$
  • $\forall (l, n), (r, m) \in S. \ (l, n) \bullet (r, m) := (l + r, n + m)$
  • 適当な $e \notin \mathbb{R}$ を取って,
    • $O := \mathbb{R} \cup \{e\}$
    • $\forall l, r \in O. \ l \circ r := \begin{cases} l & (r = e) \\ r & (\text{otherwise}) \end{cases}$
  • $\forall (s, n) \in S. \ \forall p \in O.$
    • $(s, n) \triangleleft p := \begin{cases} (s, n) & (p = e) \\ (np, n) & (\text{otherwise}) \end{cases}$
type S = (Double, Int)
type O = Maybe Double

opS :: S -> S -> S
(l, n) `opS` (r, m) = (l + r, n + m)

opO :: O -> O -> O
l `opO` Nothing = l
_ `opO` r = r

act :: S -> O -> S
(s, n) `act` Nothing = (s, n)
(_, n) `act` (Just p) = (fromIntegral n * p, n)

これは要件を満たす.証明は省略.

参考

*1:遅延評価とも呼ばれる.遅延セグメント木などを参照のこと.

yukicoder - No.585 工夫のないパズル

https://yukicoder.me/problems/no/585

コンテスト中、A 問題を通した後ずっとこれを実装していて、結局間に合わず 1 完になった。でも ★4 AC は嬉しい。

概要

4 × 4 のスライドパズルを解く問題。いわゆる 15 パズル と似ているが、空きマスがない代わりに行か列を一つ選んでスライドすることができるというルール。右か下に任意マス分スライドするのを 1 回として、100 回以内の自由な操作で揃えればよい。

解説

おもむろに実験すると、最初の $3$ 行 (列でもいいけれど) は簡単にできそうということが分かる。後々の実装を楽にするために整理しておこう。

$i$ 行目 (0-indexed) までが揃っているとして ($0$ 行目は $-1$ 行目まで揃っていると考える)、次の $i + 1$ 行目を $i$ 行目までを崩さず揃えるには、基本的に下のようにすればよい。

今注目している移動させたい数字のマスの初期座標を $(sr, sc)$、移動先を $(tr, tc)$ とし (ただし、$(行, 列)$)、$round(i) = (i \% 4 + 4) \% 4$ とすると、

基本操作

  1. 列 $tc$ を $sr - tr$ 下へスライド
  2. 行 $sr$ を $round(tc - sc)$ 右へスライド
  3. 列 $tc$ を $4 - (sr - tr)$ 下へスライド

f:id:Kuretchi:20171028091512p:plain

ここでは、アルファベットの代わりに 0-15 の数字を使うことにする。今興味がない数字は省略。

ヤバそうなケースもしっかり考えておく。例えば、移動元と移動先が同じ行にある場合は、次のようにする。

移動元と移動先の行が同じ場合 ($sr = tr$)

  1. 列 $sc$ を $1$ 下へスライド
  2. 列 $tc$ を $1$ 下へスライド
  3. 行 $sr + 1 (= tr + 1)$ を $round(tc - sc)$ 右へスライド
  4. 列 $sc$ を $3$ 下へスライド
  5. 列 $tc$ を $3$ 下へスライド

f:id:Kuretchi:20171028091636p:plain

移動元と移動先が同じ列にある場合は、右へずらしてあげるとよい。

移動元と移動先の列が同じ場合 ($sc = tc$)

  1. 行 $sr$ を $1$ 右へスライド
  2. 基本操作 ($sc$ の値を更新しておくこと)

f:id:Kuretchi:20171028091700p:plain

以上の操作で、最初の $3$ 行を揃えることができる。

さて、最後の行を適当に入れ替えたい。簡単のために、隣り合うマスを交換する操作のみを考える。結論から言うと、次の操作で可能。

$r = sr (= tr)$、$c = round(sc + 1)$ とすると、

隣り合うマスの交換

  1. 列 $c$ を下へ $1$ スライド
  2. 行 $r$ を右へ $1$ スライド
  3. 列 $c$ を下へ $3$ スライド
  4. 行 $r$ を右へ $2$ スライド
  5. 列 $c$ を下へ $1$ スライド
  6. 行 $r$ を右へ $1$ スライド
  7. 列 $c$ を下へ $3$ スライド
  8. 行 $r$ を右へ $1$ スライド

f:id:Kuretchi:20171028132131p:plain

この操作によって、自明に列のマスを任意の順番に並べ替えることができるため、以上の考察ですべてのマスを揃えることができた。

ちなみに 適当にググって出てきたページ にも同じようなことが書いてある。

操作の回数については、以上の解法をそのまま実装すると少しギリギリだが、最適化の余地は十分にあるので、いい感じにえいすれば比較的余裕だと思われる。

実装例: C# / Haskell

ソートと要素のインデックス

ここらへんでハマってとてもつらかったので書いておきます。

ソート後のインデックスから、ソート前のインデックスを知りたい

f:id:Kuretchi:20170728115804p:plain:w600

方法1

インデックスとのペアのコレクションを作り、要素をキーとしてソートする。

char arr[5] = { 'D', 'A', 'C', 'E', 'B' };

pair<char, int> arridx[5];
for (int i = 0; i < 5; i++) arridx[i] = make_pair(arr[i], i);
// arridx => { (D, 0), (A, 1), (C, 2), (E, 3), (B, 4) }

sort(begin(arridx), end(arridx),
    [](pair<char, int>& a, pair<char, int>& b) {
        return a.first < b.first; });
// arridx => { (A, 1), (B, 4), (C, 2), (D, 0), (E, 3) }

方法2

インデックスだけのコレクションを作り、ソートしたいコレクションの対応する要素をキーとしてソートする。

char arr[5] = { 'D', 'A', 'C', 'E', 'B' };

int idx[5];
for (int i = 0; i < 5; i++) idx[i] = i;
// idx => { 0, 1, 2, 3, 4 }

sort(begin(idx), end(idx),
    [&arr](int& a, int& b) { return arr[a] < arr[b]; });
// idx => { 1, 4, 2, 0, 3 }

このソートされたインデックスを使って、ソートされた順番で要素にアクセスできる。

for (int i = 0; i < 5; i++) cout << arr[idx[i]] << " ";
// => "A B C D E "

ソート前のインデックスから、ソート後のインデックスを知りたい

f:id:Kuretchi:20170728115811p:plain:w600

コレクションが与えられた後、そのインデックスがクエリとして与えられるが、事前に元のコレクションをソートしておく必要がある場合など。上図では、例えば要素 D について、ソート前のインデックス 0 からソート後のインデックス 3 を得ることができる。

前述の、ソートされたインデックスのコレクションから生成する。

方法1

int after_idx[5];
for (int i = 0; i < 5; i++) after_idx[arridx[i].second] = i;
// after_idx => { 3, 0, 2, 4, 1 }

方法2

int after_idx[5];
for (int i = 0; i < 5; i++) after_idx[idx[i]] = i;
// after_idx => { 3, 0, 2, 4, 1 }