===
Index
可持久化线段树
模板
#if __cplusplus >= 201703L
#define CHK(t) if constexpr (t)
#else
#define CHK(t) if (t)
#endif
template <typename T, bool auto_extend = false>
struct ObjPool {
using val_type = T;
using val_ptr = T*;
template<bool B, class U, class F > using cand_t = typename conditional<B,U,F>::type;
template <typename U> using container_type = cand_t<auto_extend, deque<U>, vector<U>>;
container_type<val_type> pool;
container_type<val_ptr> buff;
decltype(buff.begin()) it;
ObjPool() : ObjPool(0) {}
ObjPool(int siz) : pool(siz), buff(siz) { clear();}
int capacity() const { return pool.size(); }
int size() const { return it - buff.begin(); }
val_ptr alloc() { CHK(auto_extend) ensure(); return *it++;}
void free(val_ptr t) { *--it = t;}
void clear() {
int siz = pool.size(); it = buff.begin();
for (int i = 0; i < siz; i++) buff[i] = &pool[i];
}
void ensure() {
if (it != buff.end()) return;
int siz = buff.size();
for (int i = siz; i <= siz * 2; ++i) buff.push_back(&pool.emplace_back());
it = buff.begin() + siz;
}
};
template <typename T, T(*op)(T, T), T(*e)()>
struct PersistentSegTree {
struct Node;
using val_type = T;
using node_type = Node;
using node_ptr = node_type*;
struct Node {
static inline ObjPool<node_type> _pool;
node_ptr _ch[2]{ nullptr, nullptr };
val_type _dat;
Node() : _dat(e()) {}
static node_ptr clone(node_ptr node) { return &(*_pool.alloc() = *node);}
static void upd(node_ptr node) { node->_dat = op(node->_ch[0]->_dat, node->_ch[1]->_dat);}
static bool is_leaf(node_ptr node) { return not node->_ch[0];}
static node_ptr build(const std::vector<val_type>& dat) {
function<node_ptr(int, int)> dfs = [&](int l, int r) {
node_ptr res = _pool.alloc();
if (r - l == 1) res->_dat = dat[l];
else {
int m = (l + r) >> 1;
res->_ch[0] = dfs(l, m), res->_ch[1] = dfs(m, r);
upd(res);
}
return res;
};
return dfs(0, dat.size());
}
static val_type get_all(node_ptr node) { return node ? node->_dat : e();}
static val_type get(node_ptr node, int tl, int tr, int ql, int qr) {
if (tr <= ql or qr <= tl) return e();
if (ql <= tl and tr <= qr) return node->_dat;
int tm = (tl + tr) >> 1;
return op(get(node->_ch[0], tl, tm, ql, qr), get(node->_ch[1], tm, tr, ql, qr));
}
template <bool do_upd, typename F>
static node_ptr search_node(node_ptr node, int siz, int i, F &&f) {
static vector<node_ptr> path;
node_ptr res = node;
CHK(do_upd) res = clone(res);
node_ptr cur = res;
for (int l = 0, r = siz; r - l > 1;) {
CHK(do_upd) path.push_back(cur);
int m = (l + r) >> 1;
if (i < m) {
CHK(do_upd) cur->_ch[0] = clone(cur->_ch[0]);
cur = cur->_ch[0], r = m;
} else {
CHK(do_upd) cur->_ch[1] = clone(cur->_ch[1]);
cur = cur->_ch[1], l = m;
}
}
f(cur);
CHK(do_upd) {
while (path.size()) upd(path.back()), path.pop_back();
return res;
} else return nullptr;
}
static val_type get(node_ptr node, int siz, int i) {
val_type res;
search_node</*do_upd =*/false>(node, siz, i, [&](node_ptr i_th_node) { res = i_th_node->_dat; });
return res;
}
template <typename F>
static node_ptr apply(node_ptr node, int siz, int i, F&& f) {
return search_node</* do_upd = */true>(node, siz, i, [&](node_ptr i_th_node) { i_th_node->_dat = f(i_th_node->_dat); });
}
static node_ptr set(node_ptr node, int siz, int i, const val_type& dat) {
return apply(node, siz, i, [&](const val_type&) { return dat; });
}
template <typename F>
static int max_right(node_ptr node, int siz, int l, F&& f) {
assert(f(e()));
function<int(node_ptr, int, int, val_type)> dfs = [&](node_ptr cur, int tl, int tr, val_type& sum) {
if (tr <= l) return tr;
if (l <= tl) {
val_type nxt_sum = op(sum, cur->_dat);
if (f(nxt_sum)) {
sum = std::move(nxt_sum);
return tr;
}
if (tr - tl == 1) return tl;
}
int tm = (tl + tr) >> 1;
int res_l = dfs(cur->_ch[0], tl, tm, sum);
return res_l != tm ? res_l : dfs(cur->_ch[1], tm, tr, sum);
};
val_type sum = e();
return dfs(node, 0, siz, sum);
}
template <typename F>
static int min_left(node_ptr node, int siz, int r, F&& f) {
assert(f(e()));
function<int(node_ptr, int, int, val_type)> dfs = [&](node_ptr cur, int tl, int tr, val_type& sum) {
if (r <= tl) return tl;
if (tr <= r) {
val_type nxt_sum = op(cur->_dat, sum);
if (f(nxt_sum)) {
sum = move(nxt_sum);
return tl;
}
if (tr - tl == 1) return tr;
}
int tm = (tl + tr) >> 1;
int res_r = dfs(cur->_ch[1], tm, tr, sum);
return res_r != tm ? res_r : dfs(cur->_ch[0], tl, tm, sum);
};
val_type sum = e();
return dfs(node, 0, siz, sum);
}
template <typename OutputIter>
static void dump(node_ptr node, OutputIter it) {
if (not node) return;
function<void(node_ptr)> dfs = [&](node_ptr cur) {
if (is_leaf(cur)) *it++ = cur->_dat;
else dfs(cur->_ch[0]), dfs(cur->_ch[1]);
};
dfs(node);
}
static std::vector<val_type> dump(node_ptr node) {
vector<val_type> res;
dump(node, std::back_inserter(res));
return res;
}
};
PersistentSegTree() : _n(0), _root(nullptr) {}
explicit PersistentSegTree(int n) : PersistentSegTree(vector<val_type>(n, e())) {}
PersistentSegTree(const vector<val_type>& dat) : _n(dat.size()), _root(node_type::build(dat)) {}
static void init_pool(int siz) { node_type::_pool = ObjPool<node_type>(siz);}
static void clear_pool() { node_type::_pool.clear();}
val_type get_all() { return node_type::get_all(_root);}
val_type get(int l, int r) { // a[l,..r-1]
assert(0 <= l and l <= r and r <= _n);
return node_type::get(_root, 0, _n, l, r);
}
val_type operator()(int l, int r) { return get(l, r);}
val_type get(int i) { assert(0 <= i and i < _n); return node_type::get(_root, _n, i);}
val_type operator[](int i) { return get(i);}
template <typename F>
PersistentSegTree apply(int i, F&& f) { assert(0 <= i and i < _n);
return PersistentSegTree(_n, node_type::apply(_root, _n, i, std::forward<F>(f)));
}
PersistentSegTree set(int i, const val_type& v) { assert(0 <= i and i < _n);
return PersistentSegTree(_n, node_type::set(_root, _n, i, v));
}
template <typename F> int max_right(int l, F&& f) { assert(0 <= l and l <= _n);
return node_type::max_right(_root, _n, l, std::forward<F>(f));
}
template <bool(*pred)(val_type)> static int max_right(int l) { return max_right(l, pred);}
template <typename F> int min_left(int r, F&& f) { assert(0 <= r and r <= _n);
return node_type::min_left(_root, _n, r, std::forward<F>(f));
}
template <bool(*pred)(val_type)> static int min_left(int r) { return min_left(r, pred);}
template <typename OutputIter>
void dump(OutputIter it) { node_type::dump(_root, it);}
vector<val_type> dump() { return node_type::dump(_root);}
private:
int _n;
node_ptr _root;
PersistentSegTree(int n, node_ptr root) : _n(n), _root(root) {}
};
using S = int;
S op(S x, S y) {
return x + y;
}
S e() {
return 0;
}
using Seg = PersistentSegTree<S, op, e>;
使用方法
可持久化线段树:支持回退,访问之前版本的线段树。
思想:前缀和思想,保存每次插入操作是的历史版本,每个版本在上一版本的基础上添加(log(n)+1)个新的节点。假设有m个版本的线段树,每个线段树有n个节点,初始建树需要开2n-1
个节点,有m次插入,每次插入最多增加 logn + 1
个节点,总空间大小为 2*n + m * (log(n) + 1)
个节点。
- 定义n棵线段树
vector<Seg> segs(n + 1);
- 对第t个版本的线段树的第i个元素执行+x的操作。
int t, i, x;
segs[t] = segs[t].apply(i, [x](S e) { return e + x; });
或者
int t, i, x;
segs[t] = segs[t].set(i, segs[t].get(i) + x);
静态区间第k小
给定长度为n的数组,m次查询,每次查询区间[l,r]内的第k小数。
- 1 <= n <= m <= 2e5
- -1e9 <= a[i] <= 1e9
- 1 <= l <= r <= n
- 1 <= k <= r - l + 1
// PersistentSegTree
// Discrete
using S = int;
S op(S x, S y) {
return x + y;
}
S e() {
return 0;
}
using Seg = PersistentSegTree<S, op, e>;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Seg::init_pool(4000000); // 初始化存储
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
Discrete<int> v(a);
for (int &x : a)
x = v(x);
int q = v.size();
vector<Seg> segs(n + 1);
segs[0] = Seg(q);
for (int i = 1; i <= n; ++i) {
segs[i] = segs[i - 1];
segs[i] = segs[i].set(a[i - 1], segs[i][a[i-1]]+1);
}
for (int i = 0; i < m; ++i) {
int l, r, k;
cin >> l >> r >> k;
l--;
int lo = 0, hi = q - 1, ans = hi;
while (lo <= hi) {
int md = (lo + hi) / 2;
// cnt: [l,r]区间内小于等于md的元素数目
int cnt = segs[r].get(0, md + 1) - segs[l].get(0, md + 1);
if (cnt >= k) {
ans = md;
hi = md - 1;
} else {
lo = md + 1;
}
}
cout << v[ans] << '\n';
}
return 0;
}
矩阵操作
n行m列矩阵,初始值全为0,q次查询。
- 1 l r x 将第l列到第r列到所有元素都加x
- 2 i x 第i行的所有元素赋值为x
-
3 i j 输出第i行第j列元素
- 1 <= n, m, q <= 2e5
- 1 <= x <= 1e9
- l, r 都在满足条件范围内
可持久化线段树
q次操作维护q+1个版本的线段树,每个版本的线段树维护长度为m+1的数组,表示每一列被加了多少的差分数组。 last[n]维护在某个版本第i行被修改为元素x,在求第i行第j列元素时,找到第i行上次修改的数值,然后累加在上次修改版本后的版本中这一列增加的元素和,求和后就是最终结果。
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m, q;
std::cin >> n >> m >> q;
vector<pair<int,int> > last(n);
Seg::init_pool(10000000);
vector<Seg> segs(q + 1);
segs[0] = Seg(m + 1);
for (int t = 1; t <= q; ++t) {
int op, l, r, x, i, j;
cin >> op;
segs[t] = segs[t - 1];
if (op == 1) {
cin >> l >> r >> x;
--l;
segs[t] = segs[t].apply(l, [x](long long e) { return e + x; });
segs[t] = segs[t].apply(r, [x](long long e) { return e - x; });
} else if (op == 2) {
std::cin >> i >> x;
--i;
last[i] = { t - 1, x };
} else {
std::cin >> i >> j;
--i, --j;
auto [tl, x] = last[i];
cout << x + segs[t - 1].get(0, j + 1) - segs[tl].get(0, j + 1) << '\n'; // 在[tl, t]之间增加了多少
}
}
return 0;
}
矩阵和
二维平面上有n个点,坐标(x,y),权重w,q次查询,每次查询给定
-
l d r u 求满足 l <= x < y, d <= y < u 的权重和。
- 1 <= n, q <= we5
- 0 <= x, y, w <= 1e9
- 0 <= l < r <= 1e9
- 0 <= d < u <= 1e9
分析
对坐标离散化,设有n行m列,维护n+1个版本线段树,利用前缀和思想,计算[l,r]版本内[d,u]权重之和。
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
Seg::init_pool(5000000);
int n, q;
cin >> n >> q;
vector<int> vx, vy;
vx.reserve(n); vy.reserve(n);
vector<array<int, 3>> a(n);
for (int i = 0; i < n; ++i) {
int x, y, w;
cin >> x >> y >> w;
a[i] = {x, y, w};
vx.push_back(x);
vy.push_back(y);
}
Discrete<int> dx(vx);
Discrete<int> dy(vy);
vector<array<int, 4>> qs(q);
for (int i = 0; i < q; ++i) {
int l, d, r, u;
cin >> l >> d >> r >> u;
qs[i] = {l, d, r, u};
}
int nx = dx.size(), ny = dy.size();
vector<vector<pair<int,int>>> p(nx);
for (auto &[x, y, w] : a) {
p[dx(x)].push_back({dy(y), w});
}
vector<Seg> segs(nx + 1);
segs[0] = Seg(ny);
for (int i = 1; i <= nx; ++i) {
segs[i] = segs[i - 1];
if (!p[i - 1].size()) continue;
for (auto &[x, y] : p[i - 1]) {
long long w = y;
segs[i] = segs[i].apply(x, [w](S e){return e + w;});
}
}
for (auto &[l, d, r, u] : qs) {
l = dx(l), r = dx(r), d = dy(d), u = dy(u);
long long ans = segs[r].get(d, u) - segs[l].get(d, u);
cout << ans << '\n';
}
return 0;
}
带修改区间中k出现数量
长度为n的数组,m次操作,
- C x p 将下标为x的数修改为p
-
Q l r k 查询[l,r]区间内k出现的次数
- 1 <= n, m <= 1e5
- 1 <= a[i], k <= 1e9
分析
将a中元素与所有查询出现的k进行离散化,每个元素唯一一个版本的线段树,修改操作时,将a[x]版本的线段树x位置减一,p版本x位置加一。查询即为查询k版本的线段树[l,r]区间和。
int main() {
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
auto b = a;
vector<vector<int>> qs(m);
char op;
for (int i = 0; i < m; ++i) {
cin >> op;
if (op == 'C') {
int x, v;
cin >> x >> v;
qs[i] = {x, v};
} else {
int l, r, k;
cin >> l >> r >> k;
b.push_back(k);
qs[i] = {l, r, k};
}
}
Discrete<int> v(b);
int m1 = v.size();
vector<Seg> segs(m1);
segs[0] = Seg(n);
for (int i = 1; i < m1; ++i) {
segs[i] = segs[i - 1];
}
for (int i = 0; i < n; ++i) {
a[i] = v(a[i]);
segs[a[i]] = segs[a[i]].apply(i, [](int e){return e + 1;});
}
for (int i = 0; i < m; ++i) {
if (qs[i].size() == 2) {
int p = qs[i][0] - 1, val = qs[i][1];
int x = v(val);
segs[a[p]] = segs[a[p]].apply(p, [](int e){return e - 1;});
segs[x] = segs[x].apply(p, [](int e){return e + 1;});
a[p] = x;
} else {
int l = qs[i][0] - 1, r = qs[i][1], x = v(qs[i][2]);
cout << segs[x].get(l, r) << '\n';
}
}
}
树上路径权值第k小
n个节点的树,每个点有一个权值。有 m 个询问,每次给你 u,v,k,你需要回答 (u ^ last) 到 v 这两个节点间第 k 小的点权。其中 last是上一个询问的答案,定义其初始为 0,数据保证每次(u ^ last)在1-n之间。
- 1 <= n, m <= 1e5
- 1 <= a[i] <= 1e9
分析
静态区间第k小的树上版本
int main() {
Seg::init_pool(2000000); // 初始化存储
int n, m;
cin >> n >> m;
HLD g(n);
vector<int> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
u--; v--;
g.add_edge(u, v);
}
g.build();
Discrete<int> d(a);
int q = d.size();
vector<Seg> segs(n + 1);
segs[0] = Seg(q);
function<void(int, int)> dfs = [&](int u, int fa) {
segs[u + 1] = segs[fa + 1];
segs[u + 1] = segs[u + 1].apply(d(a[u]), [](int x){return x + 1;});
for (int v : g.g[u]) if (v != fa) {
dfs(v, u);
}
};
dfs(0, -1);
auto get = [&](int u, int v, int k) {
int l = 0, r = q - 1, ans = r;
while (l <= r) {
int md = (l + r) / 2;
int p = g.lca(u - 1, v - 1);
int x = segs[u].get(0, md + 1) + segs[v].get(0, md + 1) - segs[p + 1].get(0, md + 1);
if (p != 0) x -= segs[g.fa[p] + 1].get(0, md + 1);
if (x >= k) {
ans = md;
r = md - 1;
} else l = md + 1;
}
return ans;
};
for (int i = 0, u, v, k, lst = 0; i < m; ++i) {
cin >> u >> v >> k;
u = u ^ lst;
lst = d[get(u, v, k)];
cout << lst << '\n';
}
}
可持久化懒标记线段树
模板
#if __cplusplus >= 201703L
#define CHK(t) if constexpr (t)
#else
#define CHK(t) if (t)
#endif
template <typename T, bool auto_extend = false>
struct ObjPool {
using val_type = T;
using val_ptr = T*;
template<bool B, class U, class F > using cand_t = typename conditional<B,U,F>::type;
template <typename U> using container_type = cand_t<auto_extend, deque<U>, vector<U>>;
container_type<val_type> pool;
container_type<val_ptr> buff;
decltype(buff.begin()) it;
ObjPool() : ObjPool(0) {}
ObjPool(int siz) : pool(siz), buff(siz) { clear();}
int capacity() const { return pool.size(); }
int size() const { return it - buff.begin(); }
val_ptr alloc() { CHK(auto_extend) ensure(); return *it++;}
void free(val_ptr t) { *--it = t;}
void clear() {
int siz = pool.size(); it = buff.begin();
for (int i = 0; i < siz; i++) buff[i] = &pool[i];
}
void ensure() {
if (it != buff.end()) return;
int siz = buff.size();
for (int i = siz; i <= siz * 2; ++i) buff.push_back(&pool.emplace_back());
it = buff.begin() + siz;
}
};
template <class T, T(*op)(T, T), T(*e)(), class F, T(*tag)(F, T), F(*merge)(F, F), F(*id)()>
struct PersisLazySegTree {
struct Node;
using val_type = T;
using op_type = F;
using node_type = Node;
using node_ptr = node_type*;
struct Node {
static inline ObjPool<node_type> _pool;
node_ptr _ch[2]{ nullptr, nullptr };
val_type _dat;
op_type _laz;
Node() : _dat(e()), _laz(id()) {}
static node_ptr clone(node_ptr node) { return &(*_pool.alloc() = *node);}
static void upd(node_ptr node) {node->_dat = op(node->_ch[0]->_dat, node->_ch[1]->_dat);}
template <bool do_clone = true>
static node_ptr push(node_ptr node) {
node_ptr res = node;
CHK(do_clone) res = clone(res);
res->_ch[0] = apply_all(res->_ch[0], res->_laz);
res->_ch[1] = apply_all(res->_ch[1], res->_laz);
res->_laz = id();
CHK(do_clone) return res;
return nullptr;
}
static bool is_leaf(node_ptr node) { return not node->_ch[0];}
static node_ptr build(const std::vector<val_type>& dat) {
function<node_ptr(int, int)> dfs = [&](int l, int r) {
node_ptr res = _pool.alloc();
if (r - l == 1) {
res->_dat = dat[l]; res->_laz = id();
} else {
int m = (l + r) >> 1;
res->_ch[0] = dfs(l, m), res->_ch[1] = dfs(m, r);
upd(res);
res->_laz = id();
}
return res;
};
return dfs(0, dat.size());
}
static val_type get_all(node_ptr node) { return node ? node->_dat : e();}
static val_type get(node_ptr node, int tl, int tr, int ql, int qr, const op_type &f = id()) {
if (tr <= ql or qr <= tl) return e();
if (ql <= tl and tr <= qr) return tag(f, node->_dat);
int tm = (tl + tr) >> 1;
op_type g = merge(f, node->_laz);
return op(get(node->_ch[0], tl, tm, ql, qr, g), get(node->_ch[1], tm, tr, ql, qr, g));
}
static node_ptr apply_all(node_ptr node, const op_type &f) {
if (not node) return nullptr;
node_ptr res = clone(node);
res->_dat = tag(f, res->_dat); res->_laz = merge(f, res->_laz);
return res;
}
static node_ptr apply(node_ptr node, int tl, int tr, int ql, int qr, const op_type &f) {
if (tr <= ql or qr <= tl) return node;
if (ql <= tl and tr <= qr) return apply_all(node, f);
node_ptr res = push(node);
int tm = (tl + tr) >> 1;
res->_ch[0] = apply(res->_ch[0], tl, tm, ql, qr, f);
res->_ch[1] = apply(res->_ch[1], tm, tr, ql, qr, f);
upd(res);
return res;
}
template <typename Func>
static node_ptr upd_leaf(node_ptr node, int siz, int i, Func &&f) {
static vector<node_ptr> path;
node_ptr res = clone(node);
node_ptr cur = res;
for (int l = 0, r = siz; r - l > 1;) {
path.push_back(cur);
push</*do_clone = */false>(cur);
int m = (l + r) >> 1;
if (i < m) {
cur = cur->_ch[0]; r = m;
} else {
cur = cur->_ch[1]; l = m;
}
}
cur->_dat = f(cur->_dat);
while (path.size()) upd(path.back()), path.pop_back();
return res;
}
static val_type get(node_ptr node, int siz, int i) {
op_type f = id();
node_ptr cur = node;
for (int l = 0, r = siz; r - l > 1;) {
f = merge(f, cur->_laz);
int m = (l + r) >> 1;
if (i < m) {
cur = cur->_ch[0]; r = m;
} else {
cur = cur->_ch[1]; l = m;
}
}
return tag(f, cur->_dat);
}
template <typename Func>
static node_ptr apply(node_ptr node, int siz, int i, Func&& f) {
return upd_leaf(node, siz, i, [&](const val_type &v) { return f(v); });
}
static node_ptr set(node_ptr node, int siz, int i, const val_type& dat) {
return apply(node, siz, i, [&](const val_type&) { return dat; });
}
template <typename G>
static int max_right(node_ptr node, int siz, int l, G&& g) {
assert(g(e()));
function<int(node_ptr, int, int, val_type, op_type)>
dfs = [&](node_ptr cur, int tl, int tr, val_type& sum, const op_type &f) {
if (tr <= l) return tr;
if (l <= tl) {
val_type nxt_sum = op(sum, tag(f, cur->_dat));
if (g(nxt_sum)) {
sum = move(nxt_sum); return tr;
}
if (tr - tl == 1) return tl;
}
int tm = (tl + tr) >> 1;
op_type g = merge(f, cur->_laz);
int res_l = dfs(cur->_ch[0], tl, tm, sum, g);
return res_l != tm ? res_l : dfs(cur->_ch[1], tm, tr, sum, g);
};
val_type sum = e();
return dfs(node, 0, siz, sum, id());
}
template <typename G>
static int min_left(node_ptr node, int siz, int r, G&& g) {
assert(g(e()));
function<void(node_ptr, int, int, val_type, op_type)>
dfs = [&](node_ptr cur, int tl, int tr, val_type& sum, const op_type &f) {
if (r <= tl) return tl;
if (tr <= r) {
val_type nxt_sum = op(tag(f, cur->_dat), sum);
if (g(nxt_sum)) {
sum = move(nxt_sum); return tl;
}
if (tr - tl == 1) return tr;
}
int tm = (tl + tr) >> 1;
op_type g = merge(f, cur->_laz);
int res_r = dfs(cur->_ch[1], tm, tr, sum, g);
return res_r != tm ? res_r : dfs(cur->_ch[0], tl, tm, sum, g);
};
val_type sum = e();
return dfs(node, 0, siz, sum, id());
}
template <typename OutputIter>
static void dump(node_ptr node, OutputIter it) {
if (not node) return;
function<void(node_ptr, op_type)> dfs = [&](node_ptr cur, const op_type &f) {
if (is_leaf(cur)) {
*it++ = tag(f, cur->_dat);
} else {
*it++ = tag(f, cur->_dat);
dfs(cur->_ch[0], merge(cur->_laz, f)), dfs(cur->_ch[1], merge(cur->_laz, f));
}
};
dfs(node, id());
}
static vector<val_type> dump(node_ptr node) {
vector<val_type> res;
dump(node, std::back_inserter(res));
return res;
}
};
PersisLazySegTree() : _n(0), _root(nullptr) {}
explicit PersisLazySegTree(int n) : PersisLazySegTree(std::vector<val_type>(n, e())) {}
PersisLazySegTree(const vector<val_type>& dat) : _n(dat.size()), _root(node_type::build(dat)) {}
static void init_pool(int siz) { node_type::_pool = ObjPool<node_type>(siz);}
static void clear_pool() { node_type::_pool.clear();}
val_type get_all() { return node_type::get_all(_root);}
val_type get(int l, int r) { assert(0 <= l and l <= r and r <= _n); return node_type::get(_root, 0, _n, l, r);}
val_type operator()(int l, int r) { return get(l, r);}
PersisLazySegTree apply_all(const op_type &f) { return PersisLazySegTree(_n, node_type::apply_all(_root, f));}
PersisLazySegTree apply(int l, int r, const op_type &f) {
return PersisLazySegTree(_n, node_type::apply(_root, 0, _n, l, r, f));
}
val_type get(int i) { assert(0 <= i and i < _n); return node_type::get(_root, _n, i);}
val_type operator[](int i) { return get(i);}
template <typename Func>
PersisLazySegTree apply(int i, Func&& f) {
assert(0 <= i and i < _n);
return PersisLazySegTree(_n, node_type::apply(_root, _n, i, std::forward<F>(f)));
}
PersisLazySegTree set(int i, const val_type& v) {
assert(0 <= i and i < _n);
return PersisLazySegTree(_n, node_type::set(_root, _n, i, v));
}
template <typename G>
int max_right(int l, G&& g) {
assert(0 <= l and l <= _n);
return node_type::max_right(_root, _n, l, std::forward<G>(g));
}
template <bool(*g)(val_type)> static int max_right(int l) { return max_right(l, g);}
template <typename G>
int min_left(int r, G&& g) {
assert(0 <= r and r <= _n);
return node_type::min_left(_root, _n, r, std::forward<G>(g));
}
template <bool(*g)(val_type)> static int min_left(int r) { return min_left(r, g);}
template <typename OutputIter>
void dump(OutputIter it) { node_type::dump(_root, it);}
vector<val_type> dump() { return node_type::dump(_root);}
private:
int _n;
node_ptr _root;
PersisLazySegTree(int n, node_ptr root) : _n(n), _root(root) {}
};
using S = long long;
using F = long long;
S op(S x, S) {
return x;
}
S e() {
return 0;
}
S tag(F f, S x) {
return f + x;
}
F merge(F f, F g) {
return f + g;
}
S id() {
return 0;
}
using Seg = PersisLazySegTree<S, op, e, F, tag, merge, id>;
可持久化数组
模板
#if __cplusplus >= 201703L
#define CHK(t) if constexpr (t)
#else
#define CHK(t) if (t)
#endif
template <typename T, bool auto_extend = false>
struct ObjPool {
using val_type = T;
using val_ptr = T*;
template<bool B, class U, class F > using cand_t = typename conditional<B,U,F>::type;
template <typename U> using container_type = cand_t<auto_extend, deque<U>, vector<U>>;
container_type<val_type> pool;
container_type<val_ptr> buff;
decltype(buff.begin()) it;
ObjPool() : ObjPool(0) {}
ObjPool(int siz) : pool(siz), buff(siz) { clear();}
int capacity() const { return pool.size(); }
int size() const { return it - buff.begin(); }
val_ptr alloc() { CHK(auto_extend) ensure(); return *it++;}
void free(val_ptr t) { *--it = t;}
void clear() {
int siz = pool.size(); it = buff.begin();
for (int i = 0; i < siz; i++) buff[i] = &pool[i];
}
void ensure() {
if (it != buff.end()) return;
int siz = buff.size();
for (int i = siz; i <= siz * 2; ++i) buff.push_back(&pool.emplace_back());
it = buff.begin() + siz;
}
};
template <typename T, int lg_ary = 4>
struct PersistentArr {
struct Node;
using node_type = Node;
using node_ptr = node_type*;
using val_type = T;
using pool_type = ObjPool<node_type>;
struct Node {
static inline pool_type pool{};
static const int mask = (1 << lg_ary) - 1;
node_ptr _ch[1 << lg_ary]{};
val_type _val;
Node(const val_type& val = val_type{}) : _val(val) {}
static node_ptr clone(node_ptr node) { return &(*pool.alloc() = *node);}
static node_ptr new_node(const val_type& val) { return &(*pool.alloc() = node_type(val));}
static val_type& get(node_ptr node, int id) {
for (; id; --id >>= lg_ary) node = node->_ch[id & mask];
return node->_val;
}
static node_ptr set(node_ptr node, int id, const val_type& val) {
node_ptr res = clone(node), cur = res;
for (; id; --id >>= lg_ary) cur = cur->_ch[id & mask] = clone(cur->_ch[id & mask]);
cur->_val = val;
return res;
}
static val_type mut_set(node_ptr node, int id, const val_type& val) {
return exchange(get(node, id), val);
}
static node_ptr build(const vector<val_type>& init) {
const int n = init.size();
if (n == 0) return nullptr;
function<void(node_ptr, int, int)> dfs = [&](node_ptr cur, int id, int p) {
int np = p << lg_ary, nid = id + p;
for (int d = 1; d < 1 << lg_ary; ++d, nid += p) {
if (nid < n) dfs(cur->_ch[d] = new_node(init[nid]), nid, np);
else return;
}
if (nid < n) dfs(cur->_ch[0] = new_node(init[nid]), nid, np);
};
node_ptr root = new_node(init[0]);
dfs(root, 0, 1);
return root;
}
static vector<val_type> dump(node_ptr node) {
if (not node) return {};
vector<val_type> res;
function<void(node_ptr, int, int)> dfs = [&](node_ptr cur, int id, int p) {
if (int(res.size()) <= id) res.resize(id + 1);
res[id] = node->_val;
int np = p << lg_ary, nid = id + p;
for (int d = 1; d < 1 << lg_ary; ++d, nid += p) {
if (cur->_ch[d]) dfs(cur->_ch[d], nid, np);
else return;
}
if (cur->_ch[0]) dfs(cur->_ch[0], nid, np);
};
dfs(node, 0, 1);
return res;
}
};
static void init_pool(int capacity) { node_type::pool = pool_type(capacity);}
PersistentArr() = default;
explicit PersistentArr(int n, const val_type& val = val_type{}) : PersistentArr(vector<val_type>(n, val)) {}
PersistentArr(const vector<val_type>& init) : _n(init.size()), _root(node_type::build(init)) {}
int size() const { return _n; }
const val_type& get(int id) { return node_type::get(_root, id); }
PersistentArr set(int id, const val_type& new_val) { return PersistentArr{ _n, node_type::set(_root, id, new_val) }; }
val_type mut_set(int id, const val_type& new_val) { return node_type::mut_set(_root, id, new_val);}
PersistentArr clone() {
if (not _root) return PersistentArr { _n, _root };
return PersistentArr{ _n, node_type::clone(_root) };
}
vector<val_type> dump() { return node_type::dump(_root);}
private:
int _n;
node_ptr _root;
explicit PersistentArr(int n, node_ptr root) : _n(n), _root(root) {}
};
using Arr = PersistentArr<int, 1>;
维护可持久化数组
维护一个长度为n的数组,m次操作:
- 在某个历史版本上修改某一个位置上的值
- 访问某个历史版本上的某一位置的值, 同时生成一个新的版本
可持久化线段树方法
// PersistentSegTree
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Seg::init_pool(24000000); // 空间需要开的足够
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
vector<Seg> segs(m + 1);
segs[0] = Seg(a);
for (int i = 1; i <= m; ++i) {
int v, t, pos, x;
cin >> v >> t >> pos;
pos--;
if (t == 1) {
cin >> x;
segs[i] = segs[v].set(pos, x); // 一定记得重新赋值
} else {
segs[i] = segs[v];
cout << segs[v].get(pos) << '\n';
}
}
return 0;
}
可持久化数组方法
int main() {
ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Arr::init_pool(20000000);
int n, m;
cin >> n >> m;
vector<int> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
vector<Arr> segs(m + 1);
segs[0] = Arr(a);
for (int i = 1; i <= m; ++i) {
int v, t, pos, x;
cin >> v >> t >> pos;
pos--;
if (t == 1) {
cin >> x;
segs[i] = segs[v].set(pos, x);
} else {
segs[i] = segs[v];
cout << segs[v].get(pos) << '\n';
}
}
return 0;
}