题目链接
(资料图片仅供参考)
这题的方法口糊一下没有很难,没达到3500的水准。但是写起来才发现是真的恶心(主要是容易写错),没写过这么累的题,可能难度就体现在这里吧。
计数的时候是要分类讨论的,但是核心算法都一样:启发式合并,线段树合并。把\(m^2\)对路径分成以下三类,分别统计合法的:
两条路径的LCA不同(路径的LCA指的是两个端点的LCA)。发现这两条路径的LCA必须是祖先和后代的关系,不然两条路径不可能有重合。
比如图中的红蓝两条路径就属于这一类,考虑在×处(下面两个端点的LCA)把它们统计进答案。可以在dfs的同时用线段树合并维护所有 有端点在子树内的路径的LCA的深度。在合并两个儿子的时候,把线段树中值的数量较小的拿出来,遍历其中所有的元素,并在大的那个儿子的线段树中询问得到能和当前元素匹配的数量。这部分的复杂度是\(O(nlog^2n)\)。由于n和m同阶,都用n表示了。
两条路径的LCA相同,且它们重合的部分分布在LCA的两个子树中。像下面这样:
这种情况和下面的一种情况都需要把所有LCA为x的路径都放到点x处,统一处理它们之间产生的贡献。假设现在处理LCA为root的所有的路径。把这些路径的端点以及root都拿出来建一棵虚树。为了避免重复计数,对于任意两条需要被计数的路径,我们都在它们在原树中dfs序较小的两个端点的LCA处统计,比如上面图中的×处。还是用线段树合并+启发式合并,但这次线段树中只维护每条路径dfs序较小的那个端点的信息。令当前点为pos,在遍历较小的儿子线段树中的一条路径(x,y)时,假设x在pos子树内,y在root的另外一个子树内,则如果我们沿着x→y的方向走k步到点z,那么合法的匹配路径的端点都在z的子树内。同样可以在线段树上查询来统计。
两条路径的LCA相同,且它们重合的部分分布在LCA的一个子树中。
这种情况的统计方法和上面是类似的。为了保证重合部分只在一个子树内,需要一次额外dfs对每个点求出它在root的哪个子树里。
总时间复杂度\(O(nlog^2n)\)。
调试太痛苦了
点击查看代码
#include #define rep(i,n) for(int i=0;i#define fi first#define se second#define mpr make_pair#define pb push_backvoid fileio(){ #ifdef LGS freopen("in.txt","r",stdin); freopen("out.txt","w",stdout); #endif}void termin(){ #ifdef LGS std::cout<<"\n\nEXECUTION TERMINATED"; #endif exit(0);}using namespace std;LL n,q,t,fa[150010][23],dep[150010],dfn[150010],ed[150010],ans=0,X[150010],Y[150010],LCA[150010];vector g[150010],tg[150010],dford;LL ll=0;void dfsPre(int pos,int par,int d){ fa[pos][0]=par;dep[pos]=d;dford.pb(pos); dfn[pos]=ll++; rep(i,g[pos].size()) if(g[pos][i]!=par) dfsPre(g[pos][i],pos,d+1); ed[pos]=ll-1;}int getLCA(int x,int y){ for(int i=19;i>=0;--i) if(fa[x][i]>0&&dep[fa[x][i]]>=dep[y]) x=fa[x][i]; for(int i=19;i>=0;--i) if(fa[y][i]>0&&dep[fa[y][i]]>=dep[x]) y=fa[y][i]; if(x==y) return x; for(int i=19;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0];}LL getAnces(LL x,LL y){rep(i,20) if(y&(1<>1; if(to<=mid) ls[ret]=newTree(lb,mid,to); else rs[ret]=newTree(mid+1,ub,to); return ret; } LL upd(LL k,LL lb,LL ub,LL to) { if(k==0) k=newNode(); ++dat[k]; if(lb==ub) return k; LL mid=(lb+ub)>>1; if(to<=mid) ls[k]=upd(ls[k],lb,mid,to); else rs[k]=upd(rs[k],mid+1,ub,to); return k; } vector res; void getAll(LL k,LL lb,LL ub) { if(k==0) return; if(lb==ub) { rep(i,dat[k]) res.pb(lb); return; } LL mid=(lb+ub)>>1; getAll(ls[k],lb,mid);getAll(rs[k],mid+1,ub); } vector getAll(LL root) { res.clear(); getAll(root,0,n2-1); return res; } LL qry(LL k,LL lb,LL ub,LL tlb,LL tub) { if(k==0||ub>1,tlb,tub)+qry(rs[k],((lb+ub)>>1)+1,ub,tlb,tub); } LL merge(LL a,LL b) { if(a==0||b==0) return a|b; dat[a]+=dat[b]; ls[a]=merge(ls[a],ls[b]);rs[a]=merge(rs[a],rs[b]); return a; }}namespace part1{ vector v[150010]; LL combine(LL a,LL b,LL curdep) { if(a==0||b==0) return a|b; if(st::dat[a] vec=st::getAll(b); rep(i,vec.size()) { if(vec[i]>curdep-t) continue; LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,curdep-t); ans+=v1+v2; } a=st::merge(a,b); return a; } LL dfs(LL pos,LL par) { LL ret=0; rep(i,v[pos].size()) { LL nxt=st::newTree(0,st::n2-1,v[pos][i]); ret=combine(ret,nxt,dep[pos]); } rep(i,g[pos].size()) if(g[pos][i]!=par) { LL nxt=dfs(g[pos][i],pos); ret=combine(ret,nxt,dep[pos]); } return ret; } void countDiffLCA() { rep(i,q) { v[X[i]].pb(dep[LCA[i]]); v[Y[i]].pb(dep[LCA[i]]); } st::init(n); dfs(1,0); }}namespace part2{ vector pths[150010]; LL curroot,rootdep; vector realver; void buildVT(vector vers) { realver.clear(); rep(i,vers.size()) tg[vers[i]].clear(); sort(vers.begin(),vers.end());vers.erase(unique(vers.begin(),vers.end()),vers.end()); sort(vers.begin(),vers.end(),[](LL xx,LL yy){return dfn[xx] stk;stk.push(vers[0]); realver=vers; repn(i,vers.size()-1) { LL pos=vers[i],lca=getLCA(pos,stk.top()); if(lca==stk.top()) stk.push(pos); else { while(dep[stk.top()]>dep[lca]) { int pp=stk.top();stk.pop(); int nn=stk.top();if(dep[nn]1) { int pp=stk.top();stk.pop(); tg[stk.top()].pb(pp); } } vector v[150010]; LL fr[150010]; LL walk(LL curpos,LL to,LL stp) { LL rd=dep[getLCA(curpos,to)]; LL tot=dep[curpos]+dep[to]-rd*2; if(tot vec=st::getAll(b);rep(i,vec.size()) vec[i]=dford[vec[i]]; rep(i,vec.size()) { LL walkdist=max(t,dep[curpos]-rootdep+1),to=walk(curpos,vec[i],walkdist); if(to==-1) continue; LL vv=st::qry(a,0,st::n2-1,dfn[to],ed[to]); ans+=vv; } a=st::merge(a,b); return a; } LL dfsTwo(LL pos) { LL ret=0; rep(i,v[pos].size()) { LL nxt=st::newTree(0,st::n2-1,dfn[v[pos][i]]); if(pos!=curroot) ret=combineTwo(ret,nxt,pos); } rep(i,tg[pos].size()) { LL nxt=dfsTwo(tg[pos][i]); if(pos!=curroot) ret=combineTwo(ret,nxt,pos); } return ret; } void dfsMarkFr(LL pos,LL mk) { if(mk==-1&&pos!=curroot) mk=dfn[pos]; fr[pos]=mk; rep(i,tg[pos].size()) dfsMarkFr(tg[pos][i],mk); } LL combineOne(LL a,LL b) { if(a==0||b==0) return a|b; if(st::dat[a] vec=st::getAll(b); rep(i,vec.size()) { if(vec[i]==dfn[curroot]) { ans+=st::dat[a]; continue; } LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,st::n2-1); ans+=v1+v2; } a=st::merge(a,b); return a; } LL dfsOne(LL pos) { LL ret=0; rep(i,v[pos].size()) { LL nxt=st::newTree(0,st::n2-1,fr[v[pos][i]]); if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt); } rep(i,tg[pos].size()) { LL nxt=dfsOne(tg[pos][i]); if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt); } return ret; } void countSameLCA() { rep(i,q) { if(dfn[X[i]]>dfn[Y[i]]) swap(X[i],Y[i]); pths[LCA[i]].pb(mpr(X[i],Y[i])); } repn(root,n) if(pths[root].size()) { curroot=root;rootdep=dep[root]; vector vers={root}; rep(i,pths[root].size()) vers.pb(pths[root][i].fi),vers.pb(pths[root][i].se); buildVT(vers); rep(i,realver.size()) v[realver[i]].clear(); rep(i,pths[root].size()) if(pths[root][i].fi!=root&&pths[root][i].se!=root) v[pths[root][i].fi].pb(pths[root][i].se); st::init(n); dfsTwo(root); dfsMarkFr(root,-1);fr[root]=dfn[root]; rep(i,realver.size()) v[realver[i]].clear(); rep(i,pths[root].size()) v[pths[root][i].fi].pb(pths[root][i].se),v[pths[root][i].se].pb(pths[root][i].fi); st::init(n); dfsOne(root); } }}int main(){ fileio(); cin>>n>>q>>t; LL x,y; rep(i,n-1) { scanf("%lld%lld",&x,&y); g[x].pb(y);g[y].pb(x); } dfsPre(1,0,0); rep(i,20) repn(j,n) fa[j][i+1]=fa[fa[j][i]][i]; rep(i,q) { scanf("%lld%lld",&X[i],&Y[i]); LCA[i]=getLCA(X[i],Y[i]); } part1::countDiffLCA(); part2::countSameLCA(); cout<