ACM十月训练

ABC 426 G

题意

\(q\) 次询问,每次询问给定区间 \([l,r]\),求该区间的 01 背包最大价值。

\(q\le 2\times 10^5,n\le 2\times 10^4, c\le 500\)

题解

考虑一种离线的分治做法:首先考虑两个 01 背包如何合并,我们可以将区间分成 \([l, m]\)\([m + 1, r]\),分别记录后缀 dp 值和前缀 dp 值,这样最后就能枚举容量找最大值。假设当前处理区间是 \([l, r]\),对于所有问题 \([i,j]\),我们分成三类处理:

  1. \(j\le m\),放到 \([l, m]\) 分治。
  2. \(i> m\),放到 \([m + 1, r]\) 分治。
  3. 其余区间,即 \(m\in [i, j]\)​,通过上述背包合并解决。

这种方法似乎叫猫树分治。

时间复杂度:\(O(nk\log n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>

#define int long long
using namespace std;

using i64 = long long;
using u64 = unsigned long long;
using u32 = unsigned;

using u128 = unsigned __int128;
using i128 = __int128;
using ld = long double;

signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);

int n;
cin >> n;
vector<int> w(n + 1), v(n + 1);
for (int i = 1; i <= n; i++) {
cin >> w[i] >> v[i];
}
int q;
cin >> q;
vector<array<int, 3>> query;
for (int i = 1; i <= q; i++) {
int l, r, c;
cin >> l >> r >> c;
query.push_back({l, r, c});
}

vector<int> ans(q);
constexpr int K = 500;
vector dpl(n + 1, vector<int>(K + 1)), dpr(n + 1, vector<int>(K + 1));
auto update = [&](const vector<int> &ndp, vector<int> &dp, int i) {
for (int j = 0; j <= K; j++) {
dp[j] = ndp[j];
if (j >= w[i]) {
dp[j] = max(dp[j], ndp[j - w[i]] + v[i]);
}
}
};
auto solve = [&](auto solve, int l, int r, const vector<int> &qid) -> void {
if (l == r) {
for (auto i : qid) {
auto [nl, nr, nc] = query[i];
assert(nl == l && nr == r);
ans[i] = nc >= w[l] ? v[l] : 0;
}
return;
}
int m = l + (r - l) / 2;
fill(dpl[m + 1].begin(), dpl[m + 1].end(), 0);
fill(dpr[m].begin(), dpr[m].end(), 0);
for (int i = m; i >= l; i--) {
update(dpl[i + 1], dpl[i], i);
}
for (int i = m + 1; i <= r; i++) {
update(dpr[i - 1], dpr[i], i);
}
vector<int> qid_l, qid_r;
for (auto i : qid) {
auto [nl, nr, nc] = query[i];
if (nr <= m) {
qid_l.push_back(i);
} else if (nl > m) {
qid_r.push_back(i);
} else {
for (int j = 0; j <= nc; j++) {
ans[i] = max(ans[i], dpl[nl][j] + dpr[nr][nc - j]);
}
}
}
solve(solve, l, m, qid_l);
solve(solve, m + 1, r, qid_r);
};
vector<int> qid(q);
iota(qid.begin(), qid.end(), 0);
solve(solve, 1, n, qid);
for (int i = 0; i < q; i++) {
cout << ans[i] << "\n";
}

return 0;
}