【ABC248C】
包除原理で上限つきの重複組合せを数える
(Common Lisp)

関連記事

1. 組み合わせと包除原理

N 個のサイコロを振る。
各サイコロの目は 1 以上 M 以下の整数。
出た目の合計が K 以下になる組み合わせの数を 998244353 で割った余りで求めよ。

C – Dice Sum

この問題「Dice Sum」を「N 人に飴を配る問題」として解こうと思ったのですが、「一人あたり最大 M 個まで」という上限を、組合せの式でどうやって計算したらよいか困りました。
そこで、包除原理の方法を知ることになりました。

1.1. 単純なケースを仕切り法で数える

K個の飴を N人に配る場合の組み合わせを考えます。
K 個の飴を i人目の人に配る数を aia_i1ai1 \leq a_i)とします。
すると、 N個の数 aia_i の合計は K 以下になります。

ただ、問題では「合計が K 以下」なので、余りが出ることがあります。
そこで、余りを受け取るダミー r0r \geq 0 を 1人加えます。
rr は「K までの余り」を表すので、a1++aNKa_1 + \cdots + a_N \leq Ka1++aN+r=Ka_1 + \cdots + a_N + r = K という等式に置き換えられます。

a1+a2++aN+r=K(ai1, r0)a_1 + a_2 + \cdots + a_N + r = K \quad (a_i \geq 1,\ r \geq 0)

ダミーが増えたので、N+1人で分配することになります。

KK 個の飴を一列に並べると、飴と飴の間は K1K – 1 箇所あります。
N+1 個の区画に分けるためには N 本の仕切りを置きます。
仕切りの間には必ず飴が 1 個以上あるので ai1a_i \geq 1 が自動的に満たされます。

ただし、ダミー rr は特別で、 0 でもよいです。
なので、KK 個目の右端にも仕切りが置けます。
つまり、仕切りを置ける場所は合計 KK 箇所になるので、場合の数は

(KN)\binom{K}{N}
(comb/naive K N)

です。

(nr)=n!r!(nr)!\binom{n}{r} = \frac{n!}{r!(n-r)!} は、階乗の積で計算できます。

(defun fact/naive (n)
  (loop for i from 1 to n product i))

(defun comb/naive (n r)
  (if (or (< n 0) (< r 0) (> r n))
      0
      (/ (fact/naive n)
         (* (fact/naive r)
            (fact/naive (- n r))))))Code language: JavaScript (javascript)

範囲外の組合せは 0 として扱うことで、数式上 0 になる項をコードでも自然に表現できます。

2. 上限がある場合——包除原理

包除原理(inclusion-exclusion principle)は、複数の条件を満たす場合の数を数えるとき、「全体から条件を破るものを引き、引きすぎた分を足し戻す」操作を繰り返す数え方です。
「含める・除く」を交互に繰り返すことからこの名前がついています。

(KN)\binom{K}{N} では、上限 aiMa_i \leq M を考慮していませんでした。
この上限なしの全体を使って、そこから「上限を破っているケース」を引けば答えが出ます。

違反とは aiM+1a_i \geq M + 1 になることです。

ある 1 人が違反していると仮定すると、その人は ai1a_i \geq 1 という前提の上にさらに M 個余分に持っていればよいことになります。
なので「その人にあらかじめ M 個追加で渡しておいた状態で、残りを上限なしで配る」、そうすれば、その場合の数が「その人が違反している場合の数」になります。

このとき、配れる飴の数は、

KMK – M

これを、同じ要領で配るので、場合の数は (KMN)\binom{K-M}{N} です。
ただ、ダミー以外の N 人のうち違反者が誰かは (N1)\binom{N}{1} 通りあるので、「誰か 1 人が違反している場合」は、

(N1)(KMN)\binom{N}{1} \binom{K-M}{N}

と考えられます。

2.1. 重複してカウントされている

これを全体から引けばよさそうに見えますが、2 人以上が同時に違反しているケースが二重に含まれています。

たとえば a1a_1a2a_2 が両方違反しているケースは、「a1a_1 が違反」でも「a2a_2 が違反」でもカウントされるので、余分に引かれてしまいます。

そこで、「2 人同時に違反している場合」を足し戻します。
今度は、2 人にそれぞれ M 個追加で先渡しすると、残りは

K2MK – 2M

です。
j=2j = 2 人の選び方は (N2)\binom{N}{2} 通りあるので、足し戻す量は

(N2)(K2MN)\binom{N}{2} \binom{K-2M}{N}

です。
さらに、3 人同時の場合はまた二重になるので引く、という操作を繰り返します。

これを一般化します。
jj 人が同時に違反している場合、jj 人にそれぞれ M 個追加で先渡しすると、残りは

