[AGC021F]Trinity

Description

$$
1 \le n \le 8000, 1 \le m \le 300
$$

Solution

首先可以知道对列的限制更紧,所以可以考虑围绕列来建立转移方程。

设 $f(i, j)$ 表示 $i$ 行 $j$ 列的矩阵,每行必染了一种颜色情况下,产生的 $A, B, C$ 的方案数。最终答案即 $\sum _{i = 0} ^n {n \choose i} f(i, m)$。

每次考虑新加一列会造成什么影响,也即将 $f(i, j)$ 转移到 $f(i + t, j + 1)$,其中 $t$ 表示出现在这一列的不属于前 $i$ 行的黑色格子数目。

(注意这里的加入是将这 $t$ 行插入到 $i$ 行里,而不是直接放到后面)

当 $t = 0$ 时,没有新加入行,所以相当于是第 $j + 1$ 列的格子可以任意染色,然后问不同的 $B _{j + 1}, C _{j + 1}$ 的方案数。当然是 $1 + i + {i \choose 2}$,分别是不染,左右端点重合和左右端点任选三种情况。(这里的左右端点分别指的是 $B, C$)

当 $t > 0$ 时,考虑建立两个虚行放在这加入的 $t$ 行的首尾端,再去将这 $t + 2$ 行分配到 $i + t + 2$ 行中去。那么在最终的分配结果中,首虚行的下一行代表着左端点,尾虚行的上一行代表着右端点。可以知道这样放出来的方案可以一一对应,故方案数为 ${i + t + 2 \choose t + 2}$。

于是可以写出转移式子。观察到是一个卷积的形式,NTT 优化即可。

$O(nm \log n)$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <cstdio>
#include <cstring>
#include <iostream>

#define Dec(x) (x >= mod ? x -= mod : 0)
#define Inc(x) (x < 0 ? x += mod : 0)
#define log2(x) (31 - __builtin_clz(x))

using namespace std;

typedef long long LL;

const int maxN = 8e3 + 5;
const int maxM = 205;
const int mod = 998244353;

int n, m;
int fac[maxN], ifac[maxN];
int f[maxN], g[maxN * 2], h[maxN];

int FPM(int bas, int ind)
{
int res = 1;
while (ind)
{
if (ind & 1)
res = (LL)res * bas % mod;
bas = (LL)bas * bas % mod;
ind >>= 1;
}
return res;
}

namespace NTT
{
const int maxN = 1 << (log2(::maxN) + 2) | 1;

int omg[maxN], rev[maxN];

void NTT(int* arr, int lim, int fg)
{
for (int i = 1; i < lim; ++i)
if (rev[i] < i)
swap(arr[i], arr[rev[i]]);
for (int t = 2; t <= lim; t <<= 1)
{
int m = t >> 1, coe = lim / t;
for (int* p = arr; p != arr + lim; p += t)
for (int k = 0; k < m; ++k)
{
int tmp = LL(fg == 1 ? omg[coe * k] : omg[lim - coe * k]) * p[k + m] % mod;
p[k + m] = p[k] - tmp, Inc(p[k + m]);
p[k] += tmp, Dec(p[k]);
}
}
if (fg == -1)
{
int invlim = FPM(lim, mod - 2);
for (int i = 0; i < lim; ++i)
arr[i] = (LL)arr[i] * invlim % mod;
}
}

void Mul(const int* A, int n, const int* B, int m, int* C)
{
static int _A[maxN], _B[maxN];

int lim = 1 << (log2(n + m) + 1);
memcpy(_A, A, n * sizeof(int));
memcpy(_B, B, m * sizeof(int));
memset(_A + n, 0, (lim - n) * sizeof(int));
memset(_B + m, 0, (lim - m) * sizeof(int));
NTT(_A, lim, 1), NTT(_B, lim, 1);
for (int i = 0; i < lim; ++i)
_A[i] = (LL)_A[i] * _B[i] % mod;
NTT(_A, lim, -1);
memcpy(C, _A, (n + m - 1) * sizeof(int));
}

void Init(int n, int m)
{
int lim = 1 << (log2(n + m) + 1);
for (int i = 1, bit = log2(lim); i < lim; ++i)
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (bit - 1);
omg[0] = omg[lim] = 1;
omg[1] = FPM(3, (mod - 1) / lim);
for (int i = 2; i < lim; ++i)
omg[i] = (LL)omg[i - 1] * omg[1] % mod;
}
}

inline int C(int _n, int _m)
{
if (_n < 0 or _m < 0 or _n < _m)
return 0;
return (LL)fac[_n] * ifac[_m] % mod * ifac[_n - _m] % mod;
}

int main()
{
freopen("matrix.in", "r", stdin);
freopen("matrix.out", "w", stdout);
ios::sync_with_stdio(false);
cin >> n >> m;
f[0] = 1;
fac[0] = ifac[0] = 1;
for (int i = 1; i <= n + 2; ++i)
fac[i] = (LL)fac[i - 1] * i % mod;
ifac[n + 2] = FPM(fac[n + 2], mod - 2);
for (int i = n + 1; i; --i)
ifac[i] = ifac[i + 1] * LL(i + 1) % mod;
NTT::Init(n + 1, n + 1);
for (int o = 1; o <= m; ++o)
{
for (int i = 0; i <= n; ++i)
g[i] = (LL)f[i] * ifac[i] % mod, h[i] = ifac[i + 2];
NTT::Mul(g, n + 1, h, n + 1, g);
for (int i = 1; i <= n; ++i)
{
int tmp = f[i];
f[i] = g[i] - (LL)f[i] * ifac[i] % mod * ((mod + 1) >> 1) % mod;
f[i] = (LL)f[i] * fac[i + 2] % mod;
f[i] += tmp * (i * LL(i + 1) * ((mod + 1) >> 1) % mod + 1) % mod;
Inc(f[i]), Dec(f[i]);
}
}
int ans = 0;
for (int i = 0; i <= n; ++i)
ans += (LL)f[i] * C(n, i) % mod, Dec(ans);
cout << ans << endl;
return 0;
}