米田氏が企画した「競プロ典型90問」を解いていきます。

ここでは私の試行錯誤やら解説の理解やら気づきやらをメモしていこうと思います。 競プロ初心者なのでどうか温かい目で見守ってください。

なお、本記事中のPythonコードはブラウザ上で実行可能です。 よろしければ遊んでいってください。

問題

kyopro_educational_90/005.jpg at main · E869120/kyopro_educational_90

twitter

try(1st.)

  • $c$で作れる$N$桁の数字の範囲から、$B$の倍数を探して、それが$c$で作れるか確かめればよくない?

という訳で実装したのが以下。

def main(N, B, K, c): *c, = map(str, c) xmin = int(min(c) * N) xmax = int(max(c) * N) d, m = divmod(xmin, B) imin = d+1 if m else d imax = xmax // B cnt = 0 for i in range(imin, imax+1): target = set(str(B*i)) while len(target): if not target.pop() in c: break else: cnt += 1 return cnt % (10**9 + 7)

テストデータ・テスト関数定義↓

# 縮小表示 test_data = [ { "in":[3, 7, 3, [1, 4, 9]], "out": 3 },{ "in":[5, 2, 3, [1, 4, 9]], "out": 81 },{ "in":[10000, 27, 7, [1, 3, 4, 6, 7, 8, 9]], "out": 989112238 },{ "in":[1000000000000000000, 29, 6, [1, 2, 4, 5, 7, 9]], "out": 853993813 },{ "in":[1000000000000000000, 957, 7, [1, 2, 3, 5, 6, 7, 9]], "out": 205384995 }, ] def test_all(f, n=float("inf")): for i, data in enumerate(test_data): if i >= n: break exp = data["out"] ans = f(*data["in"]) result = "AC" if exp == ans else "WA" print(f"{i+1} {result}: expected: {exp}, output: {ans}") # 3つ目以降のケースでは終わらない test_all(main, 2)

あかんわ。
確実に駄目だろうけど提出してみます。

結果(1st.)

やっぱりあかん…

提出 #32501844 - 競プロ典型 90 問

解説①

手も足も出ないので大人しく解説・想定コードを確認します。
今回は3段階の解説になっているので、一つずつ見ていきましょう。

kyopro_educational_90/005-01.jpg at main · E869120/kyopro_educational_90

想定コード

kyopro_educational_90/005-01.cpp at main · E869120/kyopro_educational_90

Pythonで書き直したのが以下です。

def kaisetsu1(N, B, K, c): mod = 1_000_000_007 dp = [[0]*33 for _ in range(10009)] dp[0][0] = 1 for i in range(N): for j in range(B): for k in c: nex = (10 * j + k) % B dp[i+1][nex] += dp[i][j] dp[i+1][nex] %= mod return dp[N][0] # 4以降は終わらない test_all(kaisetsu1, 3)

結果

Subtask1 AC になった!

提出 #32502123 - 競プロ典型 90 問

ちなみにREの内容はlist index out of rangeで、dp配列のサイズを小課題1の制約(10000 × 30)に合わせているからです。
ではdp = [[0]*B for _ in range(N+1)]とするとどうなるかというと、タイムアウトになります。

解説の理解

正直、解説を読んでも桁DPでググってもよく分かりませんでした😇
なので想定コードを解読しながら何をやっているのか理解しようと試みます。

1つ目のテストケースを例に考えます。
1, 4, 9だけを使った3桁の数の中で、7の倍数はいくつあるか。

N, B, K, c = test_data[0]["in"] N, B, K, c

まずはdp配列を作ります。
ここでdp[0][0] = 1としている理由はすぐに明らかになります。

mod = 1_000_000_007 dp = [[0]*B for _ in range(N+1)] dp[0][0] = 1 from pprint import * pp(dp)

最も外側のループを1回ずつ回して確認していきます。
ここで桁数だけ繰り返しているので、「各桁についてなんかやってんだな」と推測できます。

