test_dataA = [
{
"in":"""10
in that case you should print yes and not no
""",
"out": """Yes
"""
},{
"in":"""10
in diesem fall sollten sie no und nicht yes ausgeben
""",
"out": """No
"""
}
]def A():
pass
test_all(A)
def B():
from copy import deepcopy
from itertools import product
from string import digits
R, C = map(int, input().split())
B = [list(input()) for _ in range(R)]
def get_points(p, dist):
px, py = p
ret = set()
for x in range(dist + 1):
for y in range(dist + 1 - x):
ret.add((px + x, py + y))
ret.add((px - x, py + y))
ret.add((px + x, py - y))
ret.add((px - x, py - y))
return ret
ans = deepcopy(B)
for i, j in product(range(R), range(C)):
if B[i][j] in digits:
for x, y in get_points((i, j), int(B[i][j])):
if 0 <= x < R and 0 <= y < C:
ans[x][y] = "."
for a in ans:
print(*a, sep="")test_all(B)
def B():
from itertools import product
R, C = map(int, input().split())
B = [list(input()) for _ in range(R)]
for r in range(R):
for c in range(C):
ans = B[r][c]
for r2, c2 in product(range(R), range(C)):
if B[r2][c2].isdecimal() and abs(r -r2) + abs(c - c2) <= int(B[r2][c2]):
ans = "."
break
print(ans, end="")
print()test_all(B)
def D():
from collections import defaultdict
S = input()
mp = defaultdict(int)
cnt = [0] * 10
mp[tuple(cnt)] += 1
for s in S:
cnt[int(s)] ^= 1
mp[tuple(cnt)] += 1
res = 0
for v in mp.values():
res += v * (v - 1) // 2
print(res)test_all(D)
$N = M = 2000$で$A_i$が全て$0$の場合、列挙する組み合わせは$_{3999}C_{2000}$で$8 \times 10^{1202}$通りになってしまいます。
そりゃ無理だ😇
from math import factorialfactorial(3999)//factorial(2000)//factorial(1999)def E():
from collections import defaultdict
from functools import reduce
from itertools import combinations_with_replacement as cr
from math import factorial
from operator import mul
MOD = 998244353
N, M, K = map(int, input().split())
*A, = map(int, input().replace("0", "").split())
zeros = N - len(A)
ans = defaultdict(int)
for com in cr(range(1, M + 1), zeros):
# print(com)
ns = defaultdict(int)
for c in com:
ns[c] += 1
size = factorial(zeros) // reduce(mul, [factorial(v) for v in ns.values()])
# print(sorted(A + list(com)))
x = sorted(A + list(com))[K - 1]
ans[x] += size
ans[x] %= MOD
pat = sum([k * v % MOD for k, v in ans.items()]) % MOD
total = pow(pow(M, -1, MOD), zeros, MOD)
print(pat * total % MOD)
# print(ans)test_all(E)