数位dp学习笔记

​ 数位 dp 的常见形式,就是统计给定范围 \([a, b]\) 内满足某种条件的数有多少个,通过前缀和思想,我们可以拆分成 $[1,b] $ 的答案减去 \([1, a-1]\) 的答案。那么问题就变成了如何统计这样的答案,方法就是数位 dp。

​ 数位 dp 的基本思想是从高位到低位的顺序记忆化搜索,因为经常会用到之前重复出现的状态,但有些状态虽然数位相同,但是不能使用。举个例子:给定 \(b = 589\) ,由 \(1\) 开始搜到低位 89 的状态,对于高位为 \(2,3,4\) 的情况同样可以复用,但是高位为 5 的话,由于缺少了 590 ~ 599 ,可能导致答案的减少,这里就不能记忆化,必须单独考虑——这就被我们称为有限制的情况。

​ 既然使用记忆化搜索,我们一般通过 DFS 实现,这里的函数声明一般是:int dfs (int pos, bool sta, bool limit),其中 pos 代表我们正在处理的位,limit 代表当前位是否存在最大限制(即上面的 589 的情况),sta 代表题目中的其他限制。

​ 下面以例题的形式来理解 数位 dp 的过程:

HDU 2089 不要 62

题意

给定一个范围 \([n,m]\) ,求这个区间上既不存在“62”,也不存在“4”的数有多少个。

题解

对于每一位来说,如果它为 2,它的前一位不能为 6;同时它一定不能为 4。因此,我们在 DFS 的参数中定义 sta 表示前一位是否为 6。设 \(dp_{i,sta}\) 表示对于第 \(i\) 位,在 \(sta\) 状态下的满足题意的数量,那么排除上面的两种情况后继续 DFS 记忆化搜索即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
using namespace std;
#define int long long

int dp[10][2];
int n, m, sz;
int c[10];
int dfs(int pos, bool sta, bool limit) {
int ans = 0;
if (pos == sz + 1) {
return 1;
}
if (!limit && dp[pos][sta] != -1) {
return dp[pos][sta];
}
int up = limit ? c[pos] : 9;
for (int i = 0; i <= up; i++) {
if (sta && i == 2) {
continue;
}
if (i == 4) {
continue;
}
ans += dfs(pos + 1, i == 6, limit && i == up);
}
if (!limit) {
dp[pos][sta] = ans;
}
return ans;
}

void solve() {
while (true) {
cin >> n >> m;
if (n == 0 && m == 0) {
break;
}
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 2; j++) {
dp[i][j] = -1;
}
}
string s = to_string(m);
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
c[i] = s[i] - '0';
}
int ans = dfs(1, false, true);
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 2; j++) {
dp[i][j] = -1;
}
}
s = to_string(n - 1);
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
c[i] = s[i] - '0';
}
ans -= dfs(1, false, true);
cout << ans << "\n";
}
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}

HDU 3555 bomb

题意

给定一个数 \(x\),求区间 \([1,x]\) 中存在“49”的数有多少。

题解

从反向考虑的话和上一题非常相似,这里不给出解答。

下面从正向考虑:若一个位置满足条件,必须是这一位取 9,同时前一位取 4,所以 sta 表示前一位是否为 4,如果满足条件,以它开头的所有数都满足条件了,不需要继续搜索下去。但是这种情况还需要注意 limit 的限制,如果后面数字不能去满的话,需要用 \(x\) 对当前位的基数取模(预处理十的幂)后再加 1(考虑0)。如果不满足条件继续搜索即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
using namespace std;
#define int long long

int dp[65][2], nums[65], power[19];
int sz, x;
int dfs(int pos, bool sta, bool limit) {
if (pos == sz + 1) {
return 0;
}
if (limit == false && dp[pos][sta] != -1) {
return dp[pos][sta];
}
int ans = 0;
int up = limit ? nums[pos] : 9;
for (int i = 0; i <= up; i++) {
if (sta && i == 9) {
ans += (limit == true ? x % power[sz - pos] + 1 : power[sz - pos]);
} else {
ans += dfs(pos + 1, i == 4, limit && i == up);
}
}
if (limit == false) {
dp[pos][sta] = ans;
}
return ans;
}

void solve() {
cin >> x;
string s = to_string(x);
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
nums[i] = s[i] - '0';
}
for (int i = 0; i < 65; i++) {
for (int j = 0; j < 2; j++) {
dp[i][j] = -1;
}
}
cout << dfs(1, false, true) << "\n";
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
cin >> t;
power[0] = 1;
for (int i = 1; i <= 18; i++) {
power[i] = power[i - 1] * 10;
}
while (t--) {
solve();
}
return 0;
}

POJ 3252 Round Numbers

题意

