可爱萌新求助 80 分 KD-Tree
  • 板块P4148 简单题
  • 楼主Mogeko
  • 当前回复2
  • 已保存回复2
  • 发布时间2022/1/10 14:38
  • 上次更新2023/10/28 12:33:25
查看原帖
可爱萌新求助 80 分 KD-Tree
119316
Mogeko楼主2022/1/10 14:38
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define RN 500005

typedef int I;
typedef char C;
typedef long long L;

#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define SWAP(T, a, b) { T t = a; a = b; b = t; }

// K-D Tree

typedef struct
{
	I n[2];
}
KDInfo;

typedef struct
{
	KDInfo pos;
	KDInfo lbnd;
	KDInfo rbnd;
	I      val;
	I      sum;
	I      cnt;
	I      dim;
	I      ch[2];
}
KDNode;

KDNode kdpool[RN];
I      kdcnt;

#define kdpos(x)  kdpool[x].pos
#define kdlbnd(x) kdpool[x].lbnd
#define kdrbnd(x) kdpool[x].rbnd
#define kdval(x)  kdpool[x].val
#define kdsum(x)  kdpool[x].sum
#define kdcnt(x)  kdpool[x].cnt
#define kddim(x)  kdpool[x].dim
#define kdlch(x)  kdpool[x].ch[0]
#define kdrch(x)  kdpool[x].ch[1]

KDInfo kdbuf[RN];
I      kdind[RN];

I chooseKD(I l, I r)
{
	I mx = 0, mxvar = 0;
	for (I i = 0; i < 2; i++)
	{
		L sum = 0, sum2 = 0;
		for (I j = l; j <= r; j++)
		{
			sum += 1ll * kdbuf[j].n[i];
			sum2 += 1ll * kdbuf[j].n[i] * kdbuf[j].n[i];
		}
		if ((r - l + 1) * sum2 - sum * sum > mxvar) 
			mxvar = (r - l + 1) * sum2 - sum * sum, mx = i;
	}
	return mx;
}

I sortKD(I l, I r, I k)
{
	I mx = chooseKD(l, r);
	while (1)
	{
		I i = l, j = r;
		I pivot = kdbuf[l + rand() % (r - l + 1)].n[mx];
		do
		{
			while (kdbuf[i].n[mx] < pivot) i++;
			while (pivot < kdbuf[j].n[mx]) j--;
			if (i <= j)
			{
				SWAP(KDInfo, kdbuf[i], kdbuf[j]);
				SWAP(I, kdind[i], kdind[j]);
				i++, j--;
			}
		}
		while (i <= j);
		if (k >= i) l = i;
		else if (k <= j) r = j;
		else break;
	}
	return mx;
}

static inline void upKD(I x)
{
	kdcnt(x) = kdcnt(kdlch(x)) + 1 + kdcnt(kdrch(x));
	kdsum(x) = kdsum(kdlch(x)) + kdval(x) + kdsum(kdrch(x));
}

static inline void refreshKD(I x)
{
	for (I i = 0; i < 2; i++)
	{
		I val1 = kdpos(x).n[i], val2 = kdpos(x).n[i];
		if (kdlch(x))
		{
			val1 = MIN(val1, kdlbnd(kdlch(x)).n[i]);
			val2 = MAX(val2, kdrbnd(kdlch(x)).n[i]);
		}
		if (kdrch(x))
		{
			val1 = MIN(val1, kdlbnd(kdrch(x)).n[i]);
			val2 = MAX(val2, kdrbnd(kdrch(x)).n[i]);
		}
		kdlbnd(x).n[i] = val1;
		kdrbnd(x).n[i] = val2;
	}
}

I buildKD(I l, I r)
{
    if (l > r) return 0;
	I mid = l + ((r - l) >> 1);
	I dim = sortKD(l, r, mid);
	I x = kdind[mid];

	kddim(x) = dim;
	kdpos(x) = kdbuf[mid];
	kdlch(x) = buildKD(l, mid - 1);
	kdrch(x) = buildKD(mid + 1, r);
	refreshKD(x);
	upKD(x);
	return x;
}

