2D Topk Search

二维区间topk问题

二维区间topk问题原始来源是https://churchillnavigation.com/challenge/。去年年前我在reddit上看到这个比赛之后,跃跃欲试,期间可以说是茶饭不思啊!但是,我要说但是,最后成绩排名只是33。平均用时比第一名慢33倍,这简直是智商上的压制。期间也暴露了我写代码上的一些问题,努力改正!不过学到更多的是程序优化是算法优化和体系结构的综合,而不仅仅是理论上的渐进复杂度而已。如果不用SSE/AVX指令集,任何解决方案都打不进前10。主办方用的算法基本是理论上最优的,然而只排到了第18名,运行时间差不多是我的一半。通过研究这些代码,也发现了一些好玩的东西,因此写这篇文章来记录一下。半年之后才开始动笔,真是懒惰!

naive topk问题

所谓的naive topk问题可以表述为:

从n个数中找出其最大的k个数

言简意赅。解决方案也是非常直接,基本可用的方法用3种:

  • 堆排序:用一个大小为k的最小堆,所需时间为\(\mathcal{O}(nlogk)\),空间为\(k\),其实额外空间都不需要,直接原地排就可以。

  • 部分快排:只对索引小于k的部分进行排序,期望运行时间为\(\mathcal{O}(n)\),额外空间不需要,直接原地排。这里之所以是期望运行时间,是因为快排的划分不可控,只好取其平均情况。

  • 中位数排序法:直接选出第k大的元素,然后划分。这个算法类似于快排,但是其运行时间是确定为\(\mathcal{O}(n)\)的,只不过常数比较大而已。

就这样,naive topk 问题基本解决了。

一维区间topk

所谓的一维区间topk问题可以表述为

在给定的大小为\(n\)的数组a中,寻找\(a[i]-a[j]\)之间的前\(k\)大的数。

这次我们就不能像前面的naive topk直接上排序,因为这样会直接毁掉原来的序列。naive topk毁掉序列信息倒没啥关系,反正只操作一次。而区间topk基本是会执行多次查询的,毁掉了原来的序列信息后面的查询就全是错的了。为此,我们需要一个另外的结构来记录中间信息,这种结构就叫做划分树。通过使用划分树,我们可以在\(\mathcal{O}(logn)\)的时间内获得区间\(i-j\)之间的第\(k\)大的数。通过执行\(k\)次这样的查询,我们就可以获得前\(k\)大的数。

划分树的基本思想就是对于某个区间,把它划分成两个子区间,左边区间的数小于右边区间的数。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,最后范围缩小到1,就找到了。

划分树的建立

建树的过程比较简单,对于区间\([l,r]\),首先通过对原数组的排序找到这个区间的中位数\(a[mid]\),小于\(a[mid]\)的数划入他的左子树\([l,mid-1]\),大于它的划入右子树\([mid,r]\)。同时,对于第\(i\)个数,记录在\([l,i]\)区间内有多少数被划入左子树。最后,对它的左子树区间\([l,mid-1]\)和右子树区间\([mid,r]\)递归的继续建树就可以了。

建树的时候要注意对于被分到同一子树的元素,元素间的相对位置不能改变。

划分树的查询

查找的过程中主要问题就是确定将要查找的区间。这个问题有些麻烦。为此我们定义一个辅助函数\(tree_find\),其定义如下:

查找深度为h,在大区间[st,ed]中找小区间[s,e]中的第k元素。

我们的想法是,先判断[s,e]中第k元素在[st,ed]的哪个子树中,然后找出对应的小区间和k,递归的进行查找,直到小区间的s=e为止。

那如何解决这个问题呢?这时候前面记录的进入左子树的元素个数就派上用场了。通过之前的记录可以知道,在区间\([st,s-1]\)中有\(el[h,s-1]\)进入左子树,记它为\(l\)。同理区间\([st,e]\)中有\(el[h,e]\)个数进去左子树,记它为\(r\)。所以,我们知道区间小区间\([s,e]\)中有\((r-l)\)个数进入左子树。那么如果\((r-l)\ge k\),那么就在左子树中继续查找,否则就在右子树中继续查找。

