RT
哪里写假了吗。。
应该是O((n+m)logn)的呀。。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cmath>
#include<map>
#include<bitset>
#include<set>
#include<queue>
using namespace std;
typedef long long ll;
int T,n,m;
int tot,head[100010];
int rt,sum,mx,sz[100010];
int dep[100010],num[100010],dc[100010],q[100010],ans[100010];
int idx,st[100010],l,que[100010];
bool vis[100010];
vector<int> vec[100010];
struct EDGE
{
int nxt,to;
}edge[200010];
inline void add(int u,int v)
{
edge[++tot].nxt=head[u];
edge[tot].to=v;
head[u]=tot;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return x*f;
}
void getrt(int x,int f)
{
sz[x]=1;
int son=0;
for(int i=head[x];i;i=edge[i].nxt)
{
int y=edge[i].to;
if(y==f||vis[y]) continue;
getrt(y,x);
sz[x]+=sz[y];
son=max(son,sz[y]);
}
son=max(son,sum-sz[x]);
if(mx>son) rt=x,mx=son;
}
void getdep(int x,int f)
{
dep[x]=dep[f]+1;
st[++idx]=x;
++dc[dep[x]];
if(!vec[x].empty()) que[++l]=x;
for(int i=head[x];i;i=edge[i].nxt)
{
int y=edge[i].to;
if(y==f||vis[y]) continue;
getdep(y,x);
}
}
void calc(int x)
{
dep[x]=0;
num[0]=1;
vector<int> tmp,qs;
for(int i=head[x];i;i=edge[i].nxt)
{
int y=edge[i].to;
if(vis[y]) continue;
idx=0,l=0;
getdep(y,x);
for(int j=1;j<=l;++j)
{
int p=que[j];
qs.push_back(p);
for(int k=0;k<vec[p].size();++k)
{
int t=vec[p][k];
if(q[t]>=dep[p]) ans[t]-=dc[q[t]-dep[p]];
}
}
for(int j=1;j<=idx;++j)
{
dc[dep[st[j]]]=0;
tmp.push_back(st[j]);
++num[dep[st[j]]];
}
}
if(!vec[x].empty()) qs.push_back(x);
for(int i=0;i<qs.size();++i)
{
int p=qs[i];
for(int k=0;k<vec[p].size();++k)
{
int t=vec[p][k];
if(q[t]>=dep[p]) ans[t]+=num[q[t]-dep[p]];
}
}
for(int i=0;i<tmp.size();++i)
{
--num[dep[tmp[i]]];
}
}
void getans(int x)
{
vis[x]=1;
calc(x);
for(int i=head[x];i;i=edge[i].nxt)
{
int y=edge[i].to;
if(vis[y]) continue;
rt=0,mx=sum=sz[y];
getrt(y,x);
getans(rt);
}
}
inline void clear_all()
{
tot=0;
for(int i=1;i<=n;++i)
{
head[i]=0;
dep[i]=0;
vis[i]=0;
vec[i].clear();
}
}
inline void solve()
{
n=read(),m=read();
clear_all();
for(int i=2;i<=n;++i)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
for(int i=1;i<=m;++i)
{
int x=read();
q[i]=read();
vec[x].push_back(i);
ans[i]=0;
}
rt=0,sum=mx=n;
getrt(1,0);
getans(rt);
for(int i=1;i<=m;++i)
{
printf("%d\n",ans[i]);
}
}
int main()
{
T=read();
while(T--)
solve();
return 0;
}