求调点分治
查看原帖
求调点分治
363006
wangyibo201026楼主2024/11/21 18:13

RT,不确定自己点分治复杂度正确性,第 99 个点死活不过去,多项式板子应该没有问题:

#include <bits/stdc++.h>

using namespace std;

#define int long long
#define fir first
#define sec second
#define poly vector<int>
#define mkp make_pair 
#define pb push_back
#define lep( i, l, r ) for ( int i = ( l ); i <= ( r ); ++ i )
#define rep( i, r, l ) for ( int i = ( r ); i >= ( l ); -- i )

typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef pair < int, int > pii;

char _c; bool _f; template < class type > inline void read ( type &x ) {
	_f = 0, x = 0;
	while ( _c = getchar (), !isdigit ( _c ) ) if ( _c == '-' ) _f = 1;
	while ( isdigit ( _c ) ) x = ( x << 1 ) + ( x << 3 ) + ( _c ^ 48 ), _c = getchar (); if ( _f ) { x = -x; }
}

template < class type > inline void chkmin ( type &x, type y ) { x = ( x <= y ? x : y ); }
template < class type > inline void chkmax ( type &x, type y ) { x = ( x >= y ? x : y ); }

const int N = 50005;
const int mod = 998244353;
const int G = 3;

int n, rt, cnt, maxi;
double ans;
int vis[N], siz[N], f[N];
vector < int > c;

int quickpow(int a,int b){
  int ret=1;
  while(b){
    if(b&1) ret=ret*a%mod;
    a=a*a%mod;
    b>>=1;
  }
  return ret;
}
int ri[1<<18];
void init_NTT(int n,int lim){
  for(int i=1;i<n;i++){
    ri[i]=(ri[i>>1]>>1)|((i&1)<<(lim-1));
  }
}

void NTT(poly &f,int op){
  int n=f.size();
  for(int i=0;i<n;i++){
      if(i<ri[i]) swap(f[i],f[ri[i]]);
  }
  for(int len=2,k=1;len<=n;len<<=1,k<<=1){
    int Wn=quickpow((op==1)?G:quickpow(G,mod-2),(mod-1)/len);
    for(int i=0;i<n;i+=len){
      for(int j=0,w=1;j<k;j++,w=1ll*w*Wn%mod){
        int x=f[i+j],y=1ll*f[i+j+k]*w%mod;
        f[i+j]=x+y;
        if(f[i+j]>=mod) f[i+j]-=mod;
        f[i+j+k]=x-y;
        if(f[i+j+k]<0) f[i+j+k]+=mod;
      }
    }
  }
  if(op==-1){
    int Inv=quickpow(n,mod-2);
    for(int i=0;i<n;i++) f[i]=1ll*f[i]*Inv%mod;
  }
}

poly operator *(poly a,poly b){
  int n=1,lim=0;int len=a.size()+b.size();
  while(n<=len) n<<=1,lim++;
  init_NTT(n,lim);
  a.resize(n),b.resize(n);NTT(a,1),NTT(b,1);
  for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
  NTT(a,-1);return a;
}

poly operator +(poly a,poly b){
  int n=max(a.size(),b.size());
  a.resize(n),b.resize(n);
  for(int i=0;i<n;i++){
    a[i]+=b[i];
    if(a[i]>=mod) a[i]-=mod;
  }return a;
}

int head[N], tot;

struct Node {
  int to, next;
} edges[N << 1];

void add ( int u, int v ) {
  tot ++;
  edges[tot].to = v;
  edges[tot].next = head[u];
  head[u] = tot;
}

void getrt ( int x, int fa ) {
  siz[x] = 1, f[x] = 0;
  for ( int i = head[x]; i; i = edges[i].next ) {
    if ( edges[i].to != fa && !vis[edges[i].to] ) {
      getrt ( edges[i].to, x );
      siz[x] += siz[edges[i].to];
      chkmax ( f[x], siz[edges[i].to] );
    }
  }
  chkmax ( f[x], cnt - siz[x] );
  if ( f[x] < f[rt] ) {
    rt = x;
  }
}

vector < int > hhx;

void init ( int x, int fa, int dis ) {
  hhx.push_back ( dis );
  maxi = max ( maxi, dis );
  for ( int i = head[x]; i; i = edges[i].next ) {
    if ( edges[i].to != fa && !vis[edges[i].to] ) {
      init ( edges[i].to, x, dis + 1 );
    }
  }
}

void calc ( int x ) {
  poly tmp;
  int flag = 0;
  for ( int i = head[x]; i; i = edges[i].next ) {
    if ( !vis[edges[i].to] ) {
      hhx.clear (), maxi = 0;
      init ( edges[i].to, x, 1 );
      vector < int > hhx2 ( maxi + 1 );
      for ( int dis : hhx ) {
        hhx2[dis] ++;
      }
      if ( !flag ) {
        tmp = hhx2;
        flag = 1;
      }
      else {
        c = c + ( tmp * hhx2 );
        tmp = tmp + hhx2;
      }
    }
  }
  if ( !tmp.size () ) {
    tmp.push_back ( 1 );
  }
  else {
    tmp[0] ++;
  }
  c = c + tmp;
}

void dfs ( int x ) {
  vis[x] = true;
  calc ( x );
  for ( int i = head[x]; i; i = edges[i].next ) {
    if ( !vis[edges[i].to] ) {
      cnt = siz[edges[i].to], rt = 0;
      getrt ( edges[i].to, 0 );
      dfs ( rt );
    }
  }
}

void Solve () {
  ios :: sync_with_stdio ( false );
  cin.tie ( 0 ), cout.tie ( 0 );
  cin >> n;
  for ( int i = 1; i < n; i ++ ) {
    int u, v;
    cin >> u >> v;
    u ++, v ++;
    add ( u, v ), add ( v, u );
  }
  f[0] = cnt = n;
  getrt ( 1, 0 );
  dfs ( rt );
  for ( int i = 0; i < c.size (); i ++ ) {
    ans += 1.0 * c[i] / ( i + 1 ) * ( !i ? 1 : 2 );
  }
  cout << fixed << setprecision ( 4 ) << ans;
}

signed main () {
#ifdef judge
  freopen ( "Code.in", "r", stdin );
  freopen ( "Code.out", "w", stdout );
  freopen ( "Code.err", "w", stderr );
#endif
  Solve ();
	return 0;
}
2024/11/21 18:13
加载中...