后缀自动机题目汇总

后缀自动机题目汇总

[toc]

后缀自动机笔记

后缀自动机模板

01 求一个串的本质不同的子串的个数

1.1 解题思路

在这个字符串上创建SAM,然后统计每个节点代码的字符串的个数。对于节点i而言,其代表的字符串的个数为其最长长度减去后缀链接指向的节点代表的最长长度。

$$tr[i].len - tr[tr[i].fa].len$$

其实也可以找到所有后缀链接构成的树结构,然后找到入度为零的点,将这些点的len相加即可。

1.2 时空复杂度

时间复杂度$O(n)$,空间复杂度为$O(n)$,$n$为字符串长度。

1.3 C++代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/**
*
* 后缀自动机模板
* https://codeforces.com/edu/course/2/lesson/2/5/practice/contest/269656/problem/A
*
*
*/

#include <iostream>
#include <vector>
using namespace std;

class SAM {
public:

typedef unsigned long long ULL;

SAM(string &s) : tr(vector<Node>(MAX_LEN << 1)),
st(vector<bool>(MAX_LEN)),
idx(1),
last(1) {
for (auto x : s) {
_extend(x - 'a');
}
}

ULL getSubstringNum() {

return _dfs(1);
}


private:

static const int MAX_LEN = 1e6 + 5;

class Node {
public:
int len, fail;
int ch[26];
};

vector<Node> tr;
int idx, last;
vector<bool> st;

// 增量的方式构造SAM
void _extend(int c) {

int p = last, np = last = ++idx;
tr[np].len = tr[p].len + 1;
// 沿着faile指针向前不断扩充新增加的子串,
// 如果发现状态没有在状态机中则练一个到最后的状态的Trie边
while (p && !tr[p].ch[c]) tr[p].ch[c] = np, p = tr[p].fail;
// p 为零,说明所有的都只能通过新增的方式来做
if (!p) tr[np].fail = 1;
// 发现某个状态在扩充的时候扩充到了状态机已经存的的一个状态
else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) tr[np].fail = q; // 如果这个状态代表的最大长度刚好是新增加的,则直接将末尾的fa指针指向它即可。
else { // 否则的话,需要分裂节点
int nq = ++ idx;
tr[nq] = tr[q];
tr[q].fail = nq; // 维护faile指针
tr[np].fail = nq;
tr[nq].len = tr[p].len + 1; // 新创建的节点,且强制让他代表新加入的字符扩展得到的字传的长度
// 将前面到q的边都转移到nq
while (p && tr[p].ch[c] == q) tr[p].ch[c] = nq, p = tr[p].fail;
}
}
}

ULL _dfs(int cur) {
ULL ans = tr[cur].len - tr[tr[cur].fail].len;
st[cur] = true;
for (int i = 0; i < 26; i ++) {
int t = tr[cur].ch[i];
if (t && !st[t]) ans += _dfs(t);
}
return ans;
}
};

int main() {

ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);

string s;
cin >> s;
SAM sam(s);
cout << sam.getSubstringNum();
return 0;

}

02. 统计第一题中每个不同的子串出现的次数

2.1 解题思路

利用endpos的定义,直接在后缀树上进行拓扑遍历,同时注意是前缀的状态需要在构造SAM的时候额外加1.

2.2 时空复杂度

时间复杂度$O(n)$,空间复杂度$O(n)$,$n$为字符串长度

2.3 C++代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/**
* 后缀自动机模板, 类实现
* https://www.acwing.com/problem/content/2768/
*/

#include <iostream>
#include <vector>
using namespace std;

const int ALPHA_SIZE = 26; // 字符集大小

class Node {
public:
int len, fail; // 存储当前状态代表的最长子串的长度,faile指针
int ch[ALPHA_SIZE]; // 存储孩子指针
int num; // 存储当前节点代表的字符串在原始字符串中出现的次数
};