KjMK – jM
(- K (* j M))  ; rest

です。
jj 人の選び方は (Nj)\binom{N}{j} 通りあり、符号は引く・足す・引く……と (1)j(-1)^j で交互に入れ替わります。
KjM<NK – jM < N になると残りが足りなくて 1 個ずつ配れないので、その項は 0 になります。

すべての jj について足し合わせると、

答え=j=0N(1)j(Nj)(KjMN)\text{答え} = \sum_{j=0}^{N} (-1)^j \binom{N}{j} \binom{K – jM}{N}
(loop for j from 0 to N
      for rest = (- K (* j M))
      while (>= rest N)               ; K - jM < N の項を除く
      for term = (* (comb/naive N j) ; C(N, j)
                    (comb/naive rest N)) ; C(K - jM, N)
      sum (if (evenp j) term (- term)))  ; (-1)^jCode language: JavaScript (javascript)

これは KjM<NK – jM < N になると、配れなくなって終了です。

2.2. naive な実装と確認

これを関数にすると、

(defun count-dice-sum/naive (N M K)
  (loop for j from 0 to N
        for rest = (- K (* j M))
        while (>= rest N)
        for term = (* (comb/naive N j)
                      (comb/naive rest N))
        sum (if (evenp j) term (- term))))Code language: JavaScript (javascript)

N=1, M=3, K=3 で確かめると、j=0 のとき (comb/naive 3 1) = 3、j=1 のとき rest = 0 で (>= 0 1) が偽なので除外され、答えは 3 になります。
出目 1, 2, 3 の 3 通りと一致します。

2.3. 素朴な実装

ここまでのコードは、

(defun fact/naive (n)
  (loop for i from 1 to n product i))

(defun comb/naive (n r)
  (if (or (< n 0) (< r 0) (> r n))
      0
      (/ (fact/naive n)
         (* (fact/naive r)
            (fact/naive (- n r))))))

(defun count-dice-sum/naive (N M K)
  (loop for j from 0 to N
        for rest = (- K (* j M))
        while (>= rest N)
        for term = (* (comb/naive N j)
                      (comb/naive rest N))
        sum (if (evenp j) term (- term))))

(defun main ()
  (let ((N (read))
        (M (read))
        (K (read)))
    (init-fact-table K)
    (princ (count-dice-sum N M K))))

#-swank(main)Code language: PHP (php)

3. 階乗とその逆元をメモ化する

ただし、comb/naive は、数が大きくなると毎回 $n!$ を計算しているので遅くなり、 mod を取っていないので大きな数値になってしまいます。

そこで、あらかじめ階乗をメモしておけば、(nr)\binom{n}{r} の計算は掛け算3回で済みます。

また、割り算の mod を取るには逆元が必要です。
逆元は、pp が素数で xxpp で割り切れないときには、フェルマーの小定理より

x1xp2(modp)x^{-1} \equiv x^{p-2} \pmod{p}
(mod-pow (aref fact max-n) (- +mod+ 2) +mod+)  ; (max-n!)^(p-2) mod p

です。

ここでは $x = $ max-n! として、最大値の階乗の逆元を求めます。

累乗を(expt base exp)でそのまま計算すると、指数が大きいと天文学的な数になります。

(defun mod-pow/naive (base exp m)
  (mod (expt base exp) m))

そこで、繰り返し二乗法を使って、modを取りながら累乗していきます。

(defun mod-pow (base exp m)
  (loop with result = 1
        while (> exp 0)
        do (when (oddp exp)
             (setf result (mod (* result base) m)))
           (setf base (mod (* base base) m)
                 exp (ash exp -1))
        finally (return result)))Code language: JavaScript (javascript)

また、最大値の逆元だけ mod-pow で求め、残りは逆順に漸化式で計算します。

(inv-fact[n])=(inv-fact[n+1])×(n+1)(\text{inv-fact}[n]) = (\text{inv-fact}[n+1]) \times (n+1)
(setf (aref inv-fact i)
      (mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+))

そうすれば、階乗の逆元は O(K)O(K) で求められます。

n!=(n+1)!/(n+1)n! = (n+1)! \,/\, (n+1) の両辺の逆元を取ると (n!)1=(n+1)×((n+1)!)1(n!)^{-1} = (n+1) \times ((n+1)!)^{-1} になるので、大きい方から順に埋めていきます。

(defconstant +mod+ 998244353)
(defparameter *fact* nil)
(defparameter *inv-fact* nil)