# for i in range(N): i = 0 for j in range(B): for k in c: nex = (10 * j + k) % B dp[i+1][nex] += dp[i][j] dp[i+1][nex] %= mod pp(dp)

さて、ここで何をやっているのかというと、以下のように100の位が1, 4, 9だった場合の余りを求めているのです。

という訳で、dp配列のインデックス1は100の位(右から3桁目)について、余りが0~6の場合の個数を持ってるんだなと分かります。

そしてdp[0][0] = 1の理由は、100の位を考えるときに1,000の位の余りは0であることを表しています。
ここらへんがまだちょっと分かりづらいのでループをもう一周回しましょう。

# for i in range(N): i = 1 for j in range(B): for k in c: nex = (10 * j + k) % B dp[i+1][nex] += dp[i][j] dp[i+1][nex] %= mod pp(dp)

上で見た通り、100の位の余りは[1, 2, 4]の3パターンが1つずつでした。
それぞれのパターンについて、今度は10の位(右から2桁目)が1, 4, 9の場合の余りを求めます。

100の位の余りが1

余り    0 1 2 3 4 5 6
数字の個数(累計) 1 1 1

100の位の余りが2

余り    0 1 2 3 4 5 6
数字の個数(累計) 2 1 1 1 1

100の位の余りが4

余り    0 1 2 3 4 5 6
数字の個数(累計) 3 1 1 1 1 1 1

ここまでくるとだいたい何をやってるか分かってきましたね。
最後のループで一の位を確定します。

# for i in range(N): i = 2 for j in range(B): for k in c: nex = (10 * j + k) % B dp[i+1][nex] += dp[i][j] dp[i+1][nex] %= mod pp(dp)

やることは同じで、10の位の余りが0~6のパターンで一の位(右から1桁目)が1, 4, 9の場合の余りを求めます。

10の位の余りが0
これはループの1回目と同様ですが、右から2桁目までにこのパターンになる数が3つあることがすでに分かっています。
従って一の位については、最終的に余りが[1, 2, 4]になる数がそれぞれ3つずつあることになります。

余り    0 1 2 3 4 5 6
数字の個数(累計) 3 3 3

10の位の余りが1

余り    0 1 2 3 4 5 6
数字の個数(累計) 1 3 3 4 1

10の位の余りが2

余り    0 1 2 3 4 5 6
数字の個数(累計) 2 4 3 1 4 1

…ということを繰り返して、最終的にdp[N][0]が、7の倍数になる数字の個数ということになります。

dp[N][0]

桁DP、完全に理解した!

解説②

桁DPの基本的な実装ができても、小課題2,3は未だACできていません。
そこで2つ目の解説を見てみます。

kyopro_educational_90/005-02.jpg at main · E869120/kyopro_educational_90

想定コード

kyopro_educational_90/005-02.cpp at main · E869120/kyopro_educational_90

Pythonで書き直したのが以下です。

mod = 1_000_000_007 def mat_mul(a, b): m = [[0]*len(a) for _ in range(len(b[0]))] *b, = zip(*b) for i, r in enumerate(m): for j in range(len(r)): r[j] = sum([x * y % mod for x, y in zip(a[i], b[j])]) r[j] %= mod return m def mat_pow(a, n): ret = [[0]*len(a) for _ in range(len(a))] for i in range(len(a)): ret[i][i] = 1 while n: if n & 1: ret = mat_mul(ret, a) a = mat_mul(a, a) n >>= 1 return ret def kaisetsu2(N, B, K, c): A = [[0]*B for _ in range(B)] for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return mat_pow(A, N)[0][0] # 5は終わらない test_all(kaisetsu2, 4)

結果

Subtask2 AC になった!

提出 #32514921 - 競プロ典型 90 問

解説の理解

今回の解説については、1つ目が分かっていれば(自力で思いつくかはともかく)理解は容易でしょう。

