85 pts 求条 WA on #1 #6 #7
查看原帖
85 pts 求条 WA on #1 #6 #7
602372
Brilliant11001楼主2024/11/22 21:31

思路就是 dpi,j,kdp_{i, j, k} 表示到达第 ii 个加油站,当前速度是 2j×3k2^j\times 3^k 的最短时间。

#include <cmath>
#include <queue>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 100010;
const double eps = 1e-7;
typedef long long ll;
int n, m;
struct node{
	int d, tim, oil;
}sta[N];
double dp[N][32][16];
int nums[N];
int sum[5][N];
ll pow2[32], pow3[16];

inline double get_tim(int x, int y, double v) {
	return 1.0 * (y - x) / v;
}

inline bool comp(double a, double b) {
	if(fabs(a - b) < eps) return true;
	return false;
}

int main() {
// 	freopen("ship4.in", "r", stdin);
// 	freopen("ship.out", "w", stdout);
	scanf("%d%d", &n, &m);
	pow2[0] = pow3[0] = 1;
	for(int i = 1; i <= 31; i++)
		pow2[i] = pow2[i - 1] * 2;
	for(int i = 1; i <= 15; i++)
		pow3[i] = pow3[i - 1] * 3;
	bool flag = true;
	for(int a, b, c, i = 1; i <= n; i++) {
		scanf("%d%d%d", &a, &b, &c);
		sta[i] = {a, b, c};
		flag &= (c == 1);
		nums[i] = a;
		sum[c][i]++;
	}
	if(flag) {
		int dist;
		while(m--) {
			scanf("%d", &dist);
			printf("%d\n", dist);
		}
		return 0;
	}
	for(int i = 1; i <= 4; i++)
		for(int j = 1; j <= n; j++)
			sum[i][j] += sum[i][j - 1];
	 for(int i = 0; i <= n; i++)
	     for(int j = 0; j < 32; j++)
	         for(int k = 0; k < 16; k++)
	             dp[i][j][k] = 1e18;
	dp[0][0][0] = 0;
	for(int i = 1; i <= n; i++) {
		dp[i][0][0] = sta[i].d;
		for(int j = 0; j <= min(31, sum[2][i] + 2 * sum[4][i]); j++)
			for(int k = 0; k <= min(15, min(i, sum[3][i])); k++) {
				if(!j && !k) continue;
				//不加油
				if(j <= sum[2][i - 1] + 2 * sum[4][i - 1] && k <= sum[3][i - 1])
					dp[i][j][k] = dp[i - 1][j][k] + get_tim(sta[i - 1].d, sta[i].d, 1.0 * pow2[j] * pow3[k]);
				//加油
				if(sta[i].oil == 1) continue;
				if(sta[i].oil == 2 && j)
					dp[i][j][k] = min(dp[i][j][k], dp[i - 1][j - 1][k] + get_tim(sta[i - 1].d, sta[i].d, 1.0 * pow2[j - 1] * pow3[k]) + sta[i].tim);
				else if(sta[i].oil == 3 && k)
					dp[i][j][k] = min(dp[i][j][k], dp[i - 1][j][k - 1] + get_tim(sta[i - 1].d, sta[i].d, 1.0 * pow2[j] * pow3[k - 1]) + sta[i].tim);
				else if(sum[4][i] && j >= 2)
					dp[i][j][k] = min(dp[i][j][k], dp[i - 1][j - 2][k] + get_tim(sta[i - 1].d, sta[i].d, 1.0 * pow2[j - 2] * pow3[k]) + sta[i].tim);
			}
	}
	int dist;
	while(m--) {
 		scanf("%d", &dist);
		int reald = upper_bound(nums + 1, nums + n + 1, dist) - nums - 1;
		double res = 1e18;
		for(int i = 0; i <= min(31, sum[2][reald] + 2 * sum[4][reald]); i++)
			for(int j = 0; j <= min(15, min(reald, sum[3][reald])); j++)
				res = min(res, dp[reald][i][j] + get_tim(sta[reald].d, dist, 1.0 * pow2[i] * pow3[j]));
		printf("%.10lf\n", res);
	}
	return 0;
}
2024/11/22 21:31
加载中...