是的,只 WA 了一个点,我要炸了。
写了 7.3KB。
最后找合法的强联通分量的地方不是很包对,但是感觉是没问题的,逻辑是一个拓扑排序,当且仅当一个有效块前面没有有效块才能被计入答案。当然除了这个我不确定的地方别的地方大概也可能有错。
至于代码中的 array_fs
你就当它是 vector<int>[]
即可。
// Problem: P7215 [JOISC 2020] 首都
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P7215
// Memory Limit: 488 MB
// Time Limit: 2500 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
//#include <bits/extc++.h>
#define multiple_cases(T) signed T; cin >> T; while (T--)
#define variable_name(x) #x
#define vdebug(x) variable_name(x) << " = " << x
#define pdebug(x) "*" << variable_name(x) << " = " << *(x)
#define all(x) (x).begin(), (x).end()
#define adebug(a, b, c) #a << "[" << (b) << "..." << (c) << "] = " << vector<typename decay<decltype(*((a) + (b)))>::type>((a) + (b), (a) + (c) + 1)
#define file_io(a) (freopen(#a ".in", "r", stdin), freopen(#a ".out", "w", stdout))
using namespace std;
//using namespace __gnu_cxx;
//using namespace __gnu_pbds;
const int mod = 998244353;
template<typename T, typename Tb>
T quickpower(T a, Tb b) {
if (b < 0) b = b % (mod - 1) + mod - 1;
T c = 1;
while (b) {
if (b & 1) {
c *= a;
c %= mod;
}
a *= a;
a %= mod;
b >>= 1;
}
return c;
}
template<typename T, typename Tb>
T auto_quickpower(T a, Tb b) {
T c = 1;
while (b) {
if (b & 1) {
c *= a;
}
a *= a;
b >>= 1;
}
return c;
}
namespace quick_io {
#if defined(__GNUC__) && defined(__SIZEOF_INT128__)
#ifndef quick_io_can_use_int128
#define quick_io_can_use_int128
#define mark_can_use_in_quick_io
#endif
#endif
#ifdef quick_io_can_use_int128
#define quick_io_int_length_limit 128
#define quick_io_int_radix_type size_t
#define quick_io_length_type size_t
bool is_digit[128];
quick_io_int_radix_type cur_int_radix = 0;
char __quick_io_d2c(size_t a) {
if (a <= 9) {
return '0' + a;
} else if (a <= 35) {
return 'A' + a - 10;
} else {
return 'a' + a - 36;
}
}
size_t __quick_io_c2d(char a) {
if (a >= '0' && a <= '9') {
return a - '0';
} else if (a >= 'A' && a <= 'Z') {
return a - 'A' + 10;
} else /*if (a >= 'a' && a <= 'z')*/ {
return a - 'a' + 36;
}
}
struct quick_io_change_int_radix {
quick_io_change_int_radix(quick_io_int_radix_type new_int_radix) {
if (cur_int_radix == new_int_radix) return;
cur_int_radix = new_int_radix;
memset(is_digit, 0, sizeof(is_digit));
for (size_t i = 0; i < new_int_radix; i++) {
is_digit[size_t(__quick_io_d2c(i))] = true;
}
}
} set_radix(10);
// Weak robustness, maybe can't be correct if the input format changes
// Very slow functions
ostream &operator<<(ostream &A, __int128 b) {
char a[128], k = 0;
memset(a, 0, sizeof(a));
quick_io_length_type pos = 0;
unsigned __int128 b_ = b;
if (b < 0) k = -1, b_ = -b;
while (b_) {
a[pos++] = __quick_io_d2c(b_ % cur_int_radix);
b_ /= cur_int_radix;
}
if (!~k) {
a[pos++] = '-';
}
if (!pos) {
a[pos++] = '0';
}
reverse(a, a + pos);
return A << a;
}
istream &operator>>(istream &A, __int128 &b) {
int k = 1;
char c = A.get();
b = 0;
while (c != '-' && c != '+' && !is_digit[size_t(c)]) {
c = A.get();
}
if (c == '-') {
k = -1;
c = A.get();
} else if (c == '+') {
c = A.get();
}
while (is_digit[size_t(c)]) {
b = b * cur_int_radix + __quick_io_c2d(c);
c = A.get();
}
if (!~k) b = -b;
A.unget();
return A;
}
ostream &operator<<(ostream &A, unsigned __int128 b) {
char a[128];
memset(a, 0, sizeof(a));
quick_io_length_type pos = 0;
while (b) {
a[pos++] = __quick_io_d2c(b % cur_int_radix);
b /= cur_int_radix;
}
if (!pos) {
a[pos++] = '0';
}
reverse(a, a + pos);
return A << a;
}
istream &operator>>(istream &A, unsigned __int128 &b) {
char c = A.get();
b = 0;
while (!is_digit[size_t(c)]) {
c = A.get();
}
while (is_digit[size_t(c)]) {
b = b * cur_int_radix + __quick_io_c2d(c);
c = A.get();
}
A.unget();
return A;
}
#ifdef mark_can_use_in_quick_io
#undef quick_io_can_use_int128
#undef mark_can_use_in_quick_io
#endif
#endif
template<typename T>
ostream &operator<<(ostream &A, const vector<T> &b);
template<typename T>
ostream &operator<<(ostream &A, const deque<T> &b);
template<typename T1, typename T2>
ostream &operator<<(ostream &A, const pair<T1, T2> &b);
template<typename T>
ostream &operator<<(ostream &A, const set<T> &b);
template<typename T>
ostream &operator<<(ostream &A, const multiset<T> &b);
template<typename T, typename T2>
ostream &operator<<(ostream &A, const map<T, T2> &b);
template<typename T, typename T2>
ostream &operator<<(ostream &A, const multimap<T, T2> &b);
template<typename T>
ostream &operator<<(ostream &A, const unordered_set<T> &b);
template<typename T>
ostream &operator<<(ostream &A, const unordered_multiset<T> &b);
template<typename T, typename T2>
ostream &operator<<(ostream &A, const unordered_map<T, T2> &b);
template<typename T, typename T2>
ostream &operator<<(ostream &A, const unordered_multimap<T, T2> &b);
template<typename T>
ostream &operator<<(ostream &A, const vector<T> &b) {
A << "[";
for (size_t i = 0; i + 1 < b.size(); i++) {
A << b[i] << ",";
}
if (b.size()) {
A << b[b.size() - 1];
}
A << "]";
return A;
}
template<typename T>
ostream &operator<<(ostream &A, const deque<T> &b) {
A << "[";
for (size_t i = 0; i + 1 < b.size(); i++) {
A << b[i] << ",";
}
if (b.size()) {
A << b[b.size() - 1];
}
A << "]";
return A;
}
template<typename T1, typename T2>
ostream &operator<<(ostream &A, const pair<T1, T2> &b) {
return A << '(' << b.first << ',' << b.second << ')';
}
template<typename T>
ostream &operator<<(ostream &A, const set<T> &b) {
typename set<T>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T>
ostream &operator<<(ostream &A, const multiset<T> &b) {
typename multiset<T>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T, typename T2>
ostream &operator<<(ostream &A, const map<T, T2> &b) {
typename map<T, T2>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T, typename T2>
ostream &operator<<(ostream &A, const multimap<T, T2> &b) {
typename multimap<T, T2>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T>
ostream &operator<<(ostream &A, const unordered_set<T> &b) {
typename unordered_set<T>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T>
ostream &operator<<(ostream &A, const unordered_multiset<T> &b) {
typename unordered_multiset<T>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T, typename T2>
ostream &operator<<(ostream &A, const unordered_multimap<T, T2> &b) {
typename unordered_multimap<T, T2>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T, typename T2>
ostream &operator<<(ostream &A, const unordered_map<T, T2> &b) {
typename unordered_map<T, T2>::const_iterator i = b.begin(), e = b.end();
A << "{";
while (i != e) {
A << *i;
i++;
if (i != e) {
A << ",";
}
}
return A << "}";
}
template<typename T1, typename T2>
istream &operator>>(istream &A, pair<T1, T2> &b) {
return A >> b.first >> b.second;
}
template<typename T>
void print_array(T b, T e, string s = " ") {
while (b != e) {
cout << *b;
b++;
if (b != e) {
cout << s;
}
}
}
template<typename T>
void auto_print(T &b, size_t n, string s = " ") {
for (size_t i = 1; i < n; i++) {
cout << b[i] << s;
}
cout << b[n];
}
template<typename T>
void auto_print(T &b, string s = " ") {
for (auto i : b) {
cout << i << s;
}
}
template<typename T>
void print_n(T b, size_t n, string s = " ") {
if (n == 0) return;
cout << *b;
for (size_t i = 1; i < n; i++) {
b++;
cout << s << *b;
}
}
template<typename T>
void read_array(T b, T e) {
while (b != e) {
cin >> *b;
b++;
}
}
template<typename T>
void auto_read(T &b, size_t n) {
for (size_t i = 1; i <= n; i++) {
cin >> b[i];
}
}
template<typename T>
void read_n(T b, size_t n) { // untested
cin >> *b;
for (size_t i = 1; i < n; i++) {
b++;
cin >> *b;
}
}
template <typename T>
string to_string(const T& value) {
ostringstream oss;
oss << value;
return oss.str();
}
template <typename Array>
string array_debug_impl(const Array& a) {
return "";
}
template <typename Array, typename First, typename... Rest>
string array_debug_impl(const Array& a, First first, Rest... rest) {
string current = "[" + to_string(first) + "]";
string next = array_debug_impl(a, rest...);
return current + next;
}
template <typename T>
decltype(auto) get_nested(T&& first) {
return forward<T>(first);
}
template <typename T, typename U, typename... Args>
decltype(auto) get_nested(T&& first, U&& second, Args&&... args) {
return get_nested(forward<T>(first)[forward<U>(second)], forward<Args>(args)...);
}
template <typename Array, typename... Args>
string array_debug(const Array& a, Args... args) {
string indices_str = array_debug_impl(a, args...);
ostringstream value_oss;
value_oss << get_nested(a, args...);
return indices_str + " = " + value_oss.str();
}
ostream &operator<<(ostream &a, ostream &b) {
return b;
}
#if 1 // use debug macros
#define aedebug(first, ...) #first << array_debug(first, __VA_ARGS__)
#define spc << " " <<
#endif
#undef quick_io_int_length_limit
#undef quick_io_int_radix_type
#undef quick_io_length_type
}
using namespace quick_io;
template<typename T_key,
typename T_value,
typename head_container = vector<size_t>,
typename nxt_container = vector<size_t>,
typename data_container = vector<T_value>,
typename index_type = size_t>
struct basic_fs {
// front star, without initialization
data_container data;
nxt_container nxt;
head_container head;
index_type cnt;
basic_fs() : cnt(0) {}
struct data_t {
T_key key;
basic_fs* base;
struct iterator {
index_type pos;
data_t* base;
T_value &operator*() {
return base->base->data[pos];
}
iterator &operator++() {
pos = base->base->nxt[pos];
return *this;
}
operator index_type&() {
return pos;
}
};
iterator end() {
return {0, this};
}
iterator begin() {
return {base->head[key], this};
}
void push_back(const T_value &value) {
base->data[++base->cnt] = value;
base->nxt[base->cnt] = base->head[key];
base->head[key] = base->cnt;
}
bool empty() const {
return base->head[key] == 0;
}
};
data_t operator[](const T_key &x) {
return {x, this};
}
};
template<typename T,
size_t index_range,
size_t value_size = index_range>
using array_fs = basic_fs<size_t,
T,
int[index_range],
int[value_size],
T[value_size],
int>;
// #define int long long
#define min(x,y) (((x)<(y))?(x):(y))
#define max(x,y) (((x)>(y))?(x):(y))
#define lowbit(x) ((x)&-(x))
#define abs(x) (((x)<(0))?(-(x)):(x))
#define swap(a,b) a^=b^=a^=b
#define INF 1e18
#define sos ostream
#define sis istream
#define soss ostringstream
#define siss istringstream
#if 1 // using using of types
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;
using i128 = __int128_t;
using u128 = __uint128_t;
using vi = vector<int>;
using vvi = vector<vi>;
using vpii = vector<pii>;
#endif
ostream &cans(cout);
// #define cout cerr
#define MAXN 1000005
#define MAXM 10000005
int n, k, city[MAXN];
array_fs<signed, MAXN> e, pos;
bool blocked[MAXN];
namespace tarjan {
array_fs<signed, MAXM> e;
int scc_cnt, cnt, scc[MAXM], scc_sz[MAXM];
int dfn[MAXM], low[MAXM], st[MAXM], stamp;
bool flag[MAXM], inst[MAXM];
void tarjan(int id)
{
dfn[id] = low[id] = ++stamp;
st[++st[0]] = id;
inst[id] = 1;
for (int i : e[id])
{
if (!dfn[i])
{
tarjan(i);
low[id] = min(low[id], low[i]);
}
else if (inst[i])
{
low[id] = min(low[id], low[i]);
}
}
if (dfn[id] == low[id])
{
scc_cnt++;
while (st[st[0]] != id)
{
scc[st[st[0]]] = scc_cnt;
inst[st[st[0]--]] = 0;
}
scc[st[st[0]]] = scc_cnt;
inst[st[st[0]--]] = 0;
}
}
// void dfs_imp(int id, bool fl)
// {
// if (flag[id])
// {
// return;
// }
// if (fl)
// {
// flag[id] = 1;
// }
// for (int i : e2[id])
// {
// // cout << id << " -> " << i << endl;
// dfs_imp(i, 1);
// }
// }
}
namespace toposort {
int dp[MAXM], *val = tarjan::scc_sz;
int ind[MAXM];
array_fs<signed, MAXM> e;
using tarjan::scc_cnt;
void toposort()
{
for (int i = 1; i <= scc_cnt; i++)
{
for (int j : e[i])
{
ind[j]++;
}
}
queue<int> q;
for (int i = 1; i <= scc_cnt; i++)
{
if (!ind[i])
{
q.push(i);
}
dp[i] = val[i];
}
while (!q.empty())
{
int t = q.front();
int v = dp[t];
q.pop();
for (int i : e[t])
{
dp[i] += v;
if (!--ind[i])
{
q.push(i);
}
}
}
for (int i = 1; i <= scc_cnt; i++)
{
tarjan::flag[i] = val[i] == dp[i];
}
}
}
namespace LCA {
int dfn[MAXN], ST[MAXN][20], dep[MAXN], stamp;
#if SIZE_MAX == UINT_MAX
#define COUNT_LEADING_ZEROS(x) __builtin_clz(x)
#else
#define COUNT_LEADING_ZEROS(x) __builtin_clzll(x)
#endif
size_t floor_log2(size_t k) {
return sizeof(size_t) * CHAR_BIT - 1 - COUNT_LEADING_ZEROS(k);
}
#undef COUNT_LEADING_ZEROS
template<typename T, typename T2>
int query(T &&st, size_t l, size_t r, T2 &&cmp) {
if (l > r) swap(l, r);
if (l == r) return st[l][0];
size_t k = floor_log2(r - l + 1);
return cmp(st[l][k], st[r - (size_t(1) << k) + 1][k]);
}
template<typename T, typename T2>
void build_ST(T &&st, size_t n, T2 &&cmp) {
for (size_t i = 1; (size_t(1) << i) <= n; i++) {
for (size_t j = 0; j + (size_t(1) << i) <= n; j++) {
st[j][i] = cmp(st[j][i - 1], st[j + (size_t(1) << (i - 1))][i - 1]);
}
}
}
int cmp(int a, int b)
{
if (dep[a] < dep[b])
{
return a;
}
return b;
}
void dfs(int id, int fa, int d = 1)
{
dep[id] = d;
ST[dfn[id] = ++stamp][0] = id;
for (int i : e[id])
{
if (i != fa)
{
dfs(i, id, d + 1);
ST[++stamp][0] = id;
}
}
}
int LCA(int x, int y)
{
return query(ST, dfn[x], dfn[y], cmp);
}
}
namespace edge_adder {
int top[MAXN], node_id[MAXN];
void dfs(int id, int fa, int t)
{
top[id] = t;
tarjan::e[node_id[id] = ++tarjan::cnt].push_back(node_id[fa]);
// cout << node_id[id] << " is a copy of " << id << endl;
tarjan::e[node_id[id]].push_back(k + id);
for (int i : e[id])
{
if (i != fa && !blocked[i])
{
dfs(i, id, t);
}
}
}
}
namespace core_decomposition {
using edge_adder::node_id;
pii get_core(int id, int fa, int sz)
{
// cout << __PRETTY_FUNCTION__ spc id spc fa spc sz << endl;
int s = 1;
for (int i : e[id])
{
if (i != fa && !blocked[i])
{
pii sres = get_core(i, id, sz);
if (sres.second)
{
return sres;
}
s += sres.first;
}
}
// cout << s << endl;
if (s >= (sz + 1) / 2)
{
return {0, id};
}
return {s, 0};
}
int get_size(int id, int fa)
{
// cout << __PRETTY_FUNCTION__ spc id spc fa << endl;
int s = 1;
for (int i : e[id])
{
if (i != fa && !blocked[i])
{
s += get_size(i, id);
}
}
return s;
}
vpii E__[MAXN];
void solve(int x, vpii E)
{
// cout << "solve" spc x spc E << endl;
x = get_core(x, x, get_size(x, x)).second;
// cout << vdebug(x) << endl;
blocked[x] = 1;
edge_adder::node_id[x] = x + k;
edge_adder::top[x] = x;
for (int i : e[x])
{
if (!blocked[i])
{
edge_adder::dfs(i, x, i);
E__[i].clear();
}
}
// cout << vdebug(E) << endl;
for (pii &i : E)
{
if (i.first == x)
{
tarjan::e[city[i.first]].push_back(node_id[i.second]);
// cout << i.second << " sec--> " << x << endl;
}
else if (i.second == x)
{
tarjan::e[city[i.first]].push_back(node_id[i.first]);
// cout << i.first << " fir--> " << x << endl;
}
else if (edge_adder::top[i.first] == edge_adder::top[i.second])
{
E__[edge_adder::top[i.first]].push_back(i);
}
else
{
tarjan::e[city[i.first]].push_back(node_id[i.first]);
tarjan::e[city[i.first]].push_back(node_id[i.second]);
// cout << city[i.first] << " -> " << node_id[i.first] << endl;
// cout << city[i.first] << " -> " << node_id[i.second] << endl;
// cout << vdebug(x) << " " << vdebug(i) << endl;
}
}
for (int i : e[x])
{
if (!blocked[i] && !E__[i].empty())
{
solve(i, E__[i]);
}
}
}
}
signed main() {
cin >> n >> k;
for (int i = 1, u, v; i < n; i++)
{
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
LCA::dfs(1, 1);
LCA::build_ST(LCA::ST + 1, LCA::stamp, LCA::cmp);
for (int i = 1; i <= n; i++)
{
cin >> city[i];
pos[city[i]].push_back(i);
tarjan::e[k + i].push_back(city[i]);
// cout << "add edge" spc k + i spc city[i] << endl;
}
// for (int i = 1; i <= n; i++)
// {
// for (int j : e[i])
// {
// cout << i << "<" << city[i] << ">" spc j << "<" << city[j] << ">" << endl;
// }
// }
// cout << adebug(LCA::dfn, 1, n) << endl;
// for (int j = 0; j <= 5; j++)
// {
// for (int i = 1; i <= LCA::stamp; i++)
// {
// cout << LCA::ST[i][j] << " ";
// }
// cout << endl;
// }
// for (int i = 1; i <= n; i++)
// {
// for (int j = 1; j <= i; j++)
// {
// cout << LCA::LCA(i, j) << " ";
// }
// cout << endl;
// }
vpii E;
for (int i = 1; i <= k; i++)
{
if (pos[i].empty())
{
continue;
}
int lca = *pos[i].begin();
for (int j : pos[i])
{
lca = LCA::LCA(lca, j);
}
for (int j : pos[i])
{
E.push_back({j, lca});
}
}
tarjan::cnt = n + k;
core_decomposition::solve(1, E);
for (int i = 1; i <= tarjan::cnt; i++)
{
if (!tarjan::dfn[i])
{
tarjan::tarjan(i);
}
}
// cout << adebug(tarjan::scc, 1, tarjan::cnt) << endl;
for (int i = 1; i <= tarjan::cnt; i++)
{
for (int j : tarjan::e[i])
{
// cout << i spc j << endl;
if (tarjan::scc[j] != tarjan::scc[i])
{
// tarjan::flag[tarjan::scc[i]] = 1;
toposort::e[tarjan::scc[j]].push_back(tarjan::scc[i]);
}
}
}
for (int i = 1; i <= k; i++)
{
tarjan::scc_sz[tarjan::scc[i]]++;
}
// for (int i = 1; i <= k; i++)
// {
// tarjan::dfs_imp(tarjan::scc[i], false);
// }
toposort::toposort();
// cout << adebug(tarjan::scc_sz, 1, tarjan::scc_cnt) << endl;
// cout << adebug(toposort::dp, 1, tarjan::scc_cnt) << endl;
// cout << adebug(tarjan::flag, 1, tarjan::scc_cnt) << endl;
int ansans = 1e9;
for (int i = 1; i <= tarjan::scc_cnt; i++)
{
if (tarjan::flag[i] && tarjan::scc_sz[i])
{
ansans = min(ansans, tarjan::scc_sz[i]);
}
}
cout << ansans - 1 << endl;
return 0;
}