给定一个范围 \([n,m]\) ,求这个区间上二进制下的 0 的数目比二进制下 1 的数目多的数有多少个。

题解

这里搜索的时候需要知道二进制下 0 和 1 的数目,所以设额外的变量 x0 和 x1,在搜索结束时用于判断是否成立。

此外,本题还涉及到数位 dp 中前导 0 的处理,需要着重说一下。由于前导 0 的存在,如果当前位也加上一个 0 ,其实是对答案没有任何贡献的,因为当前状态搜索的根本就不是一个数,因此用 sta 判断前面的状态是否是前导 0,如果是的话且当前数为 0 ,就不能加入 x0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
using namespace std;
#define int long long

int dp[65][65][65], nums[65];
int sz;
int dfs(int pos, int x0, int x1, bool sta, bool limit) {
if (pos == sz + 1) {
return x0 >= x1;
}
if (limit == false && dp[pos][x0][x1] != -1) {
return dp[pos][x0][x1];
}
int up = limit == true ? nums[pos] : 1;
int ans = 0;
for (int i = 0; i <= up; i++) {
if (sta && i == 0) {
ans += dfs(pos + 1, 0, 0, true, limit && i == up);
} else {
ans += dfs(pos + 1, x0 + (i == 0), x1 + (i == 1), false,
limit && i == up);
}
}
if (limit == false) {
dp[pos][x0][x1] = ans;
}
return ans;
}

int get(int x) {
string s;
while (x) {
if (x % 2 == 1) {
s += '1';
} else {
s += '0';
}
x /= 2;
}
reverse(s.begin(), s.end());
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
nums[i] = s[i] - '0';
}
for (int i = 0; i <= sz; i++) {
for (int j = 0; j <= sz; j++) {
for (int k = 0; k <= sz; k++) {
dp[i][j][k] = -1;
}
}
}
return dfs(1, 0, 0, true, true);
}

void solve() {
int n, m;
while (cin >> n >> m) {
cout << get(m) - get(n - 1) << "\n";
}
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}

洛谷 P2657 Windy数

题意

不含前导零且相邻两个数字之差至少为 2 的正整数被称为 windy 数。windy 想知道,在 ab 之间,包括 ab ,总共有多少个 windy 数?

题解

和上一题类似,需要考虑到前导零的影响,因为数字的第一位是可以不管大小的,我们在 dfs 中额外记录上一个选择的数字即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <bits/stdc++.h>
using namespace std;
// #define int long long

int dp[11][10];
int sz, nums[11];
int dfs(int pos, int lst, bool sta, bool limit) {
if (pos == sz + 1) {
return 1;
}
if (lst >= 0 && limit == true && dp[pos][lst] != -1) {
return dp[pos][lst];
}
int ans = 0;
int up = limit == true ? nums[pos] : 9;
for (int i = 0; i <= up; i++) {
if (i == 0 && sta) {
ans += dfs(pos + 1, -2, true, limit && i == up);
} else if (abs(i - lst) >= 2) {
ans += dfs(pos + 1, i, false, limit && i == up);
}
}
if (limit) {
dp[pos][lst] = ans;
}
return ans;
}
int get(int x) {
string s = to_string(x);
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
nums[i] = s[i] - '0';
}
for (int i = 0; i <= sz; i++) {
for (int j = 0; j < 10; j++) {
dp[i][j] = -1;
}
}
return dfs(1, -2, true, true);
}
void solve() {
int a, b;
cin >> a >> b;
cout << get(b) - get(a - 1) << "\n";
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}

Acwing 1081 度的数量

题意

求十进制区间 \([X,Y]\) 内,由 \(K\) 个互不相等(注意,此处的意思就是每个整数幂的个数不能 \(>1\),所以只能是选 0 / 1 )的 \(B\) 的整数幂之和组成的整数个数,也就是这个区间内的数字转化为 \(B\) 进制后恰好有 \(K\) 位上的数字位 1

题解

注意到题意就是一个数拆成 B 进制后每个位上的数最大为 1,且一共有 K 个就可以数位 dp 了,记录下访问的 1 的数量即可,前导零并没有影响。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;

using i64 = long long;
using u64 = unsigned long long;
using u32 = unsigned;
using u128 = unsigned __int128;

