题目大意
给定仅由数码组成的串 a 和非负整数 l 与 r,考察其的一个划分,若该划分中任意一个串都是一个正当整数(没有多余的前导零)且属于闭区间 [l, r],则我们称该划分为一个美丽划分。
求一共有多少个美丽划分。因为答案比较大,输出其对 998244353 取模的结果。
题目保证 a, l, r 的位数不超过 106 且 l 不超过 r 。
分析
这题很显然可以用动态规划的方法来解决,令 dp[i] 恰好在 i 位置右边结束某一串的方案数,则:
其中 valid(s) 表示子串 s 是否合法。一个比较显然的结论是如果当前串没有前导零且它的长度在 内那么该串一定合法,如果在
外则一定非法。对于长度等于边界的情况,我们需要比较字符串的大小。但是暴力的做法每次比较需要 O(n), 显然不能承受。
我的第一想法是用类似于后缀数组构造的步骤来排序所有长度为某个定值的字符串。然后获得了顺序以后再用我们对应的 l 或者 r 去求一个 lower_bound 或 upper_bound,这样的复杂度在 O(nlogn)。这道题时限只有 1s,所以很可能会超时。
然后我们发现如果我们的当前串与目标串相等那么当前一定可行。如果不等的话,设当前串为 s,目标串为 t,那么它们的大小关系就是他们第一个不等位置的大小关系。如果我们已经处理出了 a 中的所有长度为 |t| 的串与 t 的最长公共前缀长度,那么单次判断可以在 O(1) 时间完成。而这个预处理工作可以用构造 z 数组的方法在 O(|a| + |t|) 时间完成。复杂度为线性,可以接受。
另外注意特判 l 为 0 的情况。
代码
总复杂度为 O(|a| + |l| + |r|)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
#define chkmin(a, b) a = min(a, b)
#define chkmax(a, b) a = max(a, b)
const int maxn = 1123456;
const int MOD = 998244353;
char s[maxn * 2], a[maxn], lo[maxn], hi[maxn];
int z[maxn * 2], dp[maxn];
vector<int> in[maxn], out[maxn];
inline void upd(int &a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
void get_z(int n) {
int l = 0, r = 0;
z[0] = 0;
FOR(i, 1, n - 1) {
if (r >= i) {
if (z[i - l] + i <= r) z[i] = z[i - l];
else {
int nxt = r - i;
while (r < n && s[r] == s[nxt]) r++, nxt++;
r--;
z[i] = nxt;
l = i;
}
} else {
r = i;
int nxt = 0;
while (r < n && s[r] == s[nxt]) r++, nxt++;
l = i;
r--;
z[i] = nxt;
}
}
}
int main() {
scanf("%s%s%s", a + 1, lo + 1, hi + 1);
int n = strlen(a + 1), len_l = strlen(lo + 1), len_r = strlen(hi + 1);
strcpy(s, lo + 1);
s[len_l] = '$';
strcpy(s + len_l + 1, a + 1);
get_z(len_l + n + 1);
FOR(i, 1, n - len_l + 1) if (a[i] != '0') {
int idx = len_l + i;
if (z[idx] == len_l || a[z[idx] + i] > lo[z[idx] + 1])
in[i + len_l - 1].eb(i);
else in[i + len_l].eb(i);
}
strcpy(s, hi + 1);
s[len_r] = '$';
strcpy(s + len_r + 1, a + 1);
get_z(len_r + n + 1);
if (lo[1] != '0') {
FOR(i, 1, n - len_r + 1) if (a[i] != '0') {
int idx = len_r + i;
if (z[idx] == len_r || a[z[idx] + i] < hi[z[idx] + 1])
out[i + len_r].eb(i);
else out[i + len_r - 1].eb(i);
}
} else {
FOR(i, 1, n - len_r + 1) if (a[i] != '0') {
int idx = len_r + i;
if (z[idx] == len_r || a[z[idx] + i] < hi[z[idx] + 1])
out[i + len_r].eb(i);
else out[i + len_r - 1].eb(i);
}
FOR(i, 1, n) if (a[i] == '0') {
in[i].eb(i);
out[i + 1].eb(i);
}
}
dp[0] = 1;
int way = 0;
FOR(i, 1, n) {
for (auto it : in[i]) upd(way, dp[it - 1]);
for (auto it : out[i]) upd(way, MOD - dp[it - 1]);
dp[i] = way;
}
printf("%d", dp[n]);
}