二维区间topk问题
二维区间topk问题原始来源是https://churchillnavigation.com/challenge/。去年年前我在reddit上看到这个比赛之后,跃跃欲试,期间可以说是茶饭不思啊!但是,我要说但是,最后成绩排名只是33。平均用时比第一名慢33倍,这简直是智商上的压制。期间也暴露了我写代码上的一些问题,努力改正!不过学到更多的是程序优化是算法优化和体系结构的综合,而不仅仅是理论上的渐进复杂度而已。如果不用SSE/AVX指令集,任何解决方案都打不进前10。主办方用的算法基本是理论上最优的,然而只排到了第18名,运行时间差不多是我的一半。通过研究这些代码,也发现了一些好玩的东西,因此写这篇文章来记录一下。半年之后才开始动笔,真是懒惰!
naive topk问题
所谓的naive topk问题可以表述为:
从n个数中找出其最大的k个数
言简意赅。解决方案也是非常直接,基本可用的方法用3种:
-
堆排序:用一个大小为k的最小堆,所需时间为\(\mathcal{O}(nlogk)\),空间为\(k\),其实额外空间都不需要,直接原地排就可以。
-
部分快排:只对索引小于k的部分进行排序,期望运行时间为\(\mathcal{O}(n)\),额外空间不需要,直接原地排。这里之所以是期望运行时间,是因为快排的划分不可控,只好取其平均情况。
-
中位数排序法:直接选出第k大的元素,然后划分。这个算法类似于快排,但是其运行时间是确定为\(\mathcal{O}(n)\)的,只不过常数比较大而已。
就这样,naive topk 问题基本解决了。
一维区间topk
所谓的一维区间topk问题可以表述为
在给定的大小为\(n\)的数组a中,寻找\(a[i]-a[j]\)之间的前\(k\)大的数。
这次我们就不能像前面的naive topk直接上排序,因为这样会直接毁掉原来的序列。naive topk毁掉序列信息倒没啥关系,反正只操作一次。而区间topk基本是会执行多次查询的,毁掉了原来的序列信息后面的查询就全是错的了。为此,我们需要一个另外的结构来记录中间信息,这种结构就叫做划分树。通过使用划分树,我们可以在\(\mathcal{O}(logn)\)的时间内获得区间\(i-j\)之间的第\(k\)大的数。通过执行\(k\)次这样的查询,我们就可以获得前\(k\)大的数。
划分树的基本思想就是对于某个区间,把它划分成两个子区间,左边区间的数小于右边区间的数。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,最后范围缩小到1,就找到了。
划分树的建立
建树的过程比较简单,对于区间\([l,r]\),首先通过对原数组的排序找到这个区间的中位数\(a[mid]\),小于\(a[mid]\)的数划入他的左子树\([l,mid-1]\),大于它的划入右子树\([mid,r]\)。同时,对于第\(i\)个数,记录在\([l,i]\)区间内有多少数被划入左子树。最后,对它的左子树区间\([l,mid-1]\)和右子树区间\([mid,r]\)递归的继续建树就可以了。
建树的时候要注意对于被分到同一子树的元素,元素间的相对位置不能改变。
划分树的查询
查找的过程中主要问题就是确定将要查找的区间。这个问题有些麻烦。为此我们定义一个辅助函数\(tree_find\),其定义如下:
查找深度为h,在大区间[st,ed]中找小区间[s,e]中的第k元素。
我们的想法是,先判断[s,e]中第k元素在[st,ed]的哪个子树中,然后找出对应的小区间和k,递归的进行查找,直到小区间的s=e为止。
那如何解决这个问题呢?这时候前面记录的进入左子树的元素个数就派上用场了。通过之前的记录可以知道,在区间\([st,s-1]\)中有\(el[h,s-1]\)进入左子树,记它为\(l\)。同理区间\([st,e]\)中有\(el[h,e]\)个数进去左子树,记它为\(r\)。所以,我们知道区间小区间\([s,e]\)中有\((r-l)\)个数进入左子树。那么如果\((r-l)\ge k\),那么就在左子树中继续查找,否则就在右子树中继续查找。
接着解决查找的小区间的问题。
-
如果接下来要查找的是左子树,那么小区间应该是[st+([st,s-1]区间进入左子树的个数),st+([st,e]区间内进入左子树的个数)-1],即区间\([st+l,st+r-1]\)。显然,这里\(k\)不用变。
-
如果接下来要查找的是右子树,那么小区间应该是[mid+([st,s-1]区间中进入右子树的个数),mid+([st,e]区间进入右子树的个数)-1,即区间\([mid+(s-st-l),mid+(e-st-r)]\)。显然,这里\(k\)要减去区间里已经进入左子树的个数,即\(k\)变为\(k-(r-l)\)。
于是递归继续查找直到\(s=e\)即可。
复杂度分析
很明显,这个划分树中的所有节点的度数要么是2,要么是0。因此树节点有\(2*n-1\)个,同时树高是极度平衡的,最大树高为\(log(2n)\)。由于每一个节点都需要记录其在不同深度中的左子树计数,所以空间需求为\(nlog(n)\)。而查询就是从头节点递归下降到叶子节点的过程,所以时间为\(log(n)\)。即我们可以在\(nlog(n)\)额外空间的支持下,以\(log(n)\)的时间内得到区间第\(k\)大。
查询优化
上述查询过程针对的是查询第\(k\)个数的情况,而我们所需要的是前\(k\)个。如果只是暴力的做\(k\)次这样的查询的话,挺浪费时间的。我们可以在这里做一点优化,即在树的节点之中存储节点所代表区间的数组及相应的el信息,而不是在全局中建立el数组。这样我们对于前\(k\)的查询可以修改为:
-
在进入左字数的时候,如果\(r-l=k\),则直接将当前部分数组的\(l-r\)之间的数作为结果返回
-
在进入右子树的时候,直接将左子树的第\(l\)到第\(r\)个数添加到结果中
在经过这样的修改之后,我们可以在\(log(n)+k\)的时间内获得区间前\(k\)个数, so easy 。其实这个思想与之后要谈的fractional cascading相似。
二维区间topk
二维区间topk难度相对于一维区间topk来说,难度增加了很多。其问题可以简述如下:
在平面上有\(n\)个不同的点,每个点都有其相应的坐标\(x,y\)以及其优先级\(z\)。设计一个高效的方案来支持区间查询:获得由点\((x_1,y_1)\)和点\((x_2,y_2)\)所确立的矩形区域(\(x_1<x_2,y_1<y_2\))中的优先级最大的前\(k\)个点。
在我参与这个比赛的过程中,尝试了如下几种实现:
-
线性扫描:将所有的点按照优先值来排序,然后线性扫描排序后的点,如果在矩形内,则添加到结果。当扫描到末尾或者结果数到达k之后返回结果。
-
kd树:将所有点建立一个kd树。查询时,首先建立一个大小为k的最小堆。每次查询到一个节点,扫描其左右子节点,如果在查询区间内,则加入到待查询队列中,就这样递归分裂查询。到最后下降到叶子节点时,扫描叶子节点中所存储的数据,判断其是否在目标区间内,然后加入到结果堆中。空间复杂度\(\mathcal{O}(n)\),查询时间\(\mathcal{O}(\sqrt{n})\)加上区间内节点数。总的来说kd树需要得到所有在区间内的点,极大的拖慢了速度。
-
range tree:将所有节点组合为range tree,不过也是到小于32个点时就不再分裂。查询的时候按照range tree的结构,搜索所有相交的\(x\)区间,然后再在这些区间中寻找\(y\)区间内的点,找到所有点之后依次加入存储结果的k小堆中,最后返回。其空间复杂度为\(\mathcal{O}(nlogn)\),时间复杂度\(\mathcal{O}(logn*logn)\)加上在目标区间内的节点个数。由此可见range tree也需要获得所有区间内的点,速度也是比较慢。
-
fractional cascading:这个数据结构是在range tree上做的一个改良,可以看作range tree和划分树的合体。即y级上的点不再是按照树的形式组织,而是直接以数组方式组织。每个x级的节点内部存储所有按照y排序的节点数组,同时对于数组中的每个节点,存储在其左边且进入左子树的节点数,类似于划分树。这样下降查询的时候,我们只需要在头节点对y进行二分查询,剩下的时候直接按照左节点计数即可,不再需要二分查询。在这种修改下,其空间复杂度为\(nlogn\),搜索时间复杂度为\(logn\)加上区间内的点。
-
priority search tree: 优先搜索树可以当作kd树的变形版,优先搜索树中会存储优先级最高的32个点,这些预先存储的点不会在子树中继续出现。搜索时我们用队列来存储所有需要访问的节点,以最小堆来存储中间结果。在递归查询的时候,首先判断当前节点所代表的区间是否在目标区间内,如果不在直接返回,没有必要递归。如果在区间内,则扫描内部预先存储的32个点。如果其优先级都小于堆的头节点,则没有必要将左右节点加入查询队列。否则,扫描这32个点,在目标区间内的加入结果堆中。其总空间需求是\(\mathcal{O}(n)\)的,而查询时复杂度上限与kd树等同\(\mathcal{O}(\sqrt{n})\)。但是实际过程中,我们并不需要查询所有在区间内的点,只需要获得前\(k\)个即停止,这点与前面所说的各种树完全不同。实际测试的时候,也发现这种结构是最快的。
AVX扫描加速
在上一节所提到的各种策略中,我们论证了其优缺点,综合来说pst是最优的。但是从最后公布的结果中可以看出,pst也只排到18名,比第一名慢了将近10倍。为什么会有这样的结果?因为我们的算法只是在理论上最优,而算法的实际运行是要在硬件上跑的,罔顾体系结构的算法基本都会败的一踏涂地。而所谓的体系结构在程序运行效率上的影响主要包括:专有指令和缓存系统。数组类型的扫描访问总是比树类型的递归访问的缓存局部性好,举个例子来说就像快排与堆排。同样的复杂度\(\mathcal{O}(nlogn)\),快排总是比堆排序快很多倍,其优势就在于对于缓存系统的友好性。而专有指令的存在可以更好的利用缓存系统和指令级并行,其带来的加速比也是很大的。下面我们就来探讨一下AVX指令集如何运用到在区间查询中的使用。
AVX指令集
AVX,全称为Advanced Vector Extensions,是SIMD指令的典范。其指令的操作数可以为128位、256位、512位,但是512位只有在个别xeno处理器才有,我们在本文中只使用256位的AVX2.在C++编译器中,通过引用头文件immintrin.h
来利用AVX2。在该头文件中,主要类型定义如下。
typedef union __declspec(intrin_type) _CRT_ALIGN(32) __m256 { float m256_f32[8]; } __m256; typedef struct __declspec(intrin_type) _CRT_ALIGN(32) __m256d { double m256d_f64[4]; } __m256d; typedef union __declspec(intrin_type) _CRT_ALIGN(32) __m256i { __int8 m256i_i8[32]; __int16 m256i_i16[16]; __int32 m256i_i32[8]; __int64 m256i_i64[4]; unsigned __int8 m256i_u8[32]; unsigned __int16 m256i_u16[16]; unsigned __int32 m256i_u32[8]; unsigned __int64 m256i_u64[4]; } __m256i;
在我们所遇到的数据中,坐标为单精度浮点数,所以我们这里只展示单精度相关的操作。这里用到的AVX相关操作主要 实际测试中使用该函数线性扫描比简单的线性扫描快10倍左右。这里我们来解释一下其中所用到的各个指令
-
_mm256_broadcast_ss
这个函数是将单精度浮点数复制8次,然后填充到_mm256
的内存区域中; -
_mm256_load_ps
这个函数是将特定的_mm256
的值复制到另外一个_mm256
的内存区域中,可以认为是赋值运算符; -
_mm256_and_ps
这个函数基本就是位操作的256位扩充版,不再解释; -
_mm256_cmp_ps
这个函数是用来比较256位的值,其比较时是通过将其中每一个32浮点单独比较,如果结果为真则对应的字节位填充为1,否则填充为0.可使用的比较操作类型有很多种,我们这里只使用了_CMP_LE_OQ
和_CMP_GE_OQ
,分别代表小于等于和大于等于。 -
_mm256_testz_ps
这个函数有两个操作数,其作用是比较两个操作数bitand操作之后是否每一位都为0,如果都是0返回true,否则返回false。
点与区间关系判断
通过利用上述的AVX相关指令,我们定义一个搜索函数:
point_index Solution::search_linear(const Rect rect, const point_index count, Point *out_points) { __m256 rect_lx = _mm256_broadcast_ss(&rect.lx); __m256 rect_hx = _mm256_broadcast_ss(&rect.hx); __m256 rect_ly = _mm256_broadcast_ss(&rect.ly); __m256 rect_hy = _mm256_broadcast_ss(&rect.hy); point_index n = 0; point_index i = 0; point_index end = m_avx_count; for(; i < (end - 7); i+=8) { __m256 x = _mm256_load_ps(m_x_coord.data() + i); __m256 y = _mm256_load_ps(m_y_coord.data() + i); __m256 x_in = _mm256_and_ps(_mm256_cmp_ps(rect_lx, x, _CMP_LE_OQ), _mm256_cmp_ps(rect_hx, x, _CMP_GE_OQ)); __m256 y_in = _mm256_and_ps(_mm256_cmp_ps(rect_ly, y, _CMP_LE_OQ), _mm256_cmp_ps(rect_hy, y, _CMP_GE_OQ)); if (!_mm256_testz_ps(x_in, y_in)) { __m256 mask = _mm256_and_ps(x_in, y_in); //这里需要单独比较这8个位置,代码冗长,就不贴了 } } //处理未内存对齐的数据 return (point_index)n; }
所以上述代码所做的工作就是:将坐标的\(x,y\)分量单独存储,然后分别用AVX指令集来比较是否有点在区间内。如果没有则直接返回,如果有,则单独的提取其中的八个点,判断是否在区间内并加入到结果之中。
其实,我们还有另外的一种方法来使用AVX指令集。对于判断点\((x,y)\)是否在区间\((lx,ly)-(ux,uy)\)内,上面的函数所做的判断是
对于这个操作,我们可以转换为
同样的,我们可以转换为
这样,我们可以将一个点的坐标扩充为128位,然后就可以对于一个点调用_mm128_cmp_ps
做单独的判断了。如果是double类型的就更好了,直接利用_mm256_cmp_ps
。但是这样会增加一倍的使用空间,所带来的优势就是不需要处理对齐问题和多重判断问题。具体的测试目前还比较缺乏,待会补充。
区间关系判断
这是一个拓展问题,如何判断两个二维区间之间的关系:相交、相离、包含。我们首先以一维区间入手,两个区间分别为\((a_1,b_1),(a_2,b_2)\)。
-
两个区间相交时,\(a_1\le a_2 \le b_1 \le b_2 || a_2\le a_1 \le b_2\le b_1\)
-
两个区间相离时,\(a_1> b_2|| a_2> b_2\)
-
两个区间包含时,\(a_1\le a_2 \land b_2\le b_1|| a_2\le a_1 \land b_1\le b_2\)
由上述判断条件可以看出,并没有特别好的AVX指令来做这些操作。对于二维区间,判断则更加麻烦。
-
区间相交,要求x轴相交且y轴相交;
-
区间相离,要求x轴相离或y轴相离;
-
区间包含,要求x轴包含且y轴包含,同时包含方向一致。
暂时没有想到特别好的利用AVX的方法。
多级窗口
在第二名所采取的方案中,还提到了多级窗口这个概念。其设计是这样的,第一级为4096个点,按照优先级排序。之后的每一级的点数目为前面一级的两倍,选取剩下点中优先级最高的一组点。然后有两个备份,分别按照x排序和y排序。在搜索时,我们首先在第一级窗口中寻找符合条件的点,找到k个之后返回。如果合适的点不足k个,则在之后的窗口中分别查询符合x区间条件的点的数目和y区间条件的点的数目,选择其中较小的那个备份中线性探查。在这个线性探查的时候,只需要探查x区间或y区间,因为我们对于另外的一半区间已经做过判断了。
在实现时发现执行热点是标准库中的二分查找函数。为了让二分查找的空间减少,我们采取了Fractional Cascading里面的方法,上级窗口中保留下级窗口中大于当前位置的节点的索引,对于x和y分别保留一份这样的索引。这个索引可以作为下一步搜索区间的大致估计,我们可以利用这几个索引调用几次小规模的二分查找,从而获得线性扫描的精确区间。这样我们就可以更一步的缩减运行时间了。