jls 线段树浅析
这里的 jls 是指 蒋老师。
最近复习了一下线段树,想着抄一下 jls 的模板,却发现我以前对于线段树的理解太浅薄,有点无法理解 jls 的代码,因此经过一段时间的研究,有了自己的一点见解。
不过由于 jls 的代码的特点为:线段树维护的区间下标从 0 开始,修改/查询的区间为:[左闭,右开)形式。 因此我进行了一点魔改,使得下标从 1 开始,并且把区间形式弄成了 [左闭,右闭] 的形式。
本文需要一点面向对象的思想,一点点离散数学的前置知识,一点点 c++。不过本文不涉及关于算法复杂度的具体证明之类的,主要做的是对于模板的实现以及浅析。
以下内容仅代表个人观点,如有错误,欢迎批评指正。
首先给出代码模板:主要分成两个,但是实际上区别仅在于是由维护了 LazyTag。
不带 LazyTag
template <class Info>
struct SegTree {
int n;
vector <Info> tr;
SegTree() {}
SegTree(int _n): n(_n), tr((n + 5) * 4) {}
SegTree(int _n, const vector<Info> a): SegTree(_n) {
function <void(int, int, int)> build = [&](int u, int l, int r) {
if (l == r) {
tr[u] = a[r];
return;
}
int mid = l + r >> 1;
build(u * 2, l, mid);
build(u * 2 + 1, mid + 1, r);
pushup(u);
};
build(1, 1, n);
}
void pushup(int u) {
tr[u] = tr[u * 2] + tr[u * 2 + 1];
}
void modify(int u, int l, int r, int x, const Info &v) {
if (l == r) {
tr[u] = v;
return;
}
int mid = l + r >> 1;
if (x <= mid) {
modify(u * 2, l, mid, x, v);
} else {
modify(u * 2 + 1, mid + 1, r, x, v);
}
pushup(u);
}
Info query(int u, int l, int r, int ql, int qr) {
if (l > qr || r < ql) {
return Info();
}
if (ql <= l && r <= qr) {
return tr[u];
}
int mid = l + r >> 1;
return query(u * 2, l, mid, ql, qr) + query(u * 2 + 1, mid + 1, r, ql, qr);
}
void modify(int x, const Info &v) {
modify(1, 1, n, x, v);
}
Info query(int l, int r) {
return query(1, 1, n, l, r);
}
};
struct Info {
};
Info operator+(const Info &l, const Info &r) {
return {
};
}
带 LazyTag
template <class Info, class Tag>
struct SegTree {
int n;
vector <Info> tr;
vector <Tag> tag;
SegTree() {}
SegTree(int _n): n(_n), tr((n + 5) * 4), tag((n + 5) * 4) {}
SegTree(int _n, const vector<Info> &a): SegTree(_n) {
function <void(int, int, int)> build = [&](int u, int l, int r) {
if (l == r) {
tr[u] = a[r];
return;
}
int mid = l + r >> 1;
build(u * 2, l, mid);
build(u * 2 + 1, mid + 1, r);
pushup(u);
};
build(1, 1, n);
}
void pushup(int u) {
tr[u] = tr[u * 2] + tr[u * 2 + 1];
}
void apply(int u, const Tag &v) {
tr[u].apply(v);
tag[u].merge(v);
}
void pushdown(int u) {
apply(u * 2, tag[u]);
apply(u * 2 + 1, tag[u]);
tag[u] = Tag();
}
void modify(int u, int l, int r, int x, const Info &v) {
if (l == r) {
tr[u] = v;
return;
}
pushdown(u);
int mid = l + r >> 1;
if (x <= mid) {
modify(u * 2, l, mid, x, v);
} else {
modify(u * 2 + 1, mid + 1, r, x, v);
}
pushup(u);
}
Info query(int u, int l, int r, int ql, int qr) {
if (l > qr || r < ql) {
return Info();
}
if (ql <= l && r <= qr) {
return tr[u];
}
pushdown(u);
int mid = l + r >> 1;
return query(u * 2, l, mid, ql, qr) + query(u * 2 + 1, mid + 1, r, ql, qr);
}
void rangeModify(int u, int l, int r, int ql, int qr, const Tag &v) {
if (l > qr || r < ql) {
return;
}
if (ql <= l && r <= qr) {
apply(u, v);
return;
}
pushdown(u);
int mid = l + r >> 1;
rangeModify(u * 2, l, mid, ql, qr, v);
rangeModify(u * 2 + 1, mid + 1, r, ql, qr, v);
pushup(u);
}
void modify(int x, const Info &v) {
modify(1, 1, n, x, v);
}
Info query(int l, int r) {
return query(1, 1, n, l, r);
}
void rangeModify(int l, int r, const Tag &v) {
rangeModify(1, 1, n, l, r, v);
}
};
struct Tag {
void merge(const Tag &o) {
}
};
struct Info {
void apply(const Tag &o) {
}
};
Info operator+(const Info &l, const Info &r) {
return {
};
}
原理以及代码解释
也主要分成两个部分。
常用定义
首先,线段树是一种可以快速维护一类 幺半群 信息的数据结构。
什么是幺半群呢?幺半群是一个存在单位元(幺元)的半群。
但是这里我们不对此做详细说明,大家只需要知道幺半群的一些 性质以及定义 即可。
线段树符号定义
现在看不懂没关系,看到下面再回过头来看就懂了。
n
表示序列长度tr[]
表示线段树存储的数组a[]
表示长度为 n 的 信息序列(u, l, r)
表示当前为 u 号节点,其表示的区间范围为 [l,r] 左右皆为闭(ql, qr)
(queryLeft -> ql
)表示询问区间 [ql,qr] 左右皆为闭Info v
表示某一个节点的信息,注意是直接可以替换掉某个节点,不是表示信息的增量!Tag v
表示对于某一个节点的 LazyTag
幺半群定义
考虑定义了二元运算 ∘:S→S(注意这里蕴含了 S 对运算 ∘ 封闭)的非空集合 S,若满足如下公理:
- 结合律: ∀a,b,c∈S ,有 (a∘b)∘c=a∘(b∘c)
- 单位元(幺元): \existe∈S, 使得 ∀a∈S, 有 a∘e=e∘a
则三元组 ⟨S,∘,e⟩ 被称为 幺半群.
这两个定义的重要性在于第一个给出了我们能使用线段树来维护的依据,第二个使得我们可以进行代码上的简化.
普通线段树
对于一个给定的长度为 n 序列 S, 可以在时间复杂度为 O(logn) 的时间进行单点修改, 以及区间的查询.
为了简化描述,我们假定运算 ∘ 以及 幺元 e 的计算为常数级别的时间 O(1)。
实际上,在复杂度计算时需要增加这些计算的时间,比如说当维护矩阵乘法之类的,此时不能简单地认为是 O(1)。
我们举个具体的例子来进行代码的详解。
比方说这样一个题目:
给定一个长度为 n 的序列 a,执行 q 次操作:
1 x y
,表示令 ax←y2 l r
,表示查询 ∑ri=lai
对于整数集合以及 +
运算,这是满足我们上面对于幺半群的定义的,即:
- 存在结合律
- 存在幺元 0
因此我们可以考虑用线段树去维护这个序列。
节点信息
首先我们定义一类 信息 Info 表示每一个线段树节点需要存储的具体信息。
可以这样写:
struct Info {
long long val = 0; // 表示当前节点所存储的和
};
注意,我这里没有写默认构造函数,而是给中的值直接赋上了初始值。
可能眼见的同学们已经发现了,我赋上的初始值都是 幺元。这有什么具体作用呢?这个问题我们在下文继续讲述。
那么考虑当前节点怎么由两个 子节点 转移过来。
往常我们的写法可能是写一个 pushup
函数,传入一个 u 然后根据线段树堆式存储的性质来找到两个儿子来计算。但是实际上,大多数情况下,我们不关心当前节点具体是哪两个节点,因此我们把这种 由某两个节点合并成一个节点 的这种思想抽象出来,然后再加之实现,这就是所谓的面向对象的思想。
我们不妨为 Info 重载 operator+
对应我们幺半群定义中的二元运算符 ∘。
这也就是需要将 pushup
函数写成 +
对于这个问题而言,具体来说是:
// 通过 节点 l,r 计算出他们合并后的节点并且返回。
Info operator+(const Info &l, const Info &r) {
return {
l.val + r.val
};
}
对应数学符号就是返回一个 l∘r
建树
考虑完当前线段树节点需要维护的信息以及对应的操作运算后,我们接下来考虑一些比较模板化,或者说 绝大多数情况下不需要更改的 的函数。
建树的复杂度是 2T(n2)+O(1)=O(n)。
如果我们当前已经整理出来了长度为 n 的信息序列 info[n] ,那么对于每个 叶子节点 u 表示的区间为 [i,i],我们只需要进行 tru←infoi。就完成对于叶子节点的赋值了。
void build(int u, int l, int r) {
// 到叶子节点
if (l == r) {
// 直接赋值
tr[u] = info[r];
return;
}
int mid = l + r >> 1;
build(u * 2, l, mid);
build(u * 2 + 1, mid + 1, r);
pushup(u); // 等价于 tr[u] = tr[u * 2] + tr[u * 2 + 1];
};
build(1, 1, n);
由于我们已经重载了运算符,也就是任给我们两个节点 l,r 我们都能返回其合并成一个节点后的信息 u,那么如果我们已知的是 u 的两个子节点,那么
pushup
就很简单了
void pushup(int u) {
tr[u] = tr[u * 2] + tr[u * 2 + 1];
}
单点修改
函数参数:void modify(int u, int l, int r, int pos, const Info &v)
参数看不懂可以看上面,这里多了一个参数 pos
表示当前需要修改的节点的下标,具体到这个问题上就是 操作一的 x。
那么和平常一样分类讨论即可。
- 如果到叶子节点,直接对当前节点 u←v
- 否则如果有
pos <= mid
修改左边 - 否则修改右边
- 最后记得
pushup
void modify(int u, int l, int r, int x, const Info &v) {
if (l == r) {
tr[u] = v;
return;
}
int mid = l + r >> 1;
if (x <= mid) {
modify(u * 2, l, mid, x, v);
} else {
modify(u * 2 + 1, mid + 1, r, x, v);
}
pushup(u);
}
我觉得有必要再强调几个点:
const Info &v
这样写是因为直接传递引用这样会减少点常数,并且防止无意间非法地修改了 v- v 是要直接替代掉树中节点的,v 是要直接替代掉树中节点的,v 是要直接替代掉树中节点的。v 不是一个增量!
那么修改函数看上去也是平平无奇的。但是比较有意思的是查询函数。
区间查询
函数参数:Info query(int u, int l, int r, int ql, int qr)
返回的是一个节点
对于询问,我往常的写法是:
LL query(int u, int l, int r) {
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].val;
}
int mid = tr[u].l + tr[u].r >> 1;
LL v = 0;
if (r <= mid) {
return query(u * 2, l, r);
} else if (l >= mid + 1) {
return query(u * 2 + 1, l, r);
} else {
LL left = query(u * 2, l, r);
LL right = query(u * 2 + 1, l, r);
return left + right;
}
return v;
}
分别表示:
- 如果是当前有
ql <= l && r <= qr (l, r 表示 u 节点所表示的区间范围)
那么此时直接返回当前节点的信息。 - 对于
mid
中点讨论- 如果询问区间
[ql, qr]
完全包含于左节点的话,那就返回对左孩子询问的结果 - 否则如果完全包含于右节点的话,那就返回对右孩子询问的结果
- 否则如果询问区间横跨了中点,那就把左右两个节点的结果都取出来,然后合并之后返回
- 如果询问区间
可以发现这样做是一定正确的,但是我们能不能进行一点点的简化呢?
答案是可以的。此时就需要用到我们对于 幺元 的定义了。
再复习一下,在一个幺半群中,幺元可以表示成 ∀a∈S,\existe s.t. e∘a=a∘e。说人话就是对于一个节点 u,通过我们重载的 +
,幺元 e 都能使得 e+u=u+e 也就是任何一个节点与我们的幺元 e 相并就是另外的那个节点。
我们再来重新审视一下我们为什么在进行这样麻烦的分类讨论,实际上就是希望每次询问的时候都是 合法的(树区间都坐落于询问区间)。与此相对应的就是非法的,用符号表示就是 l > qr || r < ql
那么我们不妨直接规定 非法区间对应的节点就是幺元(当然只是我们在脑海中规定,并不真实地改变我们线段树中的其他节点),这样的好处就在于如果访问到一些非法的节点直接可以返回 幺元,而根据我们幺元的定义,这样的返回结果是不会影响我们的查询结果的。并且也能保证复杂度的正确性,因为我们遇到了非法的节点就直接返回了,不会再向下遍历。
而返回幺元的时候还记得我们前面直接对于 struct Info
中的信息直接赋值的默认值吗?如果我们直接将其中的变量赋值成默认值那么我们就直接返回一个 Info()
就能表示幺元啦。
具体实现也非常的简单:
Info query(int u, int l, int r, int ql, int qr) {
if (l > qr || r < ql) {
return Info();
}
if (ql <= l && r <= qr) {
return tr[u];
}
int mid = l + r >> 1;
return query(u * 2, l, mid, ql, qr) + query(u * 2 + 1, mid + 1, r, ql, qr);
}
不得不说,这个思想是我当初学线段树的时候没有见到过的,因此感觉令我大受震撼。
一些其他细节
我们在封装成模板的时候,可以另写一个函数使得我们在结构体外部调用,这样就不需要每次都写一次 query(1, 1, n, l, r)...
之类的。
具体来说就是:
void modify(int x, const Info &v) {
modify(1, 1, n, x, v);
}
Info query(int l, int r) {
return query(1, 1, n, l, r);
}
然后对于上面模板的使用就是
// 注意 n 表示真实的序列长度,info_i 表示 第 i 个叶子的信息,下标从 1 开始
SegTree <Info> st(n, info);
// 这里不需要再自己写一遍 build,在构造函数里面建了树
// 查询 [l, r] 的信息并,左右都是确界
auto u = st.query(l, r);
// 将第 x 个叶子 直接修改成 y
st.modify(x, {y});
未完待续
懒标记线段树。
太强了,催更~