1. 組み合わせと包除原理
N 個のサイコロを振る。
C – Dice Sum
各サイコロの目は 1 以上 M 以下の整数。
出た目の合計が K 以下になる組み合わせの数を 998244353 で割った余りで求めよ。
この問題「Dice Sum」を「N 人に飴を配る問題」として解こうと思ったのですが、「一人あたり最大 M 個まで」という上限を、組合せの式でどうやって計算したらよいか困りました。
そこで、包除原理の方法を知ることになりました。
- E – Dice Sum(ADT 20260513_1版)
1.1. 単純なケースを仕切り法で数える
K個の飴を N人に配る場合の組み合わせを考えます。
K 個の飴を i人目の人に配る数を ()とします。
すると、 N個の数 の合計は K 以下になります。
ただ、問題では「合計が K 以下」なので、余りが出ることがあります。
そこで、余りを受け取るダミー を 1人加えます。
は「K までの余り」を表すので、 を という等式に置き換えられます。
ダミーが増えたので、N+1人で分配することになります。
個の飴を一列に並べると、飴と飴の間は 箇所あります。
N+1 個の区画に分けるためには N 本の仕切りを置きます。
仕切りの間には必ず飴が 1 個以上あるので が自動的に満たされます。
ただし、ダミー は特別で、 0 でもよいです。
なので、 個目の右端にも仕切りが置けます。
つまり、仕切りを置ける場所は合計 箇所になるので、場合の数は
(comb/naive K N)
です。
は、階乗の積で計算できます。
(defun fact/naive (n)
(loop for i from 1 to n product i))
(defun comb/naive (n r)
(if (or (< n 0) (< r 0) (> r n))
0
(/ (fact/naive n)
(* (fact/naive r)
(fact/naive (- n r))))))Code language: JavaScript (javascript)
範囲外の組合せは 0 として扱うことで、数式上 0 になる項をコードでも自然に表現できます。
2. 上限がある場合——包除原理
包除原理(inclusion-exclusion principle)は、複数の条件を満たす場合の数を数えるとき、「全体から条件を破るものを引き、引きすぎた分を足し戻す」操作を繰り返す数え方です。
「含める・除く」を交互に繰り返すことからこの名前がついています。
では、上限 を考慮していませんでした。
この上限なしの全体を使って、そこから「上限を破っているケース」を引けば答えが出ます。
違反とは になることです。
ある 1 人が違反していると仮定すると、その人は という前提の上にさらに M 個余分に持っていればよいことになります。
なので「その人にあらかじめ M 個追加で渡しておいた状態で、残りを上限なしで配る」、そうすれば、その場合の数が「その人が違反している場合の数」になります。
このとき、配れる飴の数は、
これを、同じ要領で配るので、場合の数は です。
ただ、ダミー以外の N 人のうち違反者が誰かは 通りあるので、「誰か 1 人が違反している場合」は、
と考えられます。
2.1. 重複してカウントされている
これを全体から引けばよさそうに見えますが、2 人以上が同時に違反しているケースが二重に含まれています。
たとえば と が両方違反しているケースは、「 が違反」でも「 が違反」でもカウントされるので、余分に引かれてしまいます。
そこで、「2 人同時に違反している場合」を足し戻します。
今度は、2 人にそれぞれ M 個追加で先渡しすると、残りは
です。
人の選び方は 通りあるので、足し戻す量は
です。
さらに、3 人同時の場合はまた二重になるので引く、という操作を繰り返します。
これを一般化します。
人が同時に違反している場合、 人にそれぞれ M 個追加で先渡しすると、残りは
(- K (* j M)) ; rest
です。
人の選び方は 通りあり、符号は引く・足す・引く……と で交互に入れ替わります。
になると残りが足りなくて 1 個ずつ配れないので、その項は 0 になります。
すべての について足し合わせると、
(loop for j from 0 to N
for rest = (- K (* j M))
while (>= rest N) ; K - jM < N の項を除く
for term = (* (comb/naive N j) ; C(N, j)
(comb/naive rest N)) ; C(K - jM, N)
sum (if (evenp j) term (- term))) ; (-1)^jCode language: JavaScript (javascript)
これは になると、配れなくなって終了です。
2.2. naive な実装と確認
これを関数にすると、
(defun count-dice-sum/naive (N M K)
(loop for j from 0 to N
for rest = (- K (* j M))
while (>= rest N)
for term = (* (comb/naive N j)
(comb/naive rest N))
sum (if (evenp j) term (- term))))Code language: JavaScript (javascript)
N=1, M=3, K=3 で確かめると、j=0 のとき (comb/naive 3 1) = 3、j=1 のとき rest = 0 で (>= 0 1) が偽なので除外され、答えは 3 になります。
出目 1, 2, 3 の 3 通りと一致します。
2.3. 素朴な実装
ここまでのコードは、
(defun fact/naive (n)
(loop for i from 1 to n product i))
(defun comb/naive (n r)
(if (or (< n 0) (< r 0) (> r n))
0
(/ (fact/naive n)
(* (fact/naive r)
(fact/naive (- n r))))))
(defun count-dice-sum/naive (N M K)
(loop for j from 0 to N
for rest = (- K (* j M))
while (>= rest N)
for term = (* (comb/naive N j)
(comb/naive rest N))
sum (if (evenp j) term (- term))))
(defun main ()
(let ((N (read))
(M (read))
(K (read)))
(init-fact-table K)
(princ (count-dice-sum N M K))))
#-swank(main)Code language: PHP (php)
3. 階乗とその逆元をメモ化する
ただし、comb/naive は、数が大きくなると毎回 $n!$ を計算しているので遅くなり、 mod を取っていないので大きな数値になってしまいます。
そこで、あらかじめ階乗をメモしておけば、 の計算は掛け算3回で済みます。
また、割り算の mod を取るには逆元が必要です。
逆元は、 が素数で が で割り切れないときには、フェルマーの小定理より
(mod-pow (aref fact max-n) (- +mod+ 2) +mod+) ; (max-n!)^(p-2) mod p
です。
ここでは $x = $ max-n! として、最大値の階乗の逆元を求めます。
累乗を(expt base exp)でそのまま計算すると、指数が大きいと天文学的な数になります。
(defun mod-pow/naive (base exp m)
(mod (expt base exp) m))
そこで、繰り返し二乗法を使って、modを取りながら累乗していきます。
(defun mod-pow (base exp m)
(loop with result = 1
while (> exp 0)
do (when (oddp exp)
(setf result (mod (* result base) m)))
(setf base (mod (* base base) m)
exp (ash exp -1))
finally (return result)))Code language: JavaScript (javascript)
また、最大値の逆元だけ mod-pow で求め、残りは逆順に漸化式で計算します。
(setf (aref inv-fact i)
(mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+))
そうすれば、階乗の逆元は で求められます。
の両辺の逆元を取ると になるので、大きい方から順に埋めていきます。
(defconstant +mod+ 998244353)
(defparameter *fact* nil)
(defparameter *inv-fact* nil)
(defun init-fact-table (max-n)
(let ((fact (make-array (1+ max-n) :initial-element 1))
(inv-fact (make-array (1+ max-n) :initial-element 1)))
(loop for i from 1 to max-n
do (setf (aref fact i)
(mod (* (aref fact (1- i)) i) +mod+)))
(setf (aref inv-fact max-n)
(mod-pow (aref fact max-n) (- +mod+ 2) +mod+))
(loop for i from (1- max-n) downto 0
do (setf (aref inv-fact i)
(mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+)))
(setf *fact* fact
*inv-fact* inv-fact)))Code language: JavaScript (javascript)
ただし、この階乗テーブルによる組合せ計算は K が法 998244353 未満であることを前提にしています。
もし、K が法以上になると階乗に 998244353 の倍数が含まれて逆元を取れなくなるため、この方法のままでは使えません。
3.1. 組み合わせ計算を効率化した
comb は、*fact*と*inv-fact*を使って を計算します。
(defun comb (n r)
(if (or (< n 0) (< r 0) (> r n))
0
(mod (* (aref *fact* n) ; n!
(mod (* (aref *inv-fact* r) ; (r!)^-1
(aref *inv-fact* (- n r))) ; ((n-r)!)^-1
+mod+))
+mod+)))
count-dice-sum/naive の comb/naive を comb に置き換えて mod を取ります。
(defun count-dice-sum (N M K)
(loop for j from 0 to N
for rest = (- K (* j M))
while (>= rest N)
for term = (mod (* (comb N j)
(comb rest N))
+mod+)
sum (if (evenp j) term (- term)) into answer
finally (return (mod answer +mod+))))Code language: JavaScript (javascript)
main では式に出てくる最大の添字 の第一引数 K まで確保します。
(defun main ()
(let ((N (read))
(M (read))
(K (read)))
(init-fact-table K)
(princ (count-dice-sum N M K))))
#-swank(main)Code language: CSS (css)
3.2. DP との計算量の比較
一般的な動的計画法による解答だと、計算量はおよそ です。
「(選んだ個数, 現在の合計)」という状態をDPに記録して、各状態から M 通りを試すからです。
もし、累積和を使えば までは改善できます。
一方、包除原理は「上限を破っている人数 」でまとめて、 の各値について組合せを計算で求めます。
テーブル初期化は です。
ループは から まで 回で、各ステップはテーブル引き なので本体は で済みます。
4. 完成したコード
(defconstant +mod+ 998244353)
(defparameter *fact* nil)
(defparameter *inv-fact* nil)
(defun mod-pow (base exp m)
(loop with result = 1
while (> exp 0)
do (when (oddp exp)
(setf result (mod (* result base) m)))
(setf base (mod (* base base) m)
exp (ash exp -1))
finally (return result)))
(defun init-fact-table (max-n)
(let ((fact (make-array (1+ max-n) :initial-element 1))
(inv-fact (make-array (1+ max-n) :initial-element 1)))
(loop for i from 1 to max-n
do (setf (aref fact i)
(mod (* (aref fact (1- i)) i) +mod+)))
(setf (aref inv-fact max-n)
(mod-pow (aref fact max-n) (- +mod+ 2) +mod+))
(loop for i from (1- max-n) downto 0
do (setf (aref inv-fact i)
(mod (* (aref inv-fact (1+ i)) (1+ i)) +mod+)))
(setf *fact* fact
*inv-fact* inv-fact)))
(defun comb (n r)
(if (or (< n 0) (< r 0) (> r n))
0
(mod (* (aref *fact* n)
(mod (* (aref *inv-fact* r)
(aref *inv-fact* (- n r)))
+mod+))
+mod+)))
(defun count-dice-sum (N M K)
(loop for j from 0 to N
for rest = (- K (* j M))
while (>= rest N)
for term = (mod (* (comb N j)
(comb rest N))
+mod+)
sum (if (evenp j) term (- term)) into answer
finally (return (mod answer +mod+))))
(defun main ()
(let ((N (read))
(M (read))
(K (read)))
(init-fact-table K)
(princ (count-dice-sum N M K))))
#-swank(main)Code language: PHP (php)
コードの半分は、組み合わせを計算する comb の動的計画法による実装ですね。