常用数据结构

1,头文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <iostream>
#include <iomanip>
#include <cmath>
#include <string>
#include <vector>
#include <algorithm>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <deque>
#include <stdexcept>
#include <random>

2,自定义数据结构

1, 线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// 1>, 普通线段树(以加法为例)

class Linetree {
vector<int> data;
int n;

void init(vector<int>& nums,int l,int r,int id){
if(l>=r){
data[id] = nums[l];
return;
}
int mid = (l+r)>>1;
init(nums,l,mid,2*id+1);
init(nums,mid+1,r,2*id+2);
data[id] = data[2*id+1] + data[2*id+2];
}

void update(int i,int val, int l,int r,int id){
if(l>=r) {
data[id] = val;
return;
}
int mid = (l+r)>>1;
if(i<=mid) update(i,val,l,mid,2*id+1);
else update(i,val,mid+1,r,2*id+2);
data[id] = data[2*id+1] + data[2*id+2];
}

int dfs(int l,int r,int id,int ql,int qr){
if(ql<=l && r<=qr){
return data[id];
}

int mid = (l+r)>>1;
if(qr<=mid) return dfs(l,mid,2*id+1,ql,qr);
else if(ql>mid) return dfs(mid+1,r,2*id+2,ql,qr);
return dfs(l,mid,2*id+1,ql,qr) + dfs(mid+1,r,2*id+2,ql,qr);
}

public:
explicit Linetree(vector<int>& nums) : data(4*nums.size()),n(nums.size()){
init(nums,0,nums.size()-1,0);
}

void update(int i,int val){
update(i,val,0,n-1,0);
}

int query(int l,int r){
return dfs(0,n-1,0,l,r);
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// 2>, 查询单点右侧第一个比目标值大的下标(最大值线段树)

class Linetree {
vector<int> data;
int n;

void init(vector<int>& nums,int l,int r,int id){
if(l>=r){
data[id] = nums[l];
return;
}
int mid = (l+r)>>1;
init(nums,l,mid,2*id+1);
init(nums,mid+1,r,2*id+2);
data[id] = max(data[2*id+1],data[2*id+2]);
}

void update(int i,int val, int l,int r,int id){
if(l>=r) {
data[id] = val;
return;
}
int mid = (l+r)>>1;
if(i<=mid) update(i,val,l,mid,2*id+1);
else update(i,val,mid+1,r,2*id+2);
data[id] = max(data[2*id+1],data[2*id+2]);
}

int dfs(int l,int r,int id,int ql,int qr){
if(ql<=l && r<=qr){
return data[id];
}

int mid = (l+r)>>1;
if(qr<=mid) return dfs(l,mid,2*id+1,ql,qr);
else if(ql>mid) return dfs(mid+1,r,2*id+2,ql,qr);
return max(dfs(l,mid,2*id+1,ql,qr),dfs(mid+1,r,2*id+2,ql,qr));
}

int find(int l,int r,int id,int i,int val){
if(data[id]<=val) return -1;
if(l>=r) {
if(l>i) return l;
return -1;
}
int mid = (l+r)>>1;
if(i>=mid) return find(mid+1,r,2*id+2,i,val);
int li=find(l,mid,2*id+1,i,val);
if(li!=-1) return li;
return find(mid+1,r,2*id+2,i,val);
}

public:
explicit Linetree(vector<int>& nums) : data(4*nums.size()),n(nums.size()){
init(nums,0,nums.size()-1,0);
}

void update(int i,int val){
update(i,val,0,n-1,0);
}

int query(int l,int r){
return dfs(0,n-1,0,l,r);
} //普通·查询

int find(int i, int val){ //查找i右侧第一个比val大的下标
if(i>=n-1) return -1;
return find(0,n-1,0,i,val);
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 3>, 树状数组(线段树的简洁实现)

class BIT{
vector<int> data;
int n;

int min_bit(int x){
return x&-x;
}
public:
explicit BIT(int n) : data(n+1),n(n){
}

void update(int i,int val){
for(int j=i+1;j<=n;j+=min_bit(j)){
data[j]+=val;
}
}

int query(int i){
int res = 0;
for(int j=i+1;j>0;j-=min_bit(j)){
res+=data[j];
}
return res;
}

int query(int l,int r) {
return query(r) - query(l - 1);
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// 4>, 带懒节点区间更新的线段树

struct Node {
int val;
int mul;
int add;
explicit Node(int val=0,int mul=1,int add=0):val(val),mul(mul),add(add){}
};

class Linetree {
vector<Node> data;
int n;
const int mod=1e9+7;
void init(vector<int>& nums,int l,int r,int id){
if(l>=r){
data[id].val = nums[l];
return;
}
int mid = (l+r)>>1;
init(nums,l,mid,2*id+1);
init(nums,mid+1,r,2*id+2);
data[id].val = data[2*id+1].val + data[2*id+2].val;
}

void update(int i,int val, int l,int r,int id){
if(l>=r) {
data[id].val = val;
return;
}
push_down(id, l, r);
int mid = (l+r)>>1;
if(i<=mid) update(i,val,l,mid,2*id+1);
else update(i,val,mid+1,r,2*id+2);
data[id].val = data[2*id+1].val + data[2*id+2].val;
}

int dfs(int l,int r,int id,int ql,int qr){
if(ql<=l && r<=qr){
return data[id].val;
}

push_down(id,l,r);
int mid = (l+r)>>1;
if(qr<=mid) return dfs(l,mid,2*id+1,ql,qr);
else if(ql>mid) return dfs(mid+1,r,2*id+2,ql,qr);
return dfs(l,mid,2*id+1,ql,qr) + dfs(mid+1,r,2*id+2,ql,qr);
}

void push_down(int id,int l,int r){
if(data[id].mul!=1) {
data[2 * id + 1].val = 1ll * data[2 * id + 1].val * data[id].mul % mod;
data[2 * id + 2].val = 1ll * data[2 * id + 2].val * data[id].mul % mod;
data[2 * id + 1].mul = 1ll * data[2 * id + 1].mul * data[id].mul % mod;
data[2 * id + 2].mul = 1ll * data[2 * id + 2].mul * data[id].mul % mod;
data[2 * id + 1].add = 1ll * data[2 * id + 1].add * data[id].mul % mod;
data[2 * id + 2].add = 1ll * data[2 * id + 2].add * data[id].mul % mod;
data[id].mul = 1;
}

if(data[id].add!=0) {
int mid = (l + r) >> 1;

data[2 * id + 1].val = (1ll * data[2 * id + 1].val + data[2 * id + 1].add * (mid - l + 1)) % mod;
data[2 * id + 2].val = (1ll * data[2 * id + 2].val + data[2 * id + 2].add * (r - mid)) % mod;
data[2 * id + 1].add = (data[2 * id + 1].add + data[id].add) % mod;
data[2 * id + 2].add = (data[2 * id + 2].add + data[id].add) % mod;

data[id].add = 0;
}
}

void mul(int ml,int mr,int val, int l, int r,int id){
if(ml<=l && r<=mr){
data[id].val = 1ll * data[id].val * val % mod;
data[id].mul = 1ll * data[id].mul * val % mod;
data[id].add = 1ll * data[id].add * val % mod;
return;
}

push_down(id,l,r);
int mid = (l+r)>>1;
if(mr<=mid) mul(ml,mr,val,l,mid,2*id+1);
else if(ml>mid) mul(ml,mr,val,mid+1,r,2*id+2);
else{
mul(ml,mid,val,l,mid,2*id+1);
mul(mid+1,mr,val,mid+1,r,2*id+2);
}
data[id].val = data[2*id+1].val + data[2*id+2].val;
}

void add(int al,int ar,int val, int l, int r,int id){
if(al<=l && r<=ar){
data[id].val = (data[id].val + 1ll * val * (r-l+1)) % mod;
data[id].add = (data[id].add + val) % mod;
return;
}

push_down(id,l,r);
int mid = (l+r)>>1;
if(ar<=mid) add(al,ar,val,l,mid,2*id+1);
else if(al>mid) add(al,ar,val,mid+1,r,2*id+2);
else{
add(al,mid,val,l,mid,2*id+1);
add(mid+1,ar,val,mid+1,r,2*id+2);
}
data[id].val = data[2*id+1].val + data[2*id+2].val;
}

public:
explicit Linetree(vector<int>& nums) : data(4*nums.size()),n(nums.size()){
init(nums,0,nums.size()-1,0);
}

void addAll(int l,int r,int val){
add(l,r,val,0,n-1,0);
}

void mulAll(int l,int r,int val){
mul(l,r,val, 0,n-1,0);
}

void update(int i,int val){
update(i, val,0, n-1,0);
}

int query(int l,int r){
return dfs(0,n-1,0,l,r);
}
};

2, 字典树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// 1>, 普通字典树

struct TrieNode{
unordered_map<char,TrieNode*> children;
int isEnd;
TrieNode() {
isEnd = false;
}
};

class Trie{
TrieNode* root;

void dfsDelete(TrieNode* node){
for(auto& it:node->children){
dfsDelete(it.second);
}
delete node;
}
public:
Trie() {
root = new TrieNode();
}

~Trie() {
dfsDelete(root);
}

void insert(const string& word) {
TrieNode* node = root;
for(char c:word){
if(!node->children.count(c)){
node->children[c] = new TrieNode();
}
node = node->children[c];
}
node->isEnd = true;
}

bool search(const string& word) {
TrieNode* node = root;
for(char c:word){
if(!node->children.count(c)){
return false;
}
node = node->children[c];
}
return node->isEnd;
}

bool startsWith(const string& prefix) {
TrieNode* node = root;
for(char c:prefix){
if(!node->children.count(c)){
return false;
}
node = node->children[c];
}
return true;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// 2>, 高效查询前缀数量
struct TrieNode{
unordered_map<char,TrieNode*> children;
int cnt;
TrieNode() {
cnt = 0;
}
};

class Trie{
TrieNode* root;

void dfsDelete(TrieNode* node){
for(auto& it:node->children){
dfsDelete(it.second);
}
delete node;
}
public:
Trie() {
root = new TrieNode();
}

~Trie() {
dfsDelete(root);
}

void insert(const string& word) {
TrieNode* node = root;
for(char c:word){
if(!node->children.count(c)){
node->children[c] = new TrieNode();
}
node = node->children[c];
node->cnt++;
}
}

int search(const string& word) {
TrieNode* node = root;
for(char c:word){
if(!node->children.count(c)){
return 0;
}
node = node->children[c];
}
return node->cnt;
}

int startsWith(const string& prefix) {
TrieNode* node = root;
for(char c:prefix){
if(!node->children.count(c)){
return 0;
}
node = node->children[c];
}
return node->cnt;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// 3>, 二进制字符串

struct TrieNode{
vector<TrieNode*> children;
int cnt;
TrieNode() {
children.resize(2, nullptr);
cnt = 0;
}
};

class Trie{
TrieNode* root;
const int max_bit = 30;

void dfsDelete(TrieNode* node){
for(auto& child:node->children){
if(child) dfsDelete(child);
}
delete node;
}
public:
Trie() {
root = new TrieNode();
}

~Trie() {
dfsDelete(root);
}

void insert(int num) {
TrieNode* node = root;
for(int i=max_bit;i>=0;i--){
int bit = (num>>i)&1;
if(!node->children[bit]){
node->children[bit] = new TrieNode();
}
node = node->children[bit];
node->cnt++;
}
}

int search(int num) { // 与num异或最大的数
int ans=0;
TrieNode* node = root;
for(int i=max_bit;i>=0;i--){
int bit = (num>>i)&1;
if(node->children[bit^1]){
ans |= (1<<i);
node = node->children[bit^1];
}else {
node = node->children[bit];
}
}
return ans;
}
};

3, 并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// 由于并查集较单一,直接写完全结构
class Union {
vector<int> father;
vector<int> UnionSize;
unordered_set<int> fathers;
public:
Union(int n) : father(n),UnionSize(n){
for(int i=0;i<n;i++){
father[i] = i;
UnionSize[i] = 1;
fathers.insert(i);
}
}

int find(int x){
if(x==father[x]){
return x;
}
father[x] = find(father[x]);
return father[x];
}

void join(int x,int y){
int fx = find(x);
int fy = find(y);
if(fx!=fy){
if(UnionSize[fx]>UnionSize[fy]){
swap(fx,fy);
}
father[fx] = fy;
UnionSize[fy] += UnionSize[fx];
fathers.erase(fx);
}
}

bool isConnect(int x,int y){
return find(x)==find(y);
}

int getUnionSize(int x){
return UnionSize[find(x)];
}

unordered_set<int> getFathers(){
return fathers;
}
};

4, LCA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// 用于查询合法树的两个节点的最近的公共祖先
class LCA {
const int max_depth = 20;
vector<vector<int>> graph;
int n;
vector<int> depth;
vector<vector<int>> parents;


void dfs(int f,int cur,int deep){ // 获得深度和一级父节点
depth[cur] = deep;
for(auto &child:graph[cur]){
if(child!=f){
parents[child][0] = cur;
dfs(cur,child,deep+1);
}
}
}

void init(){ // 生成其他父节点
for(int i=1;i<max_depth;i++) {
for (int j = 0; j < n; j++) {
if(parents[j][i-1]!=-1){
parents[j][i]= parents[parents[j][i-1]][i-1];
}
}
}
}

public:
LCA(vector<vector<int>>& edge) : graph(edge.size()+1),n(edge.size()+1), depth(n),parents(n,vector<int>(max_depth, -1)) {
for(auto& e:edge){
graph[e[0]].push_back(e[1]);
graph[e[1]].push_back(e[0]);
}
dfs(-1,0,0);
init();
}

int lca(int x,int y){
if(depth[x]>depth[y]){
swap(x,y);
}


if(depth[x]<depth[y]){ // 将y调整到与x同一深度
for(int i=max_depth-1;i>=0;i--){
if(parents[y][i]!=-1&&depth[parents[y][i]]>=depth[x]){
y=parents[y][i];
}
}
}


if(x==y){
return x;
}

for(int i=max_depth-1;i>=0;i--) { // 查找x与y的最近公共祖先
if (parents[x][i] != parents[y][i]) {
x = parents[x][i];
y = parents[y][i];
}
}
return parents[x][0];
}
};

5, 排序数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// c++排序数组,支持下标版本
template <typename T>
struct SkipNode {
T value;
std::vector<SkipNode*> forward; // 各层的前进指针
std::vector<int> span; // 各层的跨度(当前节点到下一个节点的距离)


SkipNode() : value(T()), forward(1, nullptr), span(1, 0) {}
SkipNode(T val, int level) : value(val) {
forward.resize(level, nullptr);
span.resize(level, 0);
}
};

template<typename T>
class SortedArray {
const int maxLevel = 16;
const double p = 0.5;

SkipNode<T>* head;
int n;
int currentLevel;

int getRandomLevel() {
int level = 1;
while ((rand() / double(RAND_MAX)) < p && level < maxLevel) {
level++;
}
return level;
}
public:
SortedArray() : n(0), currentLevel(1) {
head = new SkipNode<T>(T(), maxLevel);
}
~SortedArray() {
SkipNode<T>* node=head;
while(node){
SkipNode<T>* tmp=node;
node=node->forward[0];
delete tmp;
}
}

void insert(T val) {
vector<SkipNode<T>*> update(maxLevel, nullptr);
vector<int> idx(maxLevel, 0); // 记录各层的前驱节点的下标
SkipNode<T>* cur = head;

for(int i=currentLevel-1;i>=0;i--){
if(i== currentLevel-1){
idx[i]=0;
}
else{
idx[i]=idx[i+1];
}

while(cur->forward[i] && cur->forward[i]->value < val) { // 找到当前层中最后一个小于val的节点
idx[i] += cur->span[i];
cur = cur->forward[i];
}

update[i] = cur;
}

int level = getRandomLevel();
if(level>currentLevel){
for(int i=currentLevel;i<level;i++) {
update[i] = head;
idx[i]=0;
}
currentLevel = level;
}

auto* newNode = new SkipNode<T>(val, level);
for(int i=0;i<level;i++){
newNode->forward[i] = update[i]->forward[i];
update[i]->forward[i] = newNode;
newNode->span[i] = update[i]->span[i] + idx[i] - idx[0];
update[i]->span[i] = idx[0] - idx[i] + 1;
}

for(int i=level;i<currentLevel;i++){
update[i]->span[i]++;
}

n++;
}

T& operator[](int index){
if (index < 0 || index >= n) {
throw out_of_range("Index out of range");
}

SkipNode<T>* cur = head;
for(int i=currentLevel-1;i>=0;i--){
while(cur->span[i]<=index&& cur->forward[i]!=nullptr){
index -= cur->span[i];
cur = cur->forward[i];
}
}

return cur->forward[0]->value;
}

T& operator[](int index) const{
if (index < 0 || index >= n) {
throw out_of_range("Index out of range");
}

SkipNode<T>* cur = head;
for(int i=currentLevel-1;i>=0;i--){
while(cur->span[i]<=index&& cur->forward[i]!=nullptr){
index -= cur->span[i];
cur = cur->forward[i];
}
}

return cur->forward[0]->value;
}

[[nodiscard]] int level() const {
return currentLevel;
}

[[nodiscard]] int size() const {
return n;
}
};

6, 线性基

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// 用于解决数组集合的异或问题
class LinearBasis {
const int max_bit = 30;
vector<int> basis;
int n;

public:
LinearBasis() : basis(max_bit + 1, 0), n(0) {}

void insert(int x) {
for(int i = max_bit; i >= 0; i--) {
if(x & (1 << i)){
if(basis[i]){
x ^= basis[i];
}
else{
n++;
basis[i]=x;
return;
}
}
}
}

bool test(int x){
for(int i = max_bit; i >= 0; i--) {
if(x & (1 << i)){
if(basis[i]){
x ^= basis[i];
}
else{
return false;
}
}
}

return true;
}

[[nodiscard]] int size() const {
return n;
}

[[nodiscard]] int find_max() const { // 查询插入元素组成的集合的最大异或值
int max_xor = 0;
for(int i = max_bit; i >= 0; i--) {
if(basis[i]){
if((max_xor^basis[i]) > max_xor){
max_xor ^= basis[i];
}
}
}
return max_xor;
}
};