K-DTree Summary

Author Avatar
空気浮遊 2021年07月22日
  • 在其它设备中阅读本文章

关于 K -DTree 的一些总结。仅用于实践,不打算细究原理。

平衡二叉树。以点构建超平面分割区域。

Reference

https://www.luogu.com.cn/blog/van/qian-tan-pian-xu-wen-ti-yu-k-d-tree
https://oi-wiki.org/ds/kdt
https://www.zybuluo.com/l1ll5/note/967681

构建方式

2-DTree

第一层以 $x$ 轴为关键字,第二层以 $y$ 轴为关键字,随后交替。

对此时用于构建 2 -DTree 的点集 $S$ 找到关键字下从小至大排序的中位数点 $S_{mid}$,关键字小于 $S_{mid}$ 的分左边,大的分右边。以此类推构建平衡二叉树。

具体实现用nth_element(*begin, *mid, *end)

K-DTree

跟 2 - D 差不多。推广一下应该就行了。

应用

平面最近邻(欧几里得距离)

构建估价函数估算点到当前区域内点距离的下界,当遍历 2 -DTree 的时候估算的下界 $ > ans$ 则可退出不再遍历搜索(剪枝)。

用分类讨论得到的估算函数为 $\sum \max^2 (areaL_{k}-dot_{k}, 0) + \max^2 (dot_{k} - areaR_{k}, 0)$。

下方是依据例题 LP1429 所实现的模板。

#include<bits/stdc++.h>
using namespace std;

const int N=2e5+10;
struct Point { double x, y; Point(double x=0, double y=0):x(x), y(y) {} } po[N];
bool cmp1(const Point &x, const Point &y) { return x.x<y.x; }
bool cmp2(const Point &x, const Point &y) { return x.y<y.y; }

struct Node *nil;
struct Node {
    Point v; int id;
    Node *ls, *rs;
    Node() { ls=rs=nil; }
} *rt=nil;
void init() {
    nil=new Node(), nil->ls=nil->rs=nil;
    rt=nil;
}
void build(Node *&x, int l, int r, int t) {
    if(l>r) return;
    if(x==nil) x=new Node();
    int mid=(l+r)>>1;
    nth_element(po+l, po+mid, po+1+r, t==0?cmp1:cmp2);
    x->v=po[mid], x->id=mid;
    build(x->ls, l, mid-1, t^1), build(x->rs, mid+1, r, t^1);
}
#define p2(x) ((x)*(x))
inline double dis(Point a, Point b) { return sqrt(p2(a.x-b.x)+p2(a.y-b.y)); }
void query(Node *x, double xl, double xr, double yl, double yr, int t, Point v, int id, double &ans) {
    if(x==nil) return;
    if(p2(max(xl-v.x, 0.0))+p2(max(v.x-xr, 0.0))+p2(max(yl-v.y, 0.0))+p2(max(v.y-yr, 0.0))>ans*ans) return; // 估价函数
    if(x->id!=id) ans=min(ans, dis(x->v, v));
    query(x->ls, xl, t==0?x->v.x:xr, yl, t==1?x->v.y:yr, t^1, v, id, ans);
    query(x->rs, t==0?x->v.x:xl, xr, t==1?x->v.y:yl, yr, t^1, v, id, ans);
}

int main() {
    int n;
    scanf("%d", &n);
    for(int i=1; i<=n; i++) scanf("%lf%lf", &po[i].x, &po[i].y);
    double ans=1e21;
    init();
    build(rt, 1, n, 0);
    for(int i=1; i<=n; i++) {
        query(rt, 0, 1e9, 0, 1e9, 0, po[i], i, ans);
    }
    printf("%.4lf\n", ans);
    return 0;
}