class SAM {
public:

typedef unsigned long long ULL;

SAM(string &s) : tr(vector<Node>(MAX_LEN << 1)),
st(vector<bool>(MAX_LEN << 1)),
idx(1),
last(1),
h(vector<int>(MAX_LEN << 1, -1)),
e(vector<int>(MAX_LEN << 1)),
ne(vector<int>(MAX_LEN << 1)),
g_idx(0) {
for (auto x : s) {
_extend(x - 'a');
}
}

//求本质不同的子串的数量
ULL getSubstringNum() {
return _dfs(1);
}

// 跑完之后,可以求出每个状态代表的子串出现的次数,
// 一个状态代表的所有子串在原串中出现的次数是相同的
ULL getNumForEverySubString() {
_build_tree();
_dfs2(1);
}

// 得到一个子串长度和其出现次数乘积的最大值,其中子串出现次数要大于1
ULL getMaxLengthTimesOccurTime() {
getNumForEverySubString();
ULL ans = 0;
for (int i = 1; i <= idx; i ++) {
if (tr[i].num == 1) continue;
ans = max(ans, (ULL)(tr[i].len) * tr[i].num);
}
return ans;
}

private:

static const int MAX_LEN = 1e6 + 5;

vector<Node> tr;
int idx, last; // 表明当前用到了哪个点,以及上一个代表字符串末尾的状态的位置
vector<bool> st; // 用于dfs时判断是否遍历到了
vector<int> h, e, ne; // 用于在fail指针上建反边从而得到一棵树
int g_idx;

// 增量的方式构造SAM
void _extend(int c) {

int p = last, np = last = ++idx;
tr[np].len = tr[p].len + 1;
tr[np].num = 1;
// 沿着faile指针向前不断扩充新增加的子串,
// 如果发现状态没有在状态机中则练一个到最后的状态的Trie边
while (p && !tr[p].ch[c]) tr[p].ch[c] = np, p = tr[p].fail;
// p 为零,说明所有的都只能通过新增的方式来做
if (!p) tr[np].fail = 1;
// 发现某个状态在扩充的时候扩充到了状态机已经存的的一个状态
else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) tr[np].fail = q; // 如果这个状态代表的最大长度刚好是新增加的,则直接将末尾的fa指针指向它即可。
else { // 否则的话,需要分裂节点
int nq = ++ idx;
tr[nq] = tr[q];
tr[nq].num = 0; // 不能把num也给复制了,新开的num要为零
tr[q].fail = nq; // 维护faile指针
tr[np].fail = nq;
tr[nq].len = tr[p].len + 1; // 新创建的节点,且强制让他代表新加入的字符扩展得到的字传的长度
// 将前面到q的边都转移到nq
while (p && tr[p].ch[c] == q) tr[p].ch[c] = nq, p = tr[p].fail;
}
}
}

// 求本质不同的字符串的个数
ULL _dfs(int cur) {
ULL ans = tr[cur].len - tr[tr[cur].fail].len;
st[cur] = true;
for (int i = 0; i < ALPHA_SIZE; i ++) {
int t = tr[cur].ch[i];
if (t && !st[t]) ans += _dfs(t);
}
return ans;
}

void _add(int a, int b) {
e[g_idx] = b, ne[g_idx] = h[a], h[a] = g_idx ++;
}

// 建反向树
void _build_tree() {
for (int i = 2; i <= idx; i ++) _add(tr[i].fail, i);
}

// 用于求每个节点代表的字符串出现的次数
void _dfs2(int cur) {
for (int i = h[cur]; ~i; i = ne[i]) {
int node = e[i];
_dfs2(node);
tr[cur].num += tr[node].num;
}
}
};

int main() {

ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);

string s;
cin >> s;
SAM sam(s);
cout << sam.getMaxLengthTimesOccurTime();
return 0;

}

03. 求多个字符串的最长公共子串

3.1 解题思路

我们在其中一个字符串s上建立SAM,然后依次遍历其它字符串t,求出ts的每个状态出现的最长子串的长度。同时注意在求完之后,每个节点的信息还需要向其后缀链接的父节点串,因为更长的串也同样包含了它的后缀。

再对每个串的每个状态取一个最小值。处理完所有串之后对所有状态再取一个最大值即可。

3.2 时空复杂度

时间复杂度:假设字符串数量为a,每个字符串的长度为n,字符串总长度为m

构造SAM$O(n)$