(defun init-fact-table (max-n)
  (let ((fact (make-array (1+ max-n) :initial-element 1))
        (inv-fact (make-array (1+ max-n) :initial-element 1)))
    (loop for i from 1 to max-n
          do (setf (aref fact i)
                   (mod (* (aref fact (1- i)) i) +mod+)))
    (setf (aref inv-fact max-n)
          (mod-pow (aref fact max-n) (- +mod+ 2) +mod+))
    (loop for i from (1- max-n) downto 0
          do (setf (aref inv-fact i)
                   (mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+)))
    (setf *fact* fact
          *inv-fact* inv-fact)))Code language: JavaScript (javascript)

ただし、この階乗テーブルによる組合せ計算は K が法 998244353 未満であることを前提にしています。
もし、K が法以上になると階乗に 998244353 の倍数が含まれて逆元を取れなくなるため、この方法のままでは使えません。

3.1. 組み合わせ計算を効率化した

comb は、*fact**inv-fact*を使って n!r!(nr)!\frac{n!}{r!(n-r)!} を計算します。

(nr)=n!×(r!)1×((nr)!)1\binom{n}{r} = n! \times (r!)^{-1} \times ((n-r)!)^{-1}
(defun comb (n r)
  (if (or (< n 0) (< r 0) (> r n))
      0
      (mod (* (aref *fact* n)                    ; n!
              (mod (* (aref *inv-fact* r)         ; (r!)^-1
                      (aref *inv-fact* (- n r)))  ; ((n-r)!)^-1
                   +mod+))
           +mod+)))

count-dice-sum/naivecomb/naivecomb に置き換えて mod を取ります。

(defun count-dice-sum (N M K)
  (loop for j from 0 to N
        for rest = (- K (* j M))
        while (>= rest N)
        for term = (mod (* (comb N j)
                           (comb rest N))
                        +mod+)
        sum (if (evenp j) term (- term)) into answer
        finally (return (mod answer +mod+))))Code language: JavaScript (javascript)

main では式に出てくる最大の添字 (KN)\binom{K}{N} の第一引数 K まで確保します。

(defun main ()
  (let ((N (read))
        (M (read))
        (K (read)))
    (init-fact-table K)
    (princ (count-dice-sum N M K))))

#-swank(main)Code language: CSS (css)

3.2. DP との計算量の比較

一般的な動的計画法による解答だと、計算量はおよそ O(N×K×M)O(N \times K \times M) です。
「(選んだ個数, 現在の合計)」という状態をDPに記録して、各状態から M 通りを試すからです。
もし、累積和を使えば O(NK)O(NK) までは改善できます。

一方、包除原理は「上限を破っている人数 jj」でまとめて、jj の各値について組合せを計算で求めます。
テーブル初期化は O(K)O(K) です。
ループは j=0j = 0 から NN まで N+1N+1 回で、各ステップはテーブル引き O(1)O(1) なので本体は O(N)O(N) で済みます。

4. 完成したコード

(defconstant +mod+ 998244353)
(defparameter *fact* nil)
(defparameter *inv-fact* nil)

(defun mod-pow (base exp m)
  (loop with result = 1
        while (> exp 0)
        do (when (oddp exp)
             (setf result (mod (* result base) m)))
           (setf base (mod (* base base) m)
                 exp (ash exp -1))
        finally (return result)))

(defun init-fact-table (max-n)
  (let ((fact (make-array (1+ max-n) :initial-element 1))
        (inv-fact (make-array (1+ max-n) :initial-element 1)))
    (loop for i from 1 to max-n
          do (setf (aref fact i)
                   (mod (* (aref fact (1- i)) i) +mod+)))
    (setf (aref inv-fact max-n)
          (mod-pow (aref fact max-n) (- +mod+ 2) +mod+))
    (loop for i from (1- max-n) downto 0
          do (setf (aref inv-fact i)
                   (mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+)))
    (setf *fact* fact
          *inv-fact* inv-fact)))

(defun comb (n r)
  (if (or (< n 0) (< r 0) (> r n))
      0
      (mod (* (aref *fact* n)
              (mod (* (aref *inv-fact* r)
                      (aref *inv-fact* (- n r)))
                   +mod+))
           +mod+)))

(defun count-dice-sum (N M K)
  (loop for j from 0 to N
        for rest = (- K (* j M))
        while (>= rest N)
        for term = (mod (* (comb N j)
                           (comb rest N))
                        +mod+)
        sum (if (evenp j) term (- term)) into answer
        finally (return (mod answer +mod+))))

(defun main ()
  (let ((N (read))
        (M (read))
        (K (read)))
    (init-fact-table K)
    (princ (count-dice-sum N M K))))

#-swank(main)Code language: PHP (php)

コードの半分は、組み合わせを計算する comb の動的計画法による実装ですね。