虚树 Virtual-Tree

Author Avatar
空気浮遊 2018年05月29日
  • 在其它设备中阅读本文章

虚树算数据结构?也许算吧。

总览

虚树用于处理一类询问总点数较小的问题。
具体而言,某类问题给定 m 个询问(一般 1e5 级别),每个询问查询 ki 个点。一般而言,处理每次询问的复杂度是 $O(n)$ 的。但是由于 $\sum k_i$ 可能有特殊限制,所以我们可以构建虚树将每次查询的点都给提到一个树上,再在上面进行本次询问的查询,这样每次询问的复杂度就被压到了 $O(k_i)$(或者再乘上某个 log/ 常数)
巧的是,虚树上的边恰好可以代表原树上两点的一条链。

具体而言问题被拆成了两个部分:

  1. 如何构建虚树
  2. 如何在虚树上得到某次询问的答案

建立虚树

提供两个博客:
https://blog.sengxian.com/algorithms/virtual-tree
https://www.cnblogs.com/zzqsblog/p/5560645.html

建议查看第一个。查看第一个足以弄懂所有东西。

我会对第一个博客的方法进行补充。

所谓右链维护,是指持续维护树上的链,在链上进行点的插入 / 删除方法。
比如我现在一条链一直往下插,然后我觉得插够了,需要回到链的某个中间点分支出去。那么我们就把栈中的点持续弹出(此时被弹出的点的父亲都应当已经决定),停到我们要分支的点,然后再往栈里插点。这样子的话,每次弹出栈中的一个点,它的父亲都能被决定(栈中的上一个点),由此得出整个树的形状。

在虚树上进行查询

虚树构建出来之后,我们一般在上面进行树形 dp。
这要求我们对虚树的边权进行处理。
比如对于例题:[SDOI2011]消耗战 https://www.luogu.org/problemnew/show/P2495
虚树的边权为实际两点链上的最小边权。
在此上进行 dp,就可以 a 掉该题。

本题的树状 dp 可以自行思考。

大概虚树的查询都是差不多的,重点在于如何快速将链上的信息 / 其它信息压缩到虚树的边上的信息中。

对于虚树和该题仍然存在诸多细节,因此可以查看下方的代码来加深理解:

// Code by ajcxsu
// Problem: continous battle

#include<bits/stdc++.h>
using namespace std;
typedef long long ll; 
int n;

template<typename T> void gn(T &x) {
    char ch=getchar();
    x=0;
    while(ch<'0' || ch>'9') ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0', ch=getchar();
}

const int N=3e5, M=1e6+10;
int h[N], vt[N], to[M], nexp[M], W[M], p=1; // vt -> virtual tree
inline void ins(int a, int b, int w) { nexp[p]=h[a], h[a]=p, to[p]=b, W[p]=w, p++; }
inline void vins(int a, int b, int w) {    nexp[p]=vt[a], vt[a]=p, to[p]=b, W[p]=w, p++; }

const int OP=20;
int dfn[N], dep[N], gup[N][OP], mi[N][OP], idx;
void dfs(int x, int k) {
    dep[x]=k, dfn[x]=++idx;
    for(int u=h[x];u;u=nexp[u])
        if(!dfn[to[u]]) {
            gup[to[u]][0]=x;
            mi[to[u]][0]=W[u];
            dfs(to[u], k+1);
        }
}
void ini() {
    dfs(1,1);
    for(int j=1;j<OP;j++)
    for(int i=1;i<=n;i++)
        gup[i][j]=gup[gup[i][j-1]][j-1],
        mi[i][j]=min(mi[i][j-1], mi[gup[i][j-1]][j-1]);
}
int lca(int s, int t, bool isw=0) {
    int w=0x3f3f3f3f;
    if(dep[s]<dep[t]) swap(s,t);
    for(int j=OP-1;j>=0;j--)
        if(dep[gup[s][j]]>=dep[t]) w=min(w, mi[s][j]), s=gup[s][j];
    if(s!=t) {
        for(int j=OP-1;j>=0;j--)
            if(gup[s][j]!=gup[t][j])
                w=min(w, min(mi[s][j], mi[t][j])), 
                s=gup[s][j], t=gup[t][j];
        w=min(w, min(mi[s][0], mi[t][0]));
        s=gup[s][0], t=gup[t][0];
    }
    return !isw?s:w;
}


bool cmp(const int &a, const int &b) { return dfn[a]<dfn[b]; }

int lis[N], ki;
int stk[N], sz; // 栈点的深度一定递增,一定维护一条链
bool key[N];

ll dfs2(int x, ll w) {
    if(w==0x3f3f3f3f) w=0x7fffffffffffll;
    ll ret=0;
    int siz=0;
    for(int u=vt[x];u;u=nexp[u])
        ret+=dfs2(to[u], W[u]), siz++;
    if(key[x]) ret=0x7fffffffffffll;
    else if(siz==1 && w<0x3f3f3f3f) ret=0;
    vt[x]=0, key[x]=0;
    ret=min(ret, w);
    return ret;
}
int bp;
ll solve() {
    p=bp;
    sort(lis, lis+ki, cmp);
    int d;
    stk[++sz]=0; // 加入超级根节点 
    for(int i=0;i<ki;i++) {
        d=lca(stk[sz], lis[i]); // 查找栈顶与d的lca 
        while(sz-1>0 && dep[stk[sz-1]]>=dep[d]) vins(stk[sz-1], stk[sz], lca(stk[sz-1], stk[sz], true)), sz--;
        // 如果这条链上有深度>=lca的点,则链需要被退回,并在退回的同时加上需要加上的边 
        // 直到这条链上最后只剩一个深度>=lca的点。
        // 如果这条链上本来就没有深度>=lca 或者 只有一个, 那么这个循环不会被执行。 
        if(stk[sz]!=d) vins(d, stk[sz], lca(d, stk[sz], true)), stk[sz]=d;
        // 如果栈顶不是lca的话,说明链被退回。最后退一次链即可,并在链中加入lca。
        // 如果栈顶原本就是lca的话,说明链会被拉长,那么不会存在深度>=lca的点,并且不会执行这个if语句 
        stk[++sz]=lis[i]; // 延长这条链 
    }
    if(stk[2]!=1) vins(stk[1], stk[2], lca(1, stk[2], true)); // 由于是与根节点断开,因此虚树必须含有根节点 
    else vins(stk[1], stk[2], 0x3f3f3f3f); // 如果已经含有根节点,因为根节点的父亲不能断开,所以置为INF 
    for(int i=2;i<sz;i++) vins(stk[i], stk[i+1], lca(stk[i], stk[i+1], true));
    sz=0;
    // 由于当一个点被退出的时候才会加边
    // 最终处理完之后我们还需要把栈里的所有的点都加入虚树。 
    return dfs2(0,0x7fffffffffffll);
}


int main() {
    gn(n);
    int u,v,w;
    for(int i=1;i<n;i++) gn(u), gn(v), gn(w), ins(u,v,w), ins(v,u,w);
    ini();
    bp=p;
    int m;
    gn(m);
    while(m--) {
        gn(ki);
        for(int i=0;i<ki;i++) gn(lis[i]), key[lis[i]]=1;
        printf("%lld\n",solve());
    }
    return 0;
}
    ASL
    ASL  2019-01-21, 19:42

    Orz

      ajcxsu
      ajcxsu  2019-01-21, 19:43

      orz yyj tql

    Ayyj
    Ayyj  2019-01-21, 19:45

    @ASL faq