每个字符串在SAM上进行跳转的时候,每次最多前进一次,往后跳的次数不会超过向前跳的次数,所以时间复杂度为$O(n)$。

在每次循环中,需要遍历一下SAM的所有状态,为$O(an)$

所以总的时间复杂度为$O(an)$。

空间复杂度为$O(n)$

3.3 C++代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <iostream>
#include <cstring>
using namespace std;

const int N = 1e4 + 5;

struct Node {
int len, fa;
int ch[26];
}tr[N << 1];

int last = 1, tot = 1;

int ans[N << 1], now[N << 1];
int h[N << 1], e[N << 1], ne[N << 1], idx;

void extend(int c) {

int p = last, np = last = ++ tot;
tr[np].len = tr[p].len + 1;
while (p && !tr[p].ch[c]) tr[p].ch[c] = np, p = tr[p].fa;
if (!p) tr[np].fa = 1;
else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) tr[np].fa = q;
else {
int nq = ++ tot;
tr[nq] = tr[q];
tr[nq].len = tr[p].len + 1;
tr[q].fa = tr[np].fa = nq;
while (p && tr[p].ch[c] == q) tr[p].ch[c] = nq, p = tr[p].fa;
}
}
}

void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

void dfs(int cur) {

for(int i = h[cur]; ~i; i = ne[i]) {
dfs(e[i]);
now[cur] = max(now[cur], now[e[i]]);
}
}


int main() {

ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);


int n;
string s;
cin >> n;
cin >> s;
for (auto x : s) extend(x - 'a');

for (int i = 1; i <= tot; i ++) ans[i] = tr[i].len;
memset(h, -1, sizeof h);
for (int i = 2; i <= tot; i ++) add(tr[i].fa, i);
while ( -- n) {
cin >> s;
for (int i = 1; i <= tot; i ++) now[i] = 0;
int cur_loc = 1, cur_len = 0;
for (auto x : s) {
int c = x - 'a';
while (cur_loc != 1 && !tr[cur_loc].ch[c]) cur_loc = tr[cur_loc].fa, cur_len = tr[cur_loc].len;
if (tr[cur_loc].ch[c]) cur_loc = tr[cur_loc].ch[c], cur_len ++;
now[cur_loc] = max(now[cur_loc], cur_len);
}
dfs(1);
for (int i = 1; i <= tot; i ++) ans[i] = min(ans[i], now[i]);
}
int res = 0;
for (int i = 1; i <= tot; i ++) {
res = max(res, ans[i]);
}
cout << res;
return 0;

}

04 其它字符串出现在某个字符串的最长前缀

4.1 解题思路

SAM裸题,主串建立SAM,然后其它字符串在SAM上跑即可,直到到了一个不能继续走的地方。

4.2 时空复杂度

主串长度为$n$,要匹配的串的总长度为$m$,

空间复杂度$O(n)$

时间复杂度$O(n + m)$

4.3 C++代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
using namespace std;

int n, m;

constexpr int N = 1e7 + 5;

inline
int get(char c) {
if (c == 'E') return 0;
else if (c == 'S') return 1;
else if (c == 'W') return 2;
else return 3;
}

struct Node {
int fa, len;
int ch[4];
}tr[N << 1];
int last = 1, tot = 1;

void extend(int c) {

int p = last , np = last = ++tot;
tr[np].len = tr[p].len + 1;
while (p && !tr[p].ch[c]) tr[p].ch[c] = np, p = tr[p].fa;
if (!p) tr[np].fa = 1;
else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) tr[np].fa = q;
else {
int nq = ++tot;
tr[nq] = tr[q];
tr[nq].len = tr[p].len + 1;
tr[q].fa = tr[np].fa = nq;
while (p && tr[p].ch[c] == q) tr[p].ch[c] = nq, p = tr[p].fa;
}
}
}

int main() {

ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);

cin >> n >> m;
string s;
cin >> s;
for (auto x : s) {
extend(get(x));
}
while (m --) {
string p;
cin >> p;
int cnt = 0, cur = 1;
for (auto x : p) {
if (tr[cur].ch[get(x)]) cur = tr[cur].ch[get(x)], cnt ++;
else break;
}
cout << cnt << '\n';
}
return 0;

}