属实是属于赛后清醒,赛时暴力维护两、三个点的情况,三分求答案(真不知道当时怎么想的,,,都用了三分了知道是凸函数还不用公式直接求最大值属实是nt行为)。
正解:数学+贪心
贪心:因为有正负,维护两三个点为最优,那么怎么找到平均值最大的点呢,我们在存图的时候存的不是与他相邻的点,而是与他相邻的点的权值。这样三个点的情况,我们就可以以x为中间点,找到相邻的另外两个点就可以了。
对于取另外两个点,有两种情况:
一、最大与次大
二、最小与次小(负数时可能会更优)
对于上面的情况,我们只需要对x的相邻权值sort就可以
代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
const int N = 200010;
const double eps = 1e-6;
int n;
double w[N];
vector<int> v[N];
double res = -1e18;
double check(int u, int j) // 维护两个点时候的最大值
{
double a = -2;
double b = (u + j);
double x = b * 1.0 / (-2 * a);
return (a * x * x + b * x) / 2;
}
double checkk(int u, int j, int v) // 维护三个点时候的最大值
{
double a = -3;
double b = (u + j + v);
double x = b * 1.0 / (-2 * a);
return (a * x * x + b * x) / 3;
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i ++) cin >> w[i];
for (int i = 1; i < n; i ++)
{
int a, b;
cin >> a >> b;
v[a].push_back(w[b]); // 以权值建图
v[b].push_back(w[a]);
}
for (int i = 1; i <= n; i ++) // 对相邻权值进行排序
{
sort(v[i].begin(), v[i].end(), greater<int>());
}
for (int i = 1; i <= n; i ++) // 取两个点的情况
{
int sz = v[i].size();
res = max(res, check(w[i], v[i][0]));
res = max(res, check(w[i], v[i][sz - 1]));
}
for (int i = 1; i <= n; i ++) // 取三个点的情况
{
int sz = v[i].size();
if (sz < 2) continue;
for (int j = 0; j < sz; j ++)
{
res = max(res, checkk(w[i], v[i][0], v[i][1]));
res = max(res, checkk(w[i], v[i][sz - 1], v[i][sz - 2]));
}
}
cout.setf(ios::fixed);
cout << fixed << setprecision(6) << res << '\n';
return 0;
}
三分的代码也贴一下毕竟写了(用阳寿过题,一样代码交了三发 -> 2AC 1TLE)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
const int N = 200010;
const double eps = 1e-5;
int n;
double w[N];
vector<int> v[N];
double res = -1e18;
double f(double x, double a, double b)
{
return ((-x * x + a * x) + (-x * x + b * x)) * 1.0 / 2;
}
double ff(double x, double a, double b, double c)
{
return ((-x * x + a * x) + (-x * x + b * x) + (-x * x + c * x)) * 1.0 / 3;
}
double check(int u, int j)
{
double l = -100000, r = 100000;
if (l > r) swap(l, r);
while (r - l > eps)
{
double lmid = l + (r - l) / 3;
double rmid = l + (r - l) / 3 * 2;
double lans = f(lmid, u, j), rans = f(rmid, u, j);
if (lans >= rans) r = rmid;
else l = lmid;
}
return f(l, u, j);
}
double checkk(int u, int j, int v)
{
double l = -100000, r = 100000;
if (l > r) swap(l, r);
while (r - l > eps)
{
double lmid = l + (r - l) / 3;
double rmid = l + (r - l) / 3 * 2;
double lans = ff(lmid, u, j, v), rans = ff(rmid, u, j, v);
if (lans >= rans) r = rmid;
else l = lmid;
}
return ff(l, u, j, v);
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i ++) cin >> w[i];
for (int i = 1; i < n; i ++)
{
int a, b;
cin >> a >> b;
v[a].push_back(w[b]);
v[b].push_back(w[a]);
}
for (int i = 1; i <= n; i ++)
{
sort(v[i].begin(), v[i].end(), greater<int>());
}
for (int i = 1; i <= n; i ++)
{
int sz = v[i].size();
res = max(res, check(w[i], v[i][0]));
res = max(res, check(w[i], v[i][sz - 1]));
}
for (int i = 1; i <= n; i ++)
{
int sz = v[i].size();
if (sz < 2) continue;
for (int j = 0; j < sz; j ++)
{
res = max(res, checkk(w[i], v[i][0], v[i][1]));
res = max(res, checkk(w[i], v[i][sz - 1], v[i][sz - 2]));
}
}
cout.setf(ios::fixed);
cout << fixed << setprecision(6) << res << '\n';
return 0;
}