萌新(相对)(不是妹子)树链剖分求助,WA,20分
查看原帖
萌新(相对)(不是妹子)树链剖分求助,WA,20分
120855
清雅流水楼主2020/8/19 15:57

RT,标准思路

#include <bits/stdc++.h>
using namespace std ;
int n , m , r , p ;
int fir[100010] , nxt[200010] , zx[200010] , bs ;
int yl[100010] , bh[100010] , a[100010] , ds , fa[100010] , si[100010] , dep[100010] ;
int son[100010] , c[400010] , lz[400010] , top[100010] , bsi[100010] , bdep[100010] ;
int bfa[100010] ;
void build( int wz , int l , int r )
{
	if( l == r )
	{
		a[l] %= p ;
		c[wz] = a[l] ;
		return ;
	}
	int m = ( l + r ) / 2 ;
	build( wz * 2 , l , m ) ;
	build( wz * 2 + 1 , m + 1 , r ) ;
	c[wz] = c[wz * 2] + c[wz * 2 + 1] ;
	c[wz] %= p ;
}
void pd( int wz , int l , int r )
{
	if( l == r )
	{
		lz[wz] = 0 ;
		return ;
	}
	int m = ( l + r ) / 2 ;
	c[wz * 2] += lz[wz] * ( m - l + 1 ) ;
	c[wz * 2] %= p ;
	lz[wz * 2] += lz[wz] ;
	lz[wz * 2] %= p ;
	c[wz * 2 + 1] += lz[wz] * ( r - m ) ;
	c[wz * 2 + 1] %= p ;
	lz[wz * 2 + 1] += lz[wz] ;
	lz[wz * 2 + 1] %= p ;
	lz[wz] = 0 ;
}
void gg( int wz , int l , int r , int ll , int rr , int z )
{
	if( ll < l )
	{
		ll = l ;
	}
	if( rr > r )
	{
		rr = r ;
	}
	if( l == ll && r == rr )
	{
		c[wz] += ( r - l + 1 ) * z ;
		c[wz] %= p ;
		lz[wz] += z ;
		lz[wz] %= p ;
		return ;
	}
	if( lz[wz] != 0 )
	{
		pd( wz , l , r ) ;
	}
	int m = ( l + r ) / 2 ;
	if( ll <= m )
	{
		gg( wz * 2 , l , m , ll , rr , z ) ;
	}
	if( rr > m )
	{
		gg( wz * 2 + 1 , m + 1 , r , ll , rr , z ) ;
	}
	c[wz] = c[wz * 2] + c[wz * 2 + 1] ;
	c[wz] %= p ;
}
int fh( int wz , int l , int r , int ll , int rr )
{
	if( ll < l )
	{
		ll = l ;
	}
	if( rr > r )
	{
		rr = r ;
	}
	if( l == ll && r == rr )
	{
		return c[wz] ;
	}
	if( lz[wz] != 0 )
	{
		pd( wz , l , r ) ;
	}
	int m = ( l + r ) / 2 , ans = 0 ;
	if( ll <= m )
	{
		ans += fh( wz * 2 , l , m , ll , rr ) ;
	}
	if( rr > m )
	{
		ans += fh( wz * 2 + 1 , m + 1 , r , ll , rr ) ;
	}
	return ans % p ;
}
void dfs1( int wz )
{
	si[wz] = 1 ;
	int max = -1 ;
	for( int i = fir[wz] ; i != -1 ; i = nxt[i] )
	{
		if( zx[i] == fa[wz] ) continue ;
		dep[zx[i]] = dep[wz] + 1 ;
		fa[zx[i]] = wz ;
		dfs1( zx[i] ) ;
		si[wz] += si[zx[i]] ;
		if( si[zx[i]] > si[max] || max == -1 )
		{
			max = zx[i] ;
		}
	}
	son[wz] = max ;
}
void dfs2( int wz , int ftp )
{
//	cout<< wz << ' ' << ds << endl ;
	bh[wz] = ds ;
	bsi[ds] = si[wz] ;
	bdep[ds] = dep[wz] ;
	if( fa[wz] == -1 ) bfa[ds] = -1 ;
	else bfa[ds] = bh[fa[wz]] ;
	a[ds] = yl[wz] ;
	top[ds] = bh[ftp] ;
	ds ++ ;
	if( son[wz] == -1 )
	{
		return ;
	}
	dfs2( son[wz] , ftp ) ;
	for( int i = fir[wz] ; i != -1 ; i = nxt[i] )
	{
		if( zx[i] == fa[wz] || zx[i] == son[wz] )
		{
			continue ;
		}
		dfs2( zx[i] , zx[i] ) ;
	}
}
void c1( int x , int y , int z )
{
	x = bh[x] ;
	y = bh[y] ;
	z %= p ;
	while( top[x] != top[y] )
	{
		if( bdep[top[x]] < bdep[top[y]] )
		{
			swap( x , y ) ;
		}
		gg( 1 , 0 , ds - 1 , top[x] , x , z ) ;
		x = bfa[top[x]] ;
	}
	if( bdep[x] < bdep[y] )
	{
		swap( x , y ) ;
	}
	gg( 1 , 0 , ds - 1 , y , x , z ) ;
}
int c2( int x , int y )
{
	x = bh[x] ;
	y = bh[y] ;
//	cout<< x << ' ' << y << endl ;
	int ans = 0 ;
	while( top[x] != top[y] )
	{
//		cout<< x << ' ' << y << endl ;
		if( bdep[top[x]] < bdep[top[y]] )
		{
			swap( x , y ) ;
		}
		ans += fh( 1 , 0 , ds - 1 , top[x] , x ) ;
		ans %= p ;
		x = bfa[top[x]] ;
	}
	if( bdep[x] < bdep[y] )
	{
		swap( x , y ) ;
	}
	ans += fh( 1 , 0 , ds - 1 , y , x ) ;
	return ans % p ;
}
void c3( int x , int z )
{
	gg( 1 , 0 , ds - 1 , bh[x] , bh[x] + si[x] - 1 , z ) ;
}
int c4( int x )
{
	return fh( 1 , 0 , ds - 1 , bh[x] , bh[x] + si[x] - 1 ) % p ;
}
int main ()
{
//	freopen( "P3384_2.in" , "r" , stdin ) ;
	memset( son , -1 , sizeof(son) ) ;
	memset( fir , -1 , sizeof(fir) ) ;
	cin >> n >> m >> r >> p ;
	p = 1000000 ;
	r -- ;
	for( int i = 0 ; i < n ; i ++ )
	{
		cin >> yl[i] ;
	}
	int t1 , t2 , t3 , t4 ;
	for( int i = 1 ; i < n ; i ++ )
	{
		cin >> t1 >> t2 ;
		t1 -- ;
		t2 -- ;
		nxt[bs] = fir[t1] ;
		fir[t1] = bs ;
		zx[bs] = t2 ;
		bs ++ ;
		nxt[bs] = fir[t2] ;
		fir[t2] = bs ;
		zx[bs] = t1 ;
		bs ++ ;
	}
	fa[r] = -1 ;
	dep[r] = 0 ;
	dfs1( r ) ;
	dfs2( r , r ) ;
	build( 1 , 0 , ds - 1 ) ;
	for( int i = 0 ; i < m ; i ++ )
	{
		cin >> t1 ;
		if( t1 == 1 )
		{
			cin>> t2 >> t3 >> t4 ;
			t2 -- ;
			t3 -- ;
			c1( t2 , t3 , t4 ) ;
		}
		else if( t1 == 2 )
		{
			cin >> t2 >> t3 ;
			t2 -- ;
			t3 -- ;
			cout<< c2( t2 , t3 ) << endl ;
		}
		else if( t1 == 3 )
		{
			cin >> t2 >> t3 ;
			t2 -- ;
			c3( t2 , t3 ) ;
		}
		else if( t1 == 4 )
		{
			cin >> t2 ;
			t2 -- ;
			cout<< c4( t2 ) << endl ;
		}
	}
	return 0 ;
}
2020/8/19 15:57
加载中...