void piaKD(I x, I l)
{
	if (!x) return;
	piaKD(kdlch(x), l);
	kdbuf[l + kdcnt(kdlch(x))] = kdpos(x);
	kdind[l + kdcnt(kdlch(x))] = x;
	piaKD(kdrch(x), l + kdcnt(kdlch(x)) + 1);
}

I maintainKD(I x)
{
	if (kdcnt(kdlch(x)) > kdcnt(kdrch(x)) * 4 + 4
	 || kdcnt(kdrch(x)) > kdcnt(kdlch(x)) * 4 + 4)
	{
		piaKD(x, 1);
		return buildKD(1, kdcnt(x));
	}
	return x;
}

static inline C insideKD(KDInfo x, KDInfo y, KDInfo l, KDInfo r)
{
	I ret = 1;
	ret &= x.n[0] >= l.n[0] && y.n[0] <= r.n[0];
	ret &= x.n[1] >= l.n[1] && y.n[1] <= r.n[1];
	return ret;
}

I addKD(I x, KDInfo pos, I val)
{
	if (!x || insideKD(pos, pos, kdpos(x), kdpos(x)))
	{
		if (!x) x = ++kdcnt, kdpos(x) = kdlbnd(x) = kdrbnd(x) = pos;
		kdval(x) += val, kdsum(x) += val, kdcnt(x) = 1;
		return x;
	}
	if (kdpos(x).n[kddim(x)] > pos.n[kddim(x)])
		kdlch(x) = addKD(kdlch(x), pos, val);
	else
		kdrch(x) = addKD(kdrch(x), pos, val);
    refreshKD(x);
	upKD(x);
	return x;
}

I queryKD(I x, KDInfo lbnd, KDInfo rbnd)
{
	if (insideKD(kdlbnd(x), kdrbnd(x), lbnd, rbnd))
		return kdsum(x);
	I l = kdlch(x), r = kdrch(x), sum = 0;
	if (insideKD(kdpos(x), kdpos(x), lbnd, rbnd))
	    sum += kdval(x);
	if (l && lbnd.n[kddim(x)] <= kdpos(x).n[kddim(x)])
		sum += queryKD(l, lbnd, rbnd);
	if (r && rbnd.n[kddim(x)] >= kdpos(x).n[kddim(x)])
		sum += queryKD(r, lbnd, rbnd);
	return sum;
}

void debugKD(I x, I dep)
{
	if (!x || dep > 10) return;
	debugKD(kdlch(x), dep + 1);
	debugKD(kdrch(x), dep + 1);
	printf("%d %d %d %d %d %d %d %d %d %d %d\n", kdcnt(x), kdval(x), kdsum(x), kdlch(x), kdrch(x), kdpos(x).n[0], kdpos(x).n[1], kdlbnd(x).n[0], kdlbnd(x).n[1], kdrbnd(x).n[0], kdrbnd(x).n[1]);
}

// Main

int main(void)
{
	I n, root = 0, lastans = 0;
	scanf("%d", &n);
	while (1)
	{
		I opr;
		scanf("%d", &opr);
		if (opr == 1)
		{
			I a, b, c;
			scanf("%d%d%d", &a, &b, &c);
			a ^= lastans, b ^= lastans, c ^= lastans;
			root = addKD(root, (KDInfo){{a, b}}, c);
			root = maintainKD(root);
		}
		else if (opr == 2)
		{
			I a, b, c, d;
			scanf("%d%d%d%d", &a, &b, &c, &d);
			a ^= lastans, b ^= lastans, c ^= lastans, d ^= lastans;
			printf("%d\n", lastans = queryKD(root, (KDInfo){{a, b}}, (KDInfo){{c, d}}));
		}
		else break;
		//debugKD(root, 0);
	}
	return 0;
}

2022/1/10 14:38
加载中...