SA数组+根号分治除了Subtask #1外都能过,100pts
查看原帖
SA数组+根号分治除了Subtask #1外都能过,100pts
231022
1234567890regis楼主2025/2/3 11:21

https://www.luogu.com.cn/record/200999134

乱搞做法++。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int MAXN = 1e6 + 7;
int sa[MAXN], rk[MAXN], h[MAXN], x[MAXN], y[MAXN], b[MAXN], st[MAXN][20], lg2[MAXN], n, m = 127, q, xfd;
string s;

void init()
{
	for (int i = 1; i <= n; i++) b[x[i] = s[i]]++;
	for (int i = 1; i <= m; i++) b[i] += b[i - 1];
	for (int i = n; i >= 1; i--) sa[b[x[i]]--] = i;
	for (int w = 1; w <= n; w <<= 1)
	{
		int tmp = 0;
		memset(b, 0, sizeof(b));
		for (int i = n - w + 1; i <= n; i++) y[++tmp] = i;
		for (int i = 1; i <= n; i++) if (sa[i] > w) y[++tmp] = sa[i] - w;
		for (int i = 1; i <= n; i++) b[x[i]]++;
		for (int i = 1; i <= m; i++) b[i] += b[i - 1];
		for (int i = n; i >= 1; i--) sa[b[x[y[i]]]--] = y[i];
		swap(x, y);
		x[sa[1]] = tmp = 1;
		for (int i = 2; i <= n; i++)
			x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + w] == y[sa[i - 1] + w] ? tmp : ++tmp);
		if (tmp == n) break;
		m = tmp;
	}
	int cur = 0;
	for (int i = 1; i <= n; i++) rk[sa[i]] = i;
	for (int i = 1; i <= n; i++)
	{
		if (rk[i] == 1) {
//			cur = 0;
			continue;
		}
		if (cur) cur--;
		int j = sa[rk[i] - 1];
		while (i + cur <= n && j + cur <= n && s[i + cur] == s[j + cur]) cur++;
		h[rk[i]] = cur;
	}
}

int query(int l, int r)
{
	if (l > r) swap(l, r);
	if (l == r) return 2147483647;
	l++;
	int d = lg2[r - l + 1];
	return min(st[l][d], st[r - (1ll << d) + 1][d]);
}

bool check(int l1, int r1, int l2, int x)
{
	r1 = r1 - x + 1;
	if (r1 - l1 + 1 <= xfd)
	{
		for (int i = l1; i <= r1; i++)
			if (query(rk[i], rk[l2]) >= x)
				return true;
		return false;
	}
	else {
		int i, j, ans = 0;
		for (i = rk[l2]; i <= n && (sa[i] < l1 || sa[i] > r1); i++);
		j = sa[i];
		if (query(i, rk[l2]) >= x) return true;
		for (i = rk[l2]; i >= 1 && (sa[i] < l1 || sa[i] > r1); i--);
		j = sa[i];
		if (query(i, rk[l2]) >= x) return true;
		return false;
	}
}

signed main()
{
	for (int i = 2; i < MAXN; i++) lg2[i] = lg2[i / 2] + 1;
	cin >> n >> q >> s; s = ' ' + s; xfd = 5 * sqrt(n);
	init();
	for (int i = 1; i <= n; i++) st[i][0] = h[i];
	for (int j = 1, p = 1; j < 20; j++, p <<= 1)
		for (int i = 1; i <= n; i++)
			st[i][j] = min(st[i][j - 1], st[i + p][j - 1]);
	while (q--)
	{
		int l1, r1, l2, r2, ans = 0; cin >> l1 >> r1 >> l2 >> r2;
		int l = 0, r = min(r1 - l1 + 1, r2 - l2 + 1);
		while (l < r)
		{
			int mid = l + r + 1 >> 1;
			if (check(l1, r1, l2, mid)) l = mid;
			else r = mid - 1;
		}
		cout << l << '\n';
	}
}
2025/2/3 11:21
加载中...