===
Index
简介
最近公共祖先简称 LCA(Lowest Common Ancestor)
。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个。 为了方便,我们记某点集 S=v1,v2,…vn 的最近公共祖先为 LCA(S)。
性质
- lca(u) = u
- u 是 v 的祖先,当且仅当 lca(u, v) = u;
- 如果 u 不为 v 的祖先并且 v 不为 u 的祖先,那么 u,v 分别处于 lca(u,v) 的两棵不同子树中;
- 前序遍历中,lca(s) 出现在所有 s 中元素之前,后序遍历中 lca(s) 则出现在所有 s 中元素之后;
- 两点集并的最近公共祖先为两点集分别的最近公共祖先的最近公共祖先,即 lca(A U B) = lca(lca(A), lca(B))
- 两点的最近公共祖先必定处在树上两点间的最短路上;
- d(u,v) = h(u) + h(v) - 2h(lca(u,v)),其中 d 是树上两点间的距离,h 代表某点到树根的距离。
算法
树上倍增
// 预处理:g.complete(root) //root:根节点编号
struct LCA{
private:
int n, root, K;
vector<vector<int>> g, fa;
vector<int> dep;
public:
LCA(int n) {
init(n);
}
LCA() {}
void init(int n_) {
n = n_, K = 32 - __builtin_clz(n);;
g.resize(n), fa.resize(n, vector<int>(K)), dep.resize(n);
}
void add_edge(int a, int b) {
g[a].push_back(b);
g[b].push_back(a);
}
void dfs(int id, int fr, int d) {
fa[id][0] = fr;
dep[id] = d;
for(auto& to: g[id]) {
if (to == fr) continue;
dfs(to, id, d + 1);
}
}
void complete(int rt = 0) { //根节点编号
root = rt;
dfs(root, -1, 0);
for(int j = 0; j < K-1; ++j) for(int i = 0; i < n; ++i) {
if (fa[i][j] < 0) fa[i][j + 1] = -1;
else fa[i][j + 1] = fa[fa[i][j]][j];
}
}
int lca(int u, int v) {
if (dep[u] > dep[v]) swap(u, v);
for (int k = 0; k < K; k++) {
if ((dep[v] - dep[u]) >> k & 1) {
v = fa[v][k];
}
}
if (u == v) return u;
for (int k = K - 1; k >= 0; k--)
if (fa[u][k] != fa[v][k]) {
u = fa[u][k];
v = fa[v][k];
}
return fa[u][0];
}
int depth(int x) {
return dep[x];
}
int dist(int x, int y) {
int l = lca(x, y);
return dep[x] + dep[y] - 2 * dep[l];
}
};
- 预处理时间复杂度:O(nlog(n))
- 单次查询时间复杂度: O(log(n))
- 倍增算法是强制在线算法
欧拉序列转化为rmq问题
// N = 5e5时,空间约 170M,空间紧张时用 https://github.com/nealwu/competitive-programming/blob/master/rmq_lca/block_rmq_mask.cc
template<typename T, bool max_mode = false>
struct ST {
static int lg(unsigned x) { return x == 0 ? -1 : 31 - __builtin_clz(x);}
int n = 0;
vector<T> a;
vector<vector<int>> mat;
ST(const vector<T> &A = {}) { if (!A.empty()) build(A);}
int op(int x, int y) const {
return (max_mode ? a[y] < a[x] : a[x] < a[y]) ? x : y; // when `a[x] == a[y]`, returns y.
}
void build(const vector<T> &A) {
a = A, n = int(a.size());
int max_log = lg(n) + 1;
mat.resize(max_log);
for (int k = 0; k < max_log; k++) mat[k].resize(n - (1 << k) + 1);
for (int i = 0; i < n; i++) mat[0][i] = i;
for (int k = 1; k < max_log; k++)
for (int i = 0; i <= n - (1 << k); i++)
mat[k][i] = op(mat[k - 1][i], mat[k - 1][i + (1 << (k - 1))]);
}
int get_idx(int x, int y) const { // 0 <= x < y <= n
int k = lg(y - x);
return op(mat[k][x], mat[k][y - (1 << k)]);
}
T get_val(int x, int y) const { return a[get_idx(x, y)];}
};
struct LCA {
int n = 0, now_dfn = 0;
vector<vector<int>> g;
vector<int> fa, dep, siz, euler, dfn, ent, out, tour_ls, rev_tour_ls;
vector<int> hv_rt, hv_rt_dep, hv_rt_fa; // 后两个 vector 仅用于优化 get_kth_anc
ST<int> rmq;
bool built = false;
LCA(int n) : n(n), g(n), fa(n, -1), dep(n), siz(n), dfn(n), ent(n), out(n), tour_ls(n), hv_rt(n){}
LCA(const vector<vector<int>> &G) : LCA(int(G.size())){g = G;}
void add_edge(int a, int b) {
g[a].push_back(b);
g[b].push_back(a);
}
int degree(int v) const { return int(g[v].size()) + (built && fa[v] >= 0);}
void dfs(int u, int pa) {
fa[u] = pa, siz[u] = 1, dep[u] = pa < 0 ? 0 : dep[pa] + 1;
g[u].erase(remove(g[u].begin(), g[u].end(), pa), g[u].end());
for (int v : g[u]) {
dfs(v, u);
siz[u] += siz[v];
}
sort(g[u].begin(), g[u].end(), [&](int a, int b) {
return siz[a] > siz[b];
});
}
void dfs1(int u, bool hv) {
hv_rt[u] = hv ? hv_rt[fa[u]] : u;
dfn[u] = int(euler.size());
euler.push_back(u);
tour_ls[now_dfn] = u, ent[u] = now_dfn++;
bool hv_child = true;
for (int v : g[u]) {
dfs1(v, hv_child);
euler.push_back(u);
hv_child = false;
}
out[u] = now_dfn;
};
void build(int root = -1) {
if (0 <= root && root < n) dfs(root, -1);
for (int i = 0; i < n; i++) if (i != root && fa[i] < 0)
dfs(i, -1);
euler.reserve(2 * n);
for (int i = 0; i < n; i++)
if (fa[i] < 0) {
dfs1(i, false);
euler.push_back(-1);
}
assert(int(euler.size()) == 2 * n);
vector<int> euler_dep;
euler_dep.reserve(euler.size());
for (int u : euler) euler_dep.push_back(u < 0 ? u : dep[u]);
rmq.build(euler_dep);
hv_rt_dep.resize(n); hv_rt_fa.resize(n);
for (int i = 0; i < n; i++) {
hv_rt_dep[i] = dep[hv_rt[i]];
hv_rt_fa[i] = fa[hv_rt[i]];
}
rev_tour_ls = tour_ls;
reverse(rev_tour_ls.begin(), rev_tour_ls.end());
built = true;
}
// return <直径长度,{端点u, 端点v}>
pair<int, array<int, 2>> get_diameter() const {
pair<int, int> u_max = {-1, -1}, ux_max = {-1, -1};
pair<int, array<int, 2>> uxv_max = {-1, {-1, -1}};
for (int node : euler) {
if (node < 0) break;
u_max = max(u_max, {dep[node], node});
ux_max = max(ux_max, {u_max.first - 2 * dep[node], u_max.second});
uxv_max = max(uxv_max, {ux_max.first + dep[node], {ux_max.second, node}});
}
return uxv_max;
}
array<int, 2> get_center() const { // 树的直径的中点
pair<int, array<int, 2>> diam = get_diameter();
int length = diam.first, a = diam.second[0], b = diam.second[1];
return {get_kth_node_on_path(a, b, length / 2), get_kth_node_on_path(a, b, (length + 1) / 2)};
}
int get_lca(int a, int b) const { // return -1, if a,b不连通
a = dfn[a], b = dfn[b];
if (a > b) swap(a, b);
return euler[rmq.get_idx(a, b + 1)];
}
bool is_anc(int a, int b) const { return ent[a] <= ent[b] && ent[b] < out[a];}
bool on_path(int x, int a, int b) const {
return (is_anc(x, a) || is_anc(x, b)) && is_anc(get_lca(a, b), x);
}
int get_dist(int a, int b) const {return dep[a] + dep[b] - 2 * dep[get_lca(a, b)];}
// Returns the child of `a` that is an ancestor of `b`. Assumes `a` is a strict ancestor of `b`.
int child_anc(int a, int b) const {
assert(a != b && is_anc(a, b));
int child = euler[rmq.get_idx(dfn[a], dfn[b] + 1) + 1];
return child;
}
int get_kth_anc(int a, int k) const {
if (k > dep[a]) return -1;
int goal = dep[a] - k;
while (hv_rt_dep[a] > goal)
a = hv_rt_fa[a];
return tour_ls[ent[a] + goal - dep[a]];
}
int get_kth_node_on_path(int a, int b, int k) const {
int anc = get_lca(a, b), ls = dep[a] - dep[anc], rs = dep[b] - dep[anc];
if (k < 0 || k > ls + rs) return -1;
return k < ls ? get_kth_anc(a, k) : get_kth_anc(b, ls + rs - k);
}
// 到三个节点距离之和最小的节点(质心),lca(a, b), lca(b, c)和 lca(c, a)中最深的节点
int get_common_node(int a, int b, int c) const {
int x = get_lca(a, b), y = get_lca(b, c), z = get_lca(c, a);
return x ^ y ^ z;
}
// 给定树上k个节点子集,计算包含所有k个节点的最小子树(最多2*k-1节点) res[0].first 是子树的根
vector<pair<int, int>> compress_tree(vector<int> nodes) const {
if (nodes.empty()) return {};
auto &&comp = [&](int a, int b) { return ent[a] < ent[b]; };
sort(nodes.begin(), nodes.end(), comp);
int k = int(nodes.size());
for (int i = 0; i < k - 1; i++)
nodes.push_back(get_lca(nodes[i], nodes[i + 1]));
sort(nodes.begin() + k, nodes.end(), comp);
inplace_merge(nodes.begin(), nodes.begin() + k, nodes.end(), comp);
nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
vector<pair<int, int>> res = { {nodes[0], -1} };
for (int i = 1; i < int(nodes.size()); i++)
res.emplace_back(nodes[i], get_lca(nodes[i], nodes[i - 1]));
return res;
}
};
使用方法:
- 创建一个n个节点的lca图。
LCA lca(N)
- 添加边
lca.add_edge(u, v);
- lca预处理
lca.build();
- 求 lca,
lca.get_lca(a,b)
如果不连通返回-1 - 判断 a是否是b的祖先节点
lca.is_anc(a, b)
- 判断 x 是否在 a,b 的简单路径上
lca.on_path(x,a,b)
- 节点a和节点b的距离
lca.get_dist(a,b)
- 求树的直径
lca.get_diameter()
返回,pair<int, array<int, 2>>
直径,节点u, 节点v - 求a,b路径上的第k个节点
lca.get_kth_node_on_path(a,b,k)
- 找到包含给定节点集合的最小子树
lca.compress_tree(vector<int> nodes)
例题
最近公共祖先
给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。
模板1
#include <bits/stdc++.h>
using namespace std;
struct LCA{}; // 模板
int main(){
int n,m,s,x,y;
cin>>n>>m>>s;
LCA g(n);
for (int i = 0; i < n; ++i) {
cin >> x >> y;
g.add_edge(x-1,y-1);
}
g.complete(s-1);
for (int i = 0; i < m; ++i) {
cin >> x >> y;
cout << g.lca(x - 1, y - 1) + 1 << "\n";
}
}
模板2
#include <bits/stdc++.h>
using namespace std;
struct LCA{}; // 模板
int main(){
int n,m,s,x,y;
cin>>n>>m>>s;
LCA g(n);
for (int i = 1; i < n; ++i) {
cin >> x >> y;
g.add_edge(x-1,y-1);
}
g.build(s-1);
for (int i = 0; i < m; ++i) {
cin >> x >> y;
cout << g.get_lca(x - 1, y - 1) + 1 << "\n";
}
}
树上与u距离为k的节点
N个节点的树,q个询问,每次询问给定u,k,找到任一距离u为k的节点,如果没有,输出-1.
- 2 <= n <= 2e5
- 1 <= q <= 2e5
- 1 <= u, k <= n
分析
设 x, y是树的直径的两个端点,则对任意节点u,x和y中至少有一个是距离u最远的节点。
对于一个查询u,设 x是距离u最远的节点,对于任意查询k,如果k大于该距离,答案为-1,否则为u到x的路径上第k个祖先。
#include <bits/stdc++.h>
using namespace std;
struct LCA{}; // 模板
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
int n, x, y, q, u, k;
cin >> n;
LCA g(n);
for (int i = 0; i < n - 1; ++i) {
cin >> x >> y;
g.add_edge(x - 1, y - 1);
}
g.build();
pair<int, array<int, 2>> p = g.get_diameter();
x = p.second[0], y = p.second[1];
cin >> q;
while(q--) {
cin >> u >> k;
u--;
int fa = g.get_dist(u,x) > g.get_dist(u,y) ? x : y;
if(k > g.get_dist(u,fa)) cout<<"-1\n";
else cout << g.get_kth_node_on_path(u, fa, k) + 1<<'\n';
}
return 0;
}
连通路径
N个节点的树,定义节点集合是连通的,如果树上存在一条简单路径,经过集合中每个节点恰好一次,这条路径可以包含集合外的节点。 给q个询问,每个询问给一个节点集合,如果集合是连通的,输出yes,否则输出no。
- 1 <= n <= 2e5
- 1 <= q <= 1e5
- 所有询问集合大小之和不超过2e5
分析
设集合中距离最远的两个节点为u,v,集合中其它节点应该都在u,v的简单路径上。
#include <bits/stdc++.h>
using namespace std;
struct LCA{}; // 模板
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
int n, m, x, y, k, q;
cin >> n;
LCA g(n);
for (int i = 0; i < n - 1; ++i) {
cin >> x >> y;
x--, y--;
g.add_edge(x, y);
}
g.build();
cin >> q;
while (q--) {
cin >> m;
vector<int> a(m);
for (int i = 0; i < m; ++i) {
cin >> a[i];
a[i]--;
}
bool o = 1;
sort(a.begin(), a.end(), [&](int x, int y){
return g.dep[x] > g.dep[y];
});
int u = a[0], v = a[0], d = 0;
for (int i = 1; i < m; ++i) {
int dis = g.get_dist(a[i], u);
if (dis > d) {
d = dis, v = a[i];
}
}
for (int i = 0; i < m; ++i) {
if (a[i] == u || a[i] == v) continue;
if (!g.on_path(a[i], v, u)) {
o = 0;
break;
}
}
cout << (o ? "YES\n" : "NO\n");
}
}
network
n个节点的树,有m条附加边,假设你可以删除一条树上的边和一条附加边,求能让树断裂的方案数。
- 1 <= n <= 1e5
- 1 <= m <= 2e5
分析
树上差分,对于m条边,统计树上每条边被多少个环覆盖。 设覆盖环为x:
- x > 1 : 删除树上该边无法使树断裂, ans += 0
- x = 1 : 删除树上该边和对应环的附加边才能断裂 ans += 1
- x = 0 : 删除树上该边即可使树断裂,搭配任意m条边都可以, ans += m
void ac_yyf(int tt) {
cin >> n >> m;
HLD g(n);
vector<pair<int,int>> a;
for (int i = 0, u, v; i < n - 1; i++) {
cin >> u >> v;
u--, v--;
g.add_edge(u, v);
a.emplace_back(u, v);
}
g.build();
vector<int> d(n);
for (int i = 0, u, v; i < m; ++i) {
cin >> u >> v;
u--, v--;
d[u]++, d[v]++, d[g.lca(u, v)] -= 2;
}
function<void(int, int)> dfs = [&](int u, int fa) {
for (int v : g.g[u]) if (v != fa) {
dfs(v, u);
d[u] += d[v];
}
};
dfs(0, -1);
long long ans = 0;
for (auto &[x, y] : a) {
if (g.dep[x] < g.dep[y]) swap(x, y);
if (d[x] == 1) ans ++;
else if (d[x] == 0) ans += m;
}
cout << ans << '\n';
}