int dp[35][35], nums[35], sz, b, k;
int dfs(int pos, int cnt, bool limit) {
if (pos == sz + 1) {
return cnt == k;
}
if (!limit && dp[pos][cnt] != -1) {
return dp[pos][cnt];
}
int up = limit ? nums[pos] - 1 : b - 1;
int ans = 0;
for (int i = 0; i <= up; i++) {
if (i > 1 || (i == 0 && cnt == 0)) {
continue;
}
ans += dfs(pos + 1, cnt + (i == 1), limit && i == up);
}
if (!limit) {
dp[pos][cnt] = ans;
}
return ans;
}
int get(int x) {
for (int i = 0; i < 35; i++) {
for (int j = 0; j < 35; j++) {
dp[i][j] = -1;
}
}
string s;
while (x) {
s += '0' + x % b;
x /= b;
}
sz = s.size();
reverse(s.begin(), s.end());
for (int i = 1; i <= sz; i++) {
nums[i] = s[i] - '0';
}
return dfs(1, 0, true);
}
void solve() {
int l, r;
cin >> l >> r >> k >> b;
cout << get(r) - get(l - 1) << "\n";
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}

洛谷 P2602 数字计数

题意

给定两个正整数 ab,求在 \([a, b]\) 中的所有整数中,每个数码(digit)各出现了多少次。

题解

我们单独考虑每个数码,那么在 dfs 中加入 \(x\) 来表示当前考虑的数码,注意需要考虑前导 0 的影响,因为前导零会导致计数异常。我们枚举数字时,如果恰好为 \(x\),此时还需要考虑 limit 的影响,如果存在限制,此时位能增加的答案就是对当前位的基数取模后加1(和第二题不同的是,下面的处理预处理了取模后的结果,但本质是一样的),如果没有限制,答案就是对应的十的幂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
using namespace std;
#define int long long

int sz, nums[15];
int dp[15];
int power[15], now[15];
int dfs(int pos, int x, bool sta, bool limit) {
if (pos == sz + 1) {
return 0;
}
if (limit == false && sta == false && dp[pos] != -1) {
return dp[pos];
}
int cnt = 0;
int up = limit ? nums[pos] : 9;
for (int i = 0; i <= up; i++) {
if (sta == true && i == 0) {
cnt += dfs(pos + 1, x, true, limit && i == up);
} else if (i == x && limit == true && i == up) {
cnt += now[pos + 1] + 1 + dfs(pos + 1, x, false, limit && i == up);
} else if (i == x) {
cnt += power[sz - pos] + dfs(pos + 1, x, false, limit && i == up);
} else {
cnt += dfs(pos + 1, x, false, limit && i == up);
}
}
if (limit == false && sta == false) {
dp[pos] = cnt;
}
return cnt;
}

int get(int x, int digit) {
string s = to_string(x);
sz = s.size();
s = " " + s;
for (int i = 0; i < 15; i++) {
now[i] = 0;
}
for (int i = 1; i <= sz; i++) {
dp[i] = -1;
nums[i] = s[i] - '0';
}
for (int i = sz; i >= 1; i--) {
now[i] = now[i + 1] + nums[i] * power[sz - i];
}
return dfs(1, digit, true, true);
}

void solve() {
int a, b;
cin >> a >> b;
power[0] = 1;
for (int i = 1; i <= 12; i++) {
power[i] = power[i - 1] * 10;
}
for (int i = 0; i <= 9; i++) {
cout << get(b, i) - get(a - 1, i) << " \n"[i == 9];
}
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}

Atcoder Beginner Contest 387 C

题意

一个不小于 10 的正整数,其十进制表示法的首位(最重要的一位)严格大于该数的其他各位,叫做蛇形数。例如, 31 和 201 是蛇形数,但 35 和 202 不是。

求在 LR 之间(包括首尾两个数)有多少个蛇形数。

题解

本题需要我们指定最高位的情况,因此在 DFS 中加入参数 first ,并注意处理前导零,遇到不成立的情况需要 continue 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
using namespace std;

using i64 = long long;
using u64 = unsigned long long;
using u32 = unsigned;
using u128 = unsigned __int128;

i64 dp[20][10];
int nums[20], sz;
i64 dfs(int pos, int first, bool sta, bool limit) {
if (pos == sz + 1) {
return 1;
}
if (!limit && dp[pos][first] != -1) {
return dp[pos][first];
}
int up = limit ? nums[pos] : 9;
i64 ans = 0;
for (int i = 0; i <= up; i++) {
if (sta == false && i >= first) {
continue;
}
if (i == 0 && sta) {
ans += dfs(pos + 1, 0, true, limit && i == up);
} else {
ans +=
dfs(pos + 1, (first == 0 ? i : first), false, limit && i == up);
}
}
if (!limit && !sta) {
dp[pos][first] = ans;
}
return ans;
}

i64 get(i64 x) {
string s = to_string(x);
sz = s.size();
s = " " + s;
for (int i = 1; i <= sz; i++) {
nums[i] = s[i] - '0';
}
for (int i = 0; i <= 19; i++) {
for (int j = 0; j <= 9; j++) {
dp[i][j] = -1;
}
}
return dfs(1, 0, true, true);
}

void solve() {
i64 l, r;
cin >> l >> r;
cout << get(r) - get(l - 1) << "\n";
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}