被迫营业,学习了树状数组
树状数组
粗:
树状数组是一个基于二进制的数据结构,我们将每一个值存为区间右端点的一个“lowbit”长度的区间中,就构成了树状数组(真的是很粗。。。)
细:简介
首先,树状数组这个数据结构可以动态维护区间和,支持单点修改区间查询(其实还支持区间修改单点查询或者都一起支持,这个后续再讲,因为是扩展内容)
在学习快速幂算法的过程中,我们可以知道一个事实:
任意正整数都可以表示为 2 的整数次幂相加(也就是二进制分解)
那么,6=4+2(2^2+2^1),14=8+4+2(2^3+2^2+2^1)
那么我们可以将 1~6 这个区间表示为 1~4 并 5~6(具体为什么不是 12 并 3~6 一会再说)
1~14 可以表示为 1~8 并 9~12 并 13~14
我们就可以进行存储。
那么。。。如何存储或查询呢?
6 可以表示为(110)2,14 可以表示为(1110)2,我们可以通过一个小操作实现取出 110 最右边(最小的)那个 1,就是lowbit。
lowbit 运算
lowbit 操作,涉及到了一些位运算的知识:
我们不妨设 n>0,n 的第 k 位是 1,0~k-1 位是 0(1000000……)
那么我们先将 n 取反,此时第 k 位变成了 0,0~k-1 都是 1
我们再将 n 加上 1,那么最后的那 k 位就变回了原来的样子
但是:第 k+1 位往后,就变成了与之前相反的数
我们可以想到按位 and 操作:只要有一个不是 0,就把这一位变成 0,都是 1 变成 1
k+1 位往后是都相反的,那说明那几位中按位 and 运算一定会全部变成 0
第 k 位从始至终都是 1,0~k-1 位从始至终都是 0
那么这个数只剩下了 2^k-1
表示为:n&(~n+1)
然而在补码的表示下,~n=-1-n,那么~n+1==-n
所以 lowbit 运算最后就变成了 n&-n
查询和存储
存储(或者说是单点修改)
在讲完了 lowbit 操作之后,就到了实现
查询的时候,我们还是以 n 来举例子
我们想要存储 1-n 的前缀和,就是分步走
7 存到 t[7]之后
我们就需要修改它的父节点
通过 lowbit 操作我们可以知道,我们现在存储的节点的区间长度是 lowbit(n)
我们想要到我们的父节点,父节点的区间长度应是我们的二倍(或者是和之前的进行拼接成为更大的区间)
所以我们应该让 lowbit 变大。
由我们的计算机存储方式二进制可以得到,lowbit 的这一位加上 1,会变成 0,然而更高的哪一位会加一(直到不能进位)(两个同样长的挨着的区间分别存储是不优的,因为我们是在存储前缀和,直接详见可以得到答案就不需要浪费空间进行存储),这正好顺应了我们想要存区间的需求,那我们就让 n 加上 lowbit(n)便是他的父节点
void add(int x, int y)
{
for (; x <= n; x += x & -x)
t[x] += y;
}
查询
存储其实和查询有一点反操作的感觉
我们考虑一个区间 1~7
7=(111)2
那么,我们就可以知道,sum(1~7)=sum(1-4)+sum(5-6)+sum(7);
我们进行 7 的存储,就可以分步走:
首先将 sum(7)存储进 sum(1-7),7-=1 现在为 6
将 sum(5-6)存进 sum(1-7),6-=2 现在为 4
再将 sum(1-4)存进 sum(1-7,4-=4 现在为 0,结束
所以我们 sum(n)存的区间的长度就可以是 lowbit(n)
那样的话 sum(n-lowbit(n))存的就是下一位(100000……)直到变成 0
这就是树状数组的存储
int ask(int x)
{
int ans = 0;
for (; x; x -= x & -x)
ans += t[x];
return ans;
}
例题
洛谷 P3374,P3368 是模板题,可以去看看
P1966 火柴排队 题目链接
这道题倒是不难(我个人感觉应该到不了蓝色。。。)(我太菜了)
根据题目所说的是,两根火柴的距离定义为(a-b)^2,那么就让第 i 长的和第 i 长的放在一起,这样会是最小(贪心)
因为如果这样的话就相当于是火柴的对应关系固定的死死的,那么,我们就将火柴以编号命名(离散化但不完全离散化),然后求逆序对数
由于我们想让第 1 对应第 1,就相当于 a 数组 2 3 1 4,我将它变为 2->1,3->2,1->3,4->4,序列变为 1 2 3 4
然而 b 数组对应之后变成个 1 4 2 3,就相当于是冒泡排序求最少交换多少次,这个次数就是序列的逆序对数
(所以我说很简单嘛)
(代码)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define maxn 200010
#define mod 99999997
int n;
int a[maxn], b[maxn], c1[maxn], c2[maxn];
map<int, int> q;
int t[maxn];
void add(int x, int y)
{
for (; x <= n; x += x & -x)
t[x] += y;
}
int ask(int x)
{
int ans = 0;
for (; x; x -= x & -x)
ans += t[x];
return ans;
}
signed main()
{
ios::sync_with_stdio(false);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i], c1[i] = a[i];
for (int i = 1; i <= n; i++)
cin >> b[i], c2[i] = b[i];
sort(c1 + 1, c1 + 1 + n);
sort(c2 + 1, c2 + 1 + n);
for (int i = 1; i <= n; i++)
{
a[i] = lower_bound(c1 + 1, c1 + 1 + n, a[i]) - c1;
b[i] = lower_bound(c2 + 1, c2 + 1 + n, b[i]) - c2;
q[b[i]] = i;
}
for (int i = 1; i <= n; i++)
a[i] = q[a[i]];
int ans = 0;
for (int i = 1; i <= n; i++)
{
add(a[i], 1);
ans += ask(n) - ask(a[i]);
}
cout << ans % mod << endl;
return 0;
}
P5677 配对统计 题目链接
这道题,其实也不难,但是我想了很长时间
(其实就是理解错题意了。。。)
我们考虑首先对于好的配对进行预处理,求出每一组好的配对
然后进行记录。记录的过程中,我们可以将配对放在一个数组中。
配对分几种情况:
首先,如果说我们将序列上的所有数字放在一条数轴上,一个数字对应的好的配对一定在它的两侧,那么分三种情况:
左边距离小,右边距离小,两边一样
就可以进行处理。两边一样的就有两组好的配对是当前数字的。
我们处理完配对之后,将它们在存储时改为左端点右端点单增(我只需要统计个数又不统计在哪)
我们对配对排序(以右端点升序排序),再将询问进行升序排序(同样是右端点)
我们枚举每一个询问,将右端点比询问的右端点小的配对的左端点放入树状数组(很绕口—)
我们统计左端点在 l~r 区间之内的个数,就是配对在 l~r 区间的个数(右端点比询问的右端点小,比自己的左端点大,所以只要我的左端点大于询问的左端点就可以是一个配对了)
上代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define maxn 600010
int n, m, t[maxn];
pair<int, int> q[maxn];
int cnt;
struct no
{
int q1, q2, pos;
} question[maxn];
bool qcmp(no x, no y)
{
if (x.q2 == y.q2)
return x.q1 < y.q1;
return x.q2 < y.q2;
}
int ans;
void add(int pos)
{
while (pos <= n)
t[pos]++, pos += (pos & -pos);
}
int ask(int num)
{
int sum = 0;
while (num > 0)
sum += t[num], num -= (num & -num);
return sum;
}
struct node
{
int num, pos;
} a[maxn];
bool cmp(node x, node y)
{
return x.num < y.num;
}
void addq(int x, int y)
{
q[++cnt].first = min(x, y);
q[cnt].second = max(x, y);
}
bool comp(pair<int, int> x, pair<int, int> y)
{
if (x.second == y.second)
return x.first < y.first;
return x.second < y.second;
}
signed main()
{
ios::sync_with_stdio(false);
cin >> n >> m;
if(n==1)
{
cout<<0<<endl;
return 0;
}
for (int i = 1; i <= n; i++)
cin >> a[i].num, a[i].pos = i;
sort(a + 1, a + 1 + n, cmp);
addq(a[1].pos, a[2].pos);
addq(a[n].pos, a[n - 1].pos);
for (int i = 2; i < n; i++)
{
if (a[i].num - a[i - 1].num < a[i + 1].num - a[i].num)
addq(a[i].pos, a[i - 1].pos);
else if (a[i].num - a[i - 1].num == a[i + 1].num - a[i].num)
addq(a[i].pos, a[i - 1].pos), addq(a[i].pos, a[i + 1].pos);
else
addq(a[i].pos, a[i + 1].pos);
}
sort(q + 1, q + 1 + cnt, comp);
for (int i = 1; i <= m; i++)
cin >> question[i].q1 >> question[i].q2, question[i].pos = i;
sort(question + 1, question + 1 + m, qcmp);
int j = 1;
for (int i = 1; i <= m; i++)
{
while (q[j].second <= question[i].q2 && j <= cnt)
add(q[j].first),j++;
ans += (long long)(j - 1 - ask(question[i].q1 - 1)) * question[i].pos;
}
cout << ans << endl;
return 0;
}