#include <bits/stdc++.h>
using namespace std;
typedef unsigned int ui;
typedef long long ll;
#define all(x) (x).begin(), (x).end()
const int N = 3e3 + 20, p = 998244353;
vector<int> e[N];
int F[N][N * 2], *f[N], c[N], cnt[N], a[N], siz[N];
int res, tot;
bool ed[N];
void dfs(int u)
{
f[u][a[u]] = 1;
siz[u] = 1;
ed[u] = 1;
for (int v : e[u])
if (!ed[v])
{
dfs(v);
static int G[N * 2];
int *g = G + N;
copy_n(f[u] - siz[u] - siz[v], (siz[u] + siz[v]) * 2 + 1, g - siz[u] - siz[v]);
for (int i = -siz[u]; i <= siz[u]; i++)
for (int j = -siz[v]; j <= siz[v]; j++)
{
g[i + j] = (g[i + j] + (ll)f[u][i] * f[v][j]) % p;
}
copy_n(g - siz[u] - siz[v], (siz[u] + siz[v]) * 2 + 1, f[u] - siz[u] - siz[v]);
siz[u] += siz[v];
siz[u] = min(siz[u], tot);
}
for (int i = 1; i <= siz[u]; i++)
if ((res += f[u][i]) >= p)
res -= p;
ed[u] = 0;
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
int n, i, j;
cin >> n;
for (i = 1; i <= n; i++)
cin >> c[i], ++cnt[c[i]];
for (i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (i = 1; i <= n; i++)
f[i] = F[i] + N;
for (i = 1; i <= n; i++)
if (cnt[i])
{
tot = cnt[i];
for (j = 1; j <= n; j++)
fill_n(f[j] - cnt[i], cnt[i] * 2 + 1, 0);
for (j = 1; j <= n; j++)
a[j] = c[j] == i ? 1 : -1;
dfs(1);
// cerr<<i<<' '<<res<<endl;
}
cout << res << endl;
}
function<void(int, int)> dfs = [&](int u, int fa)
{
l[u] = ++ T;
for (int i = h[u]; i != -1; i = ne[i])
{
int v = e[i];
tr[T] = W[i];
if (v == fa) continue;
dfs(v, u);
}
r[u] = T;
};