接着解决查找的小区间的问题。

  • 如果接下来要查找的是左子树,那么小区间应该是[st+([st,s-1]区间进入左子树的个数),st+([st,e]区间内进入左子树的个数)-1],即区间\([st+l,st+r-1]\)。显然,这里\(k\)不用变。

  • 如果接下来要查找的是右子树,那么小区间应该是[mid+([st,s-1]区间中进入右子树的个数),mid+([st,e]区间进入右子树的个数)-1,即区间\([mid+(s-st-l),mid+(e-st-r)]\)。显然,这里\(k\)要减去区间里已经进入左子树的个数,即\(k\)变为\(k-(r-l)\)

于是递归继续查找直到\(s=e\)即可。

复杂度分析

很明显,这个划分树中的所有节点的度数要么是2,要么是0。因此树节点有\(2*n-1\)个,同时树高是极度平衡的,最大树高为\(log(2n)\)。由于每一个节点都需要记录其在不同深度中的左子树计数,所以空间需求为\(nlog(n)\)。而查询就是从头节点递归下降到叶子节点的过程,所以时间为\(log(n)\)。即我们可以在\(nlog(n)\)额外空间的支持下,以\(log(n)\)的时间内得到区间第\(k\)大。

查询优化

上述查询过程针对的是查询第\(k\)个数的情况,而我们所需要的是前\(k\)个。如果只是暴力的做\(k\)次这样的查询的话,挺浪费时间的。我们可以在这里做一点优化,即在树的节点之中存储节点所代表区间的数组及相应的el信息,而不是在全局中建立el数组。这样我们对于前\(k\)的查询可以修改为:

  • 在进入左字数的时候,如果\(r-l=k\),则直接将当前部分数组的\(l-r\)之间的数作为结果返回

  • 在进入右子树的时候,直接将左子树的第\(l\)到第\(r\)个数添加到结果中

在经过这样的修改之后,我们可以在\(log(n)+k\)的时间内获得区间前\(k\)个数, so easy 。其实这个思想与之后要谈的fractional cascading相似。

二维区间topk

二维区间topk难度相对于一维区间topk来说,难度增加了很多。其问题可以简述如下:

在平面上有\(n\)个不同的点,每个点都有其相应的坐标\(x,y\)以及其优先级\(z\)。设计一个高效的方案来支持区间查询:获得由点\((x_1,y_1)\)和点\((x_2,y_2)\)所确立的矩形区域(\(x_1<x_2,y_1<y_2\))中的优先级最大的前\(k\)个点。

在我参与这个比赛的过程中,尝试了如下几种实现:

  • 线性扫描:将所有的点按照优先值来排序,然后线性扫描排序后的点,如果在矩形内,则添加到结果。当扫描到末尾或者结果数到达k之后返回结果。

  • kd树:将所有点建立一个kd树。查询时,首先建立一个大小为k的最小堆。每次查询到一个节点,扫描其左右子节点,如果在查询区间内,则加入到待查询队列中,就这样递归分裂查询。到最后下降到叶子节点时,扫描叶子节点中所存储的数据,判断其是否在目标区间内,然后加入到结果堆中。空间复杂度\(\mathcal{O}(n)\),查询时间\(\mathcal{O}(\sqrt{n})\)加上区间内节点数。总的来说kd树需要得到所有在区间内的点,极大的拖慢了速度。

  • range tree:将所有节点组合为range tree,不过也是到小于32个点时就不再分裂。查询的时候按照range tree的结构,搜索所有相交的\(x\)区间,然后再在这些区间中寻找\(y\)区间内的点,找到所有点之后依次加入存储结果的k小堆中,最后返回。其空间复杂度为\(\mathcal{O}(nlogn)\),时间复杂度\(\mathcal{O}(logn*logn)\)加上在目标区间内的节点个数。由此可见range tree也需要获得所有区间内的点,速度也是比较慢。

  • fractional cascading:这个数据结构是在range tree上做的一个改良,可以看作range tree和划分树的合体。即y级上的点不再是按照树的形式组织,而是直接以数组方式组织。每个x级的节点内部存储所有按照y排序的节点数组,同时对于数组中的每个节点,存储在其左边且进入左子树的节点数,类似于划分树。这样下降查询的时候,我们只需要在头节点对y进行二分查询,剩下的时候直接按照左节点计数即可,不再需要二分查询。在这种修改下,其空间复杂度为\(nlogn\),搜索时间复杂度为\(logn\)加上区间内的点。

  • priority search tree: 优先搜索树可以当作kd树的变形版,优先搜索树中会存储优先级最高的32个点,这些预先存储的点不会在子树中继续出现。搜索时我们用队列来存储所有需要访问的节点,以最小堆来存储中间结果。在递归查询的时候,首先判断当前节点所代表的区间是否在目标区间内,如果不在直接返回,没有必要递归。如果在区间内,则扫描内部预先存储的32个点。如果其优先级都小于堆的头节点,则没有必要将左右节点加入查询队列。否则,扫描这32个点,在目标区间内的加入结果堆中。其总空间需求是\(\mathcal{O}(n)\)的,而查询时复杂度上限与kd树等同\(\mathcal{O}(\sqrt{n})\)。但是实际过程中,我们并不需要查询所有在区间内的点,只需要获得前\(k\)个即停止,这点与前面所说的各种树完全不同。实际测试的时候,也发现这种结构是最快的。

AVX扫描加速

在上一节所提到的各种策略中,我们论证了其优缺点,综合来说pst是最优的。但是从最后公布的结果中可以看出,pst也只排到18名,比第一名慢了将近10倍。为什么会有这样的结果?因为我们的算法只是在理论上最优,而算法的实际运行是要在硬件上跑的,罔顾体系结构的算法基本都会败的一踏涂地。而所谓的体系结构在程序运行效率上的影响主要包括:专有指令和缓存系统。数组类型的扫描访问总是比树类型的递归访问的缓存局部性好,举个例子来说就像快排与堆排。同样的复杂度\(\mathcal{O}(nlogn)\),快排总是比堆排序快很多倍,其优势就在于对于缓存系统的友好性。而专有指令的存在可以更好的利用缓存系统和指令级并行,其带来的加速比也是很大的。下面我们就来探讨一下AVX指令集如何运用到在区间查询中的使用。

AVX指令集

AVX,全称为Advanced Vector Extensions,是SIMD指令的典范。其指令的操作数可以为128位、256位、512位,但是512位只有在个别xeno处理器才有,我们在本文中只使用256位的AVX2.在C++编译器中,通过引用头文件immintrin.h来利用AVX2。在该头文件中,主要类型定义如下。

typedef union __declspec(intrin_type) _CRT_ALIGN(32) __m256 {
    float m256_f32[8];
} __m256;

typedef struct __declspec(intrin_type) _CRT_ALIGN(32) __m256d {
    double m256d_f64[4];
} __m256d;

typedef union  __declspec(intrin_type) _CRT_ALIGN(32) __m256i {
    __int8              m256i_i8[32];
    __int16             m256i_i16[16];
    __int32             m256i_i32[8];
    __int64             m256i_i64[4];
    unsigned __int8     m256i_u8[32];
    unsigned __int16    m256i_u16[16];
    unsigned __int32    m256i_u32[8];
    unsigned __int64    m256i_u64[4];
} __m256i;

在我们所遇到的数据中,坐标为单精度浮点数,所以我们这里只展示单精度相关的操作。这里用到的AVX相关操作主要 实际测试中使用该函数线性扫描比简单的线性扫描快10倍左右。这里我们来解释一下其中所用到的各个指令

  • _mm256_broadcast_ss 这个函数是将单精度浮点数复制8次,然后填充到_mm256的内存区域中;

  • _mm256_load_ps这个函数是将特定的_mm256的值复制到另外一个_mm256的内存区域中,可以认为是赋值运算符;

  • _mm256_and_ps这个函数基本就是位操作的256位扩充版,不再解释;

  • _mm256_cmp_ps 这个函数是用来比较256位的值,其比较时是通过将其中每一个32浮点单独比较,如果结果为真则对应的字节位填充为1,否则填充为0.可使用的比较操作类型有很多种,我们这里只使用了_CMP_LE_OQ_CMP_GE_OQ,分别代表小于等于和大于等于。

  • _mm256_testz_ps 这个函数有两个操作数,其作用是比较两个操作数bitand操作之后是否每一位都为0,如果都是0返回true,否则返回false。

点与区间关系判断

通过利用上述的AVX相关指令,我们定义一个搜索函数:

point_index Solution::search_linear(const Rect rect, const point_index count, Point *out_points)
{
    __m256 rect_lx = _mm256_broadcast_ss(&rect.lx);
    __m256 rect_hx = _mm256_broadcast_ss(&rect.hx);
    __m256 rect_ly = _mm256_broadcast_ss(&rect.ly);
    __m256 rect_hy = _mm256_broadcast_ss(&rect.hy);

    point_index n = 0;
    point_index i = 0;

    point_index end = m_avx_count;
    for(; i < (end - 7); i+=8) {
        __m256 x = _mm256_load_ps(m_x_coord.data() + i);
        __m256 y = _mm256_load_ps(m_y_coord.data() + i);

        __m256 x_in = _mm256_and_ps(_mm256_cmp_ps(rect_lx, x, _CMP_LE_OQ),
                                    _mm256_cmp_ps(rect_hx, x, _CMP_GE_OQ));
        __m256 y_in = _mm256_and_ps(_mm256_cmp_ps(rect_ly, y, _CMP_LE_OQ),
                                    _mm256_cmp_ps(rect_hy, y, _CMP_GE_OQ));

        if (!_mm256_testz_ps(x_in, y_in))
        {
            __m256 mask = _mm256_and_ps(x_in, y_in);
         //这里需要单独比较这8个位置,代码冗长,就不贴了
        }
    }

//处理未内存对齐的数据
    return (point_index)n;
}

所以上述代码所做的工作就是:将坐标的\(x,y\)分量单独存储,然后分别用AVX指令集来比较是否有点在区间内。如果没有则直接返回,如果有,则单独的提取其中的八个点,判断是否在区间内并加入到结果之中。

其实,我们还有另外的一种方法来使用AVX指令集。对于判断点\((x,y)\)是否在区间\((lx,ly)-(ux,uy)\)内,上面的函数所做的判断是

\begin{equation}x\ge lx \wedge x \le ux \wedge y\ge ly \wedge y\le uy\end{equation}

对于这个操作,我们可以转换为

\begin{equation}x\ge lx \wedge -x\ge -lx \wedge y\ge ly \wedge -y\ge -uy\end{equation}

同样的,我们可以转换为

\begin{equation}x\le ux \wedge -x\le -lx \wedge y\le uy \wedge -y\le -ly\end{equation}

这样,我们可以将一个点的坐标扩充为128位,然后就可以对于一个点调用_mm128_cmp_ps做单独的判断了。如果是double类型的就更好了,直接利用_mm256_cmp_ps。但是这样会增加一倍的使用空间,所带来的优势就是不需要处理对齐问题和多重判断问题。具体的测试目前还比较缺乏,待会补充。

区间关系判断

这是一个拓展问题,如何判断两个二维区间之间的关系:相交、相离、包含。我们首先以一维区间入手,两个区间分别为\((a_1,b_1),(a_2,b_2)\)

  • 两个区间相交时,\(a_1\le a_2 \le b_1 \le b_2 || a_2\le a_1 \le b_2\le b_1\)

  • 两个区间相离时,\(a_1> b_2|| a_2> b_2\)

  • 两个区间包含时,\(a_1\le a_2 \land b_2\le b_1|| a_2\le a_1 \land b_1\le b_2\)

由上述判断条件可以看出,并没有特别好的AVX指令来做这些操作。对于二维区间,判断则更加麻烦。

  • 区间相交,要求x轴相交且y轴相交;

  • 区间相离,要求x轴相离或y轴相离;

  • 区间包含,要求x轴包含且y轴包含,同时包含方向一致。

暂时没有想到特别好的利用AVX的方法。

多级窗口

在第二名所采取的方案中,还提到了多级窗口这个概念。其设计是这样的,第一级为4096个点,按照优先级排序。之后的每一级的点数目为前面一级的两倍,选取剩下点中优先级最高的一组点。然后有两个备份,分别按照x排序和y排序。在搜索时,我们首先在第一级窗口中寻找符合条件的点,找到k个之后返回。如果合适的点不足k个,则在之后的窗口中分别查询符合x区间条件的点的数目和y区间条件的点的数目,选择其中较小的那个备份中线性探查。在这个线性探查的时候,只需要探查x区间或y区间,因为我们对于另外的一半区间已经做过判断了。

在实现时发现执行热点是标准库中的二分查找函数。为了让二分查找的空间减少,我们采取了Fractional Cascading里面的方法,上级窗口中保留下级窗口中大于当前位置的节点的索引,对于x和y分别保留一份这样的索引。这个索引可以作为下一步搜索区间的大致估计,我们可以利用这几个索引调用几次小规模的二分查找,从而获得线性扫描的精确区间。这样我们就可以更一步的缩减运行时间了。

Published:
2015-09-13 22:00
Category:
Tag: