【競プロ典型 90 問】005 - Restricted Digits
米田氏が企画した「競プロ典型90問」を解いていきます。
ここでは私の試行錯誤やら解説の理解やら気づきやらをメモしていこうと思います。 競プロ初心者なのでどうか温かい目で見守ってください。
なお、本記事中のPythonコードはブラウザ上で実行可能です。 よろしければ遊んでいってください。
問題
try(1st.)
- $c$で作れる$N$桁の数字の範囲から、$B$の倍数を探して、それが$c$で作れるか確かめればよくない?
という訳で実装したのが以下。
def main(N, B, K, c): *c, = map(str, c) xmin = int(min(c) * N) xmax = int(max(c) * N) d, m = divmod(xmin, B) imin = d+1 if m else d imax = xmax // B cnt = 0 for i in range(imin, imax+1): target = set(str(B*i)) while len(target): if not target.pop() in c: break else: cnt += 1 return cnt % (10**9 + 7)テストデータ・テスト関数定義↓
# 縮小表示 test_data = [ { "in":[3, 7, 3, [1, 4, 9]], "out": 3 },{ "in":[5, 2, 3, [1, 4, 9]], "out": 81 },{ "in":[10000, 27, 7, [1, 3, 4, 6, 7, 8, 9]], "out": 989112238 },{ "in":[1000000000000000000, 29, 6, [1, 2, 4, 5, 7, 9]], "out": 853993813 },{ "in":[1000000000000000000, 957, 7, [1, 2, 3, 5, 6, 7, 9]], "out": 205384995 }, ] def test_all(f, n=float("inf")): for i, data in enumerate(test_data): if i >= n: break exp = data["out"] ans = f(*data["in"]) result = "AC" if exp == ans else "WA" print(f"{i+1} {result}: expected: {exp}, output: {ans}") # 3つ目以降のケースでは終わらない test_all(main, 2)あかんわ。
確実に駄目だろうけど提出してみます。
結果(1st.)
やっぱりあかん…
解説①
手も足も出ないので大人しく解説・想定コードを確認します。
今回は3段階の解説になっているので、一つずつ見ていきましょう。
想定コード
kyopro_educational_90/005-01.cpp at main · E869120/kyopro_educational_90
Pythonで書き直したのが以下です。
def kaisetsu1(N, B, K, c): mod = 1_000_000_007 dp = [[0]*33 for _ in range(10009)] dp[0][0] = 1 for i in range(N): for j in range(B): for k in c: nex = (10 * j + k) % B dp[i+1][nex] += dp[i][j] dp[i+1][nex] %= mod return dp[N][0] # 4以降は終わらない test_all(kaisetsu1, 3)結果
Subtask1が AC になった!
ちなみにREの内容はlist index out of range
で、dp
配列のサイズを小課題1の制約(10000 × 30)に合わせているからです。
ではdp = [[0]*B for _ in range(N+1)]
とするとどうなるかというと、タイムアウトになります。
解説の理解
正直、解説を読んでも桁DPでググってもよく分かりませんでした😇
なので想定コードを解読しながら何をやっているのか理解しようと試みます。
1つ目のテストケースを例に考えます。
1, 4, 9だけを使った3桁の数の中で、7の倍数はいくつあるか。
まずはdp
配列を作ります。
ここでdp[0][0] = 1
としている理由はすぐに明らかになります。
最も外側のループを1回ずつ回して確認していきます。
ここで桁数だけ繰り返しているので、「各桁についてなんかやってんだな」と推測できます。
さて、ここで何をやっているのかというと、以下のように100の位が1, 4, 9だった場合の余りを求めているのです。
という訳で、dp
配列のインデックス1は100の位(右から3桁目)について、余りが0~6の場合の個数を持ってるんだなと分かります。
そしてdp[0][0] = 1
の理由は、100の位を考えるときに1,000の位の余りは0であることを表しています。
ここらへんがまだちょっと分かりづらいのでループをもう一周回しましょう。
上で見た通り、100の位の余りは[1, 2, 4]の3パターンが1つずつでした。
それぞれのパターンについて、今度は10の位(右から2桁目)が1, 4, 9の場合の余りを求めます。
100の位の余りが1
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 1 | 1 | 1 |
100の位の余りが2
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 2 | 1 | 1 | 1 | 1 |
100の位の余りが4
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 3 | 1 | 1 | 1 | 1 | 1 | 1 |
ここまでくるとだいたい何をやってるか分かってきましたね。
最後のループで一の位を確定します。
やることは同じで、10の位の余りが0~6のパターンで一の位(右から1桁目)が1, 4, 9の場合の余りを求めます。
10の位の余りが0
これはループの1回目と同様ですが、右から2桁目までにこのパターンになる数が3つあることがすでに分かっています。
従って一の位については、最終的に余りが[1, 2, 4]になる数がそれぞれ3つずつあることになります。
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 3 | 3 | 3 |
10の位の余りが1
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 1 | 3 | 3 | 4 | 1 |
10の位の余りが2
余り | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
数字の個数(累計) | 2 | 4 | 3 | 1 | 4 | 1 |
…ということを繰り返して、最終的にdp[N][0]
が、7の倍数になる数字の個数ということになります。
桁DP、完全に理解した!
解説②
桁DPの基本的な実装ができても、小課題2,3は未だACできていません。
そこで2つ目の解説を見てみます。
想定コード
kyopro_educational_90/005-02.cpp at main · E869120/kyopro_educational_90
Pythonで書き直したのが以下です。
mod = 1_000_000_007 def mat_mul(a, b): m = [[0]*len(a) for _ in range(len(b[0]))] *b, = zip(*b) for i, r in enumerate(m): for j in range(len(r)): r[j] = sum([x * y % mod for x, y in zip(a[i], b[j])]) r[j] %= mod return m def mat_pow(a, n): ret = [[0]*len(a) for _ in range(len(a))] for i in range(len(a)): ret[i][i] = 1 while n: if n & 1: ret = mat_mul(ret, a) a = mat_mul(a, a) n >>= 1 return ret def kaisetsu2(N, B, K, c): A = [[0]*B for _ in range(B)] for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return mat_pow(A, N)[0][0] # 5は終わらない test_all(kaisetsu2, 4)結果
Subtask2が AC になった!
解説の理解
今回の解説については、1つ目が分かっていれば(自力で思いつくかはともかく)理解は容易でしょう。
実装について、行列の演算となるとnumpy
を使いたくなります。特に累乗となるとnumpy.linalg.matrix_powerが使えそう。
ですが、実際に使ってみると計算途中で数字がオーバーフローしてしまいます。
オーバーフロー例①
import numpy as np mod = 1_000_000_007 def kaisetsu2_1(N, B, K, c): A = np.zeros((B, B), dtype=np.int64) for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return (np.linalg.matrix_power(A, N)[0][0]) % mod # 5は終わらない test_all(kaisetsu2_1, 4)オーバーフロー例②
import numpy as np mod = 1_000_000_007 def mat_pow(a, n): ret = np.eye(*a.shape, dtype=a.dtype) while n: if n & 1: ret = ret @ a % mod a = np.linalg.matrix_power(a, 2) n >>= 1 return ret def kaisetsu2_2(N, B, K, c): A = np.zeros((B, B), dtype=np.uint64) for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return mat_pow(A, N)[0][0] # 5は終わらない test_all(kaisetsu2_2, 4)結局、以下のように行列の乗算の過程で余りを取る必要があり、もはやnumpy
を使う必然性がありません。
解説③
最後に追加の制約なし(小課題3)でACとなるための解説です。
想定コード
kyopro_educational_90/005-03.cpp at main · E869120/kyopro_educational_90
Pythonで書き直したのが以下です。
というか けんちょんさんの記事も参考に大分変更を加えました。
結果
長かった…
解説の理解
ダブリングの考え方についても、けんちょんさんの記事(同上)が詳しいです。(丸投げ)
競プロ典型 90 問 005 - Restricted Digits(★7) - けんちょんの競プロ精進記録
まとめ
理解はできた(と思う…)が、自力で解ける気がしない😇