[BZOJ 2160] 拉拉队排练

题目大意

给定一个长为 $n$ 的字符串(仅由 $26$ 个小写字母组成),找出其所有长度为奇数的回文子串,按长度降序排列,求前 $k$ 个串的长度之积,不足 $k$ 个时输出 $-1$ 。

$1 \leqslant n \leqslant 1,000,000$

$1 \leqslant k \leqslant 1,000,000,000,000$

题目链接

BZOJ 2160

题解

回文树裸题。

对串建完回文树后,取出所有的奇数长节点,排序后直接快速幂计算即可。

一开始没看见「奇数」,WA 了 $6$ 遍。。。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#include <new>
const int MAXN = 1000005;
const int CHAR_SET = 'z' - 'a' + 1;
const char BASE_CHAR = 'a';
const int MOD = 19930726;
struct PalinT {
int str[MAXN], size;
struct Node {
int len;
long long cnt;
Node *c[CHAR_SET], *fail;
Node(int len = 0) : len(len), cnt(0), fail(NULL) {
for (int i = 0; i < CHAR_SET; i++) c[i] = NULL;
}
} *root[2], *last, _pool[MAXN], *_curr;
PalinT() {
_curr = _pool;
root[0] = last = new (_curr++) Node(0);
root[0]->fail = root[1] = new (_curr++) Node(-1);
root[1]->fail = root[1];
str[size = 0] = -1;
}
Node *getFail(Node *u) {
while (str[size - u->len - 1] != str[size]) u = u->fail;
return u;
}
void insert(int c) {
str[++size] = c;
Node *o = getFail(last);
if (!o->c[c]) {
Node *u = (o->c[c] = new (_curr++) Node(o->len + 2));
u->fail = o == root[1] ? root[0] : getFail(o->fail)->c[c];
}
last = o->c[c];
last->cnt++;
}
void count() {
Node *p = _curr - 1;
for (; p != _pool; p--) p->fail->cnt += p->cnt;
p->fail->cnt += p->cnt;
}
} palinT;
long long pow(long long a, long long n) {
long long res = 1;
for (; n; n >>= 1, a = a * a % MOD) if (n & 1) res = res * a % MOD;
return res;
}
int main() {
int n;
long long k;
scanf("%d %lld\n", &n, &k);
for (int i = 0; i < n; i++) palinT.insert(getchar() - BASE_CHAR);
palinT.count();
static std::vector<std::pair<int, long long> > vec;
long long sumCnt = 0;
for (PalinT::Node *p = palinT._pool; p != palinT._curr; p++) {
if (p->len == 0 || p->len == -1) continue;
if (p->len % 2 == 0) continue;
vec.push_back(std::make_pair(p->len, p->cnt));
sumCnt += p->cnt;
}
if (k > sumCnt) return puts("-1"), 0;
std::sort(vec.begin(), vec.end());
long long ans = 1;
for (int i = vec.size() - 1; ~i; i--) {
if (k > vec[i].second) {
(ans *= pow(vec[i].first, vec[i].second)) %= MOD;
k -= vec[i].second;
} else {
(ans *= pow(vec[i].first, k)) %= MOD;
k = 0;
break;
}
}
printf("%lld\n", ans);
return 0;
}