実装について、行列の演算となるとnumpyを使いたくなります。特に累乗となるとnumpy.linalg.matrix_powerが使えそう。
ですが、実際に使ってみると計算途中で数字がオーバーフローしてしまいます。

オーバーフロー例①

import numpy as np mod = 1_000_000_007 def kaisetsu2_1(N, B, K, c): A = np.zeros((B, B), dtype=np.int64) for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return (np.linalg.matrix_power(A, N)[0][0]) % mod # 5は終わらない test_all(kaisetsu2_1, 4)

オーバーフロー例②

import numpy as np mod = 1_000_000_007 def mat_pow(a, n): ret = np.eye(*a.shape, dtype=a.dtype) while n: if n & 1: ret = ret @ a % mod a = np.linalg.matrix_power(a, 2) n >>= 1 return ret def kaisetsu2_2(N, B, K, c): A = np.zeros((B, B), dtype=np.uint64) for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return mat_pow(A, N)[0][0] # 5は終わらない test_all(kaisetsu2_2, 4)

結局、以下のように行列の乗算の過程で余りを取る必要があり、もはやnumpyを使う必然性がありません。

import numpy as np mod = 1_000_000_007 def mat_mul(a, b): ret = np.zeros((a.shape[0], b.shape[1]), dtype=a.dtype) for i in range(a.shape[0]): for j in range(b.shape[1]): ret[i][j] = sum(a[i] * b[:, j] % mod) return ret % mod def mat_pow(a, n): ret = np.eye(*a.shape, dtype=a.dtype) while n: if n & 1: ret = mat_mul(ret, a) a = mat_mul(a, a) n >>= 1 return ret def kaisetsu2_3(N, B, K, c): A = np.zeros((B, B), dtype=np.uint64) for i in range(B): for j in c: nex = (i*10 + j) % B A[i][nex] += 1 return mat_pow(A, N)[0][0] # 5は終わらない test_all(kaisetsu2_3, 4)

提出 #32513582 - 競プロ典型 90 問

解説③

最後に追加の制約なし(小課題3)でACとなるための解説です。

kyopro_educational_90/005-03.jpg at main · E869120/kyopro_educational_90

想定コード

kyopro_educational_90/005-03.cpp at main · E869120/kyopro_educational_90

Pythonで書き直したのが以下です。
というか けんちょんさんの記事も参考に大分変更を加えました。

from itertools import accumulate def kaisetsu3(N, B, K, c): MOD = 1000000007 # python3.8以降 # *ten, = accumulate(range(62), lambda acc, _: (acc ** 2) % B , initial=10) # python3.7以前 *ten, = accumulate(range(10, 72), lambda acc, _: (acc ** 2) % B) def mul(x, exp, y): res = [0] * B for p in range(B): for q in range(B): nex = (ten[exp] * p + q) % B res[nex] += x[p] * y[q] res[nex] %= MOD return res # 初期化 ## doubled:1桁目(2^0)を求める doubled = [0] * B for k in c: doubled[k%B] += 1 ## ans:Nの1ビット目を見る if N & 1: ans = doubled else: ans = [0] * B ans[0] = 1 N >>= 1 # Nの2ビット目以降見ながらダブリング実施 exp = 0 while N: # 2^epx桁目から2^(exp+1)桁目を計算 doubled = mul(doubled, exp, doubled) exp += 1 if N & 1: ans = mul(ans, exp, doubled) N >>= 1 return ans[0] # 5つ目のケース(N = 10^18)は、この環境ではめっちゃ時間掛かる # test_all(kaisetsu3, 5) test_all(kaisetsu3, 4)

結果

AC!
提出 #32588657 - 競プロ典型 90 問

長かった…

解説の理解

ダブリングの考え方についても、けんちょんさんの記事(同上)が詳しいです。(丸投げ)

競プロ典型 90 問 005 - Restricted Digits(★7) - けんちょんの競プロ精進記録

まとめ

理解はできた(と思う…)が、自力で解ける気がしない😇

自由欄