快速除法取模

最后更新于:2023年2月24日 下午

在知乎好友 SuperSodaSea文章中看到,对于模数为常量的情况编译器会帮我们做优化(所以我们按照同种方式手写肯定跑不过编译器,真的如此吗?)

模数为常量的快速模

1
2
3
4
5
6
7
8
int powMod(int x, int n, int M) {
int r = 1;
while (n) {
if (n&1) r = 1LL * r * x % M;
n >>= 1; x = 1LL * x * x % M;
}
return r;
}

因此对于 powMod,上面写法会慢于下面

1
2
3
4
5
6
7
8
9
const int M = 998244353;
int powMod(int x, int n) {
int r = 1;
while (n) {
if (n&1) r = 1LL * r * x % M;
n >>= 1; x = 1LL * x * x % M;
}
return r;
}

这是因为对于模为 const 的情形,编译器可以在编译期利用 Barrett Modular Multiplication 把除法变成乘法(如果我们自定义高精度还是需要自己写这些优化)。但是其实我们可以利用 Montgomery multiplication 做的更好即如下写法:

我们来证明一波。首先给定常数 \(m, n\) 使得 \(m < 2^n\)\(m\) 为奇数(这是必不可少的限制条件)。且使用无符号整型。对于任意数字 \(a\), 定义 \(a' \equiv 2^n a \mod m\),那么我们想要计算 \(c \equiv ab \mod m\),那么我们只需要计算出 \[ c' \equiv (ab)' \equiv 2^n ab \equiv 2^{-n} a'b' \mod m \] 所以我们需要找到一个 \(s\) 使得 \(a'b' + sm\)\(2^n\) 的倍数,然后 \(\frac{a'b' + sm}{2^n}\) 即为所求,且我们可以找 \(s < 2^n\) 那么此时这个结果会 小于 \(2m\)。然后做一个判断即可。而 \(s\) 显然可以用 exgcd 来求(这也是为什么要求 \(m\) 为奇数),即 \(s = -\frac{1}{m} \mod 2^n\)(代码中为 mr).

为了高效,实现的时候我们取 \(n = 32\)。注意到一般我们的模都小于 \(2^{30}\), 即 \(m < 2^{n - 2}\),若 \(a', b' < 2^{n - 1}\), 又因为 \(s < 2^n\),从而 \(sm + a'b' < 2^{2n - 1}\),从而保证了 \(c' < 2^{n - 1}\),从而就可以不要每次都取一次最小值,只需最后搞一下即可(下面情形依然可以省去)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 40% faster
using ULL = unsigned long long;
unsigned fastPowMod998244353(unsigned x, unsigned n) {
static const unsigned m = 998244353U;
static const unsigned mr = 998244351U;
static const unsigned m1 = 301989884U;
static const unsigned m1inv = 232013824U;
unsigned xx = (ULL(x) << 32) % m, rr = m1;
while (n) {
if (n & 1) {
ULL t = ULL(rr) * xx;
rr = (t + ULL(unsigned(t) * mr) * m) >> 32;
}
ULL t = ULL(xx) * xx;
xx = (t + ULL(unsigned(t) * mr) * m) >> 32;
n >>= 1;
}
return ULL(rr) * m1inv % m;
}

实测 xx = std::min(xx, xx - m); 会比 if (xx >= m) xx -= m 快不少,主要是因为后者有分支会比较耗时,这段被直接优化掉了 其实上述做法对 \(m\) 不是常数也能做,只是此时不能在编译期计算 mr, m1, m1inv 所以会很慢,从而就没有太大的意义了(极端特殊的场景还是能用的) 上面有一段特别类似的代码(也可以写成函数然后强制内联,但没必要),如果写成函数会增加耗时,所以还是单独写吧。然后如果我们有如下的函数,可以跑的更快(以后再加),并不会更快,因为 (x + y) >> 32 并不等于 (x >> 32) + (y >> 32),但是不妨碍写汇编(x86_64 gcc),以下汇编修改于 SuperSodaSea 给我的汇编代码。可是 arm64 的版本一直搞不好

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
static inline  __attribute__((always_inline))
unsigned muluh(unsigned a, unsigned b, unsigned& c) {
unsigned x;
__asm__ __volatile__(
"mull %%edx"
: "=a"(c), "=d"(x)
: "a"(a), "d"(b)
);
return x;
}
static inline __attribute__((always_inline))
unsigned muluh(unsigned a, unsigned b) {
unsigned x;
__asm__ __volatile__(
"mull %%edx"
: "=d"(x)
: "a"(a), "d"(b)
);
return x;
}

其实我们可以写代码生成常数模的 fastPowMod 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
#define cerr(x) std::cerr << (#x) << " is " << (x) << '\n'
using LL = long long;

int powMod(int x, int n, int M) {
int r = 1;
while (n) {
if (n&1) r = 1LL * r * x % M;
n >>= 1; x = 1LL * x * x % M;
}
return r;
}

template<typename T>
std::tuple<T, T, T> exGcd(T a, T b) {
if (b == 0) return {a, 1, 0};
auto [d, y, x] = exGcd(b, a % b);
return {d, x, y - a / b * x};
}

void generate(int m) {
if ((m & -m) == m) {
std::clog << "You may use bit operators instead:\n\n";
std::cout << "unsigned fastPowMod" << m << "(unsigned x, unsigned n) {\n";
std::cout << " unsigned r = 1;\n";
std::cout << " while (n) {\n";
std::cout << " if (n & 1) r = 1LL * r * x &" << m - 1 << "U;\n";
std::cout << " n >>= 1; r = 1LL * x * x &" << m - 1 << "U;\n";
std::cout << " }\n";
std::cout << " return r;\n";
std::cout << "}\n\n";
return;
} else if (m % 2 == 0) {
std::cerr << "Sorry, not support for even number which is not power of 2: " << m << "\n\n";
return;
} else if (m >> 30) {
std::cerr << "Sorry, not support for number bigger than 2^30: " << m << "\n\n";
return;
}
auto [d, x, y] = exGcd(1LL << 32, 1LL * m);
y = -y;
if (y < 0) y += 1LL << 32;
int m1 = (1LL << 32) % m;
auto [d1, x1, y1] = exGcd(m1, m);
if (x1 < 0) x1 += m;
std::cout << "using ULL = unsigned long long;\n";
std::cout << "unsigned fastPowMod" << m << "(unsigned x, unsigned n) {\n";
std::cout << " static const unsigned m = " << m << "U;\n";
std::cout << " static const unsigned mr = " << y << "U;\n";
std::cout << " static const unsigned m1 = " << m1 << "U;\n";
std::cout << " static const unsigned m1inv = " << x1 << "U;\n";
std::cout << " unsigned xx = (ULL(x) << 32) % m, rr = m1;\n";
std::cout << " while (n) {\n";
std::cout << " if (n & 1) {\n";
std::cout << " ULL t = ULL(rr) * xx;\n";
std::cout << " rr = (t + ULL(unsigned(t) * mr) * m) >> 32;\n";
std::cout << " }\n";
std::cout << " ULL t = ULL(xx) * xx;\n";
std::cout << " xx = (t + ULL(unsigned(t) * mr) * m) >> 32;\n";
std::cout << " n >>= 1;\n";
std::cout << " }\n";
std::cout << " return ULL(rr) * m1inv % m;\n";
std::cout << "}\n\n";
}

int main() {
generate(998244353);
generate(1000000007);
generate(1000000009);
generate(1024);
generate(39);
generate(24);
generate((1 << 30) + 1);
return 0;
}

然后我们验证效率和正确性的代码(以 998244353 为例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
#define cerr(x) std::cerr << (#x) << " is " << (x) << '\n'
using LL = long long;
using ULL = unsigned long long;

class Timer final {
std::chrono::high_resolution_clock::time_point start_;
std::string name_;
public:
Timer(std::string name = {}) : start_(std::chrono::high_resolution_clock::now()), name_(name) {}
~Timer() {
auto elapsedTime = std::chrono::high_resolution_clock::now() - start_;
std::cerr << std::setprecision(3) << std::fixed << "[Time used: " <<
name_ << "] " << elapsedTime.count() / 1'000'000.0 << "ms\n";
}
};

const int M = 998244353;
int powMod(int x, int n) {
int r = 1;
while (n) {
if (n&1) r = 1LL * r * x % M;
n >>= 1; x = 1LL * x * x % M;
}
return r;
}

// 40% faster
using ULL = unsigned long long;
unsigned fastPowMod998244353(unsigned x, unsigned n) {
static const unsigned m = 998244353U;
static const unsigned mr = 998244351U;
static const unsigned m1 = 301989884U;
static const unsigned m1inv = 232013824U;
unsigned xx = (ULL(x) << 32) % m, rr = m1;
while (n) {
if (n & 1) {
ULL t = ULL(rr) * xx;
rr = (t + ULL(unsigned(t) * mr) * m) >> 32;
}
ULL t = ULL(xx) * xx;
xx = (t + ULL(unsigned(t) * mr) * m) >> 32;
n >>= 1;
}
return ULL(rr) * m1inv % m;
}

std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count());

int main() {
const int N = 3e7;
std::vector<int> a(N);
for (int i = 0; i < N; ++i) a[i] = rnd() % M;
{
Timer A("fast?");
LL sum = 0;
for (int i = 0; i < N; ++i) {
sum += fastPowMod998244353(a[i], i);
}
cerr(sum);
}
{
Timer A("default");
LL sum = 0;
for (int i = 0; i < N; ++i) {
sum += powMod(a[i], i);
}
cerr(sum);
}
return 0;
}

至此快速模部分已经搞定。我们下面来看 SuperSodaSea 在知乎介绍的 Barrett Modular Multiplication方法(写的特别好,严谨,循序渐进,一点点的发现问题修复问题),当然了他参考了 Barrett 的论文

最后这部分可以写成一个类,对于模非常量时也能用上述方式做快速幂

模数为常量的快速除法

主定理:设 \(m, l \geq 0, d > 0\) 且满足 \(2^{N + l} \leq m \cdot d \leq 2^{N + l} + 2^l\), 则对于 \(0 \leq n < 2^N\) 成立 \(\lfloor \frac{n}{d} \rfloor = \lfloor \frac{n m}{2^{N + l}}\rfloor\)(证明自行推导或看原始论文或看 SuperSodaSea 的文章)

从而可以取 \(l = \lceil \log_2 d \rceil\) 保证区间非空,然后不断的减小 \(l\),从而让 \(m\) 变小。然后注意到最初 \(l\) 的取值会导致最终 \(m \leq \frac{2^{N + l} + 2^{l}}{d} < 2^{N + 1}\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
template<typename T>
int ctz(T x) {
if constexpr (4 == sizeof (T)) return __builtin_ctz(x);
return __builtin_ctzll(x);
}
template<typename T, typename T2>
std::pair<int, T2> chooseMultiplier(T d) {
constexpr int bitLen = __CHAR_BIT__ * sizeof (T);
if ((d & -d) == d) return {ctz(d), T2(1)};
T2 lb = T2(1) << bitLen;
int l = std::__lg(d - 1) + 1;
while (l >= 0 && (lb << l) / d != (lb + 1 << l) / d) --l;
++l;
return {l, (lb + 1 << l) / d};
}

void solve() {
unsigned x;
std::cin >> x;
auto [l, m] = chooseMultiplier(x);
std::cout << l << ' ' << m << '\n';
}

然后确实有 \(m \geq 2^N\) 的时候,例如 \(d = 7\),因此可以把 \(n m = n(m - 2^N) + n \cdot 2^N\)。即另 \(t = n (m - 2^N) / 2^N\),然后最终答案就是 \((n + t) / 2^l\)。但是若 \(n\) 很大,还是有溢出的风险。因此可以写成 (n - t >> 1) + t >> (l - 1)(注意此处 \(n \geq t\) 恒成立)。注意此时还有一个风险:\(l = 0\),但是这不可能因为此时 \(m < 2^N\)。另外注意到若 \(d\) 是 2 的倍数,那么我们可以计算 \((n / 2) / (d / 2)\),这么我们就可以保证 \(\frac{n}{2} m < 2^{2N}\) 了。因此最终代码生成的方式为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <bits/stdc++.h>
#define cerr(x) std::cerr << (#x) << " is " << (x) << '\n'
using LL = long long;

template<typename T>
int ctz(T x) {
if constexpr (4 == sizeof (T)) return __builtin_ctz(x);
return __builtin_ctzll(x);
}

template<typename T, typename T2>
std::pair<int, T2> chooseMultiplier(T d) {
constexpr int bitLen = __CHAR_BIT__ * sizeof (T);
if ((d & -d) == d) return {ctz(d), T2(1)};
T2 lb = T2(1) << bitLen;
int l = std::__lg(d - 1) + 1;
while (l >= 0 && (lb << l) / d != (lb + 1 << l) / d) --l;
++l;
return {l, (lb + 1 << l) / d};
}

void generateUnsignedDivision(uint32_t d) {
assert(d != 0);
int dd = ctz(d);
std::cout << "inline uint32_t div" << d << "(uint32_t n) {\n";
if (d >= 1u << 31) {
std::cout << " return n >= " << d << ";\n";
} else if (d == 1 << dd) {
std::cout << " return n";
if (dd) std::cout << " >> " << dd;
std::cout << ";\n";
} else {
auto [l, m] = chooseMultiplier<uint32_t, uint64_t>(d);
if (m < 1ull << 32) {
std::cout << " return " << m << "ull * n >> " << l + 32 << ";\n";
} else if (dd == 0) {
std::cout << " uint32_t t = " << (m - (1ull << 32)) << "ull * n >> 32;\n";
std::cout << " return (n - t >> 1) + t";
if (l > 1) std::cout << " >> " << l - 1;
std::cout << ";\n";
} else {
auto [l2, m2] = chooseMultiplier<uint32_t, uint64_t>(d >> 1);
std::cout << " return " << m2 << "ull * (n >> 1) >> " << l2 + 32 << ";\n";
}
}
std::cout << "}\n\n";
}

// only for gcc and x86_64
void generateUnsignedDivisionAsm(uint32_t d) {
assert(d != 0);
int dd = ctz(d);
std::cout << "inline uint32_t div" << d << "Asm(uint32_t n) {\n";
if (d >= 1u << 31) {
std::cout << " return n >= " << d << ";\n";
} else if (d == (1 << dd)) {
std::cout << " return n";
if (dd) std::cout << " >> " << dd;
std::cout << ";\n";
} else {
auto [l, m] = chooseMultiplier<uint32_t, uint64_t>(d);
if (m < (1ull << 32)) {
std::cout << " return muluh(" << m << ", n)";
if (l > 0) std::cout << " >> " << l;
std::cout << ";\n";
} else if (dd == 0) {
std::cout << " uint32_t t = " << "muluh(" << m - (1ull << 32) << ", n);\n";
std::cout << " return (n - t >> 1) + t";
if (l > 1) std::cout << " >> " << l - 1;
std::cout << ";\n";
} else {
auto [l2, m2] = chooseMultiplier<uint32_t, uint64_t>(d >> 1);
std::cout << " return muluh(" << m2 << ", n >> 1)";
if (l2 > 0) std::cout << " >> " << l2;
std::cout << ";\n";
}
}
std::cout << "}\n\n";
}

void solve() {
for (int i = 1; i < 16; ++i) {
generateUnsignedDivision(i);
generateUnsignedDivisionAsm(i);
}
uint32_t x = 6700417;
// std::cin >> x;
generateUnsignedDivision(x);
generateUnsignedDivisionAsm(x);
}

int main() {
std::cin.tie(nullptr)->sync_with_stdio(false);
solve();
}

测试后发现跟预期的一样,自己写的并没有默认跑的快,即 以上所有代码仅有理论意义,那么我们就止步于此了吗?

模数非常量的快速除法和快速模乘

必须承认一点:无论被除数 d 是否为常量,我们都不可能比默认的快,但是如果 d 为变量,但是会被多次使用,那么我们就大概率可以比默认的快,这部分其实还要部分归功于 lambda 函数(否则要用函数指针,写起来十分麻烦),但是这部分如何做成内联呢)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#include <bits/stdc++.h>
#define cerr(x) std::cerr << (#x) << " is " << (x) << '\n'
#define NewLine std::cerr << '\n'

template<typename T>
using IntLong = std::enable_if_t<
std::is_same_v<int32_t, T> ||
std::is_same_v<uint32_t, T> ||
std::is_same_v<int64_t, T> ||
std::is_same_v<uint64_t, T>>;

template<typename T, typename T2>
using Twice = std::enable_if_t<2 * sizeof (T) == sizeof (T2)>;

template<typename T, typename check = IntLong<T>>
int ctz(T x) {
if constexpr(4 == sizeof (T)) return __builtin_ctz(x);
return __builtin_ctzll(x);
}

template<typename T, typename T2, typename check = Twice<T,T2>>
std::pair<int, T2> chooseMultiplier(T d) {
constexpr int bitLen = __CHAR_BIT__ * sizeof (T);
if ((d & -d) == d) return {ctz(d), T2(1)};
T2 lb = T2(1) << bitLen;
int l = std::__lg(d - 1) + 1;
while (l >= 0 && (lb << l) / d != (lb + 1 << l) / d) --l;
++l;
return {l, (lb + 1 << l) / d};
}

template<typename T, typename T2, typename check = Twice<T,T2>>
std::function<T(T)> getDivFun(T d) {
assert(d != 0);
constexpr int bitLen = __CHAR_BIT__ * sizeof (T);
if (d >= T(1) << bitLen - 1) {
return [d](T n) {
return n >= d;
};
}
int dd = ctz(d);
if (d == 1 << dd) {
return [dd](T n) {
return n >> dd;
};
}
auto [l, m] = chooseMultiplier<T, T2>(d);
if (m < T2(1) << bitLen) {
return [mf = m, lf = l + bitLen](T n) {
return mf * n >> lf;
};
}
if (dd == 0) {
if (l > 1) {
return [mf = m - (T2(1) << bitLen), lf = l - 1](T n) {
T t = mf * n >> __CHAR_BIT__ * sizeof (T);
return (n - t >> 1) + t >> lf;
};
} else {
return [mf = m - (T2(1) << bitLen)](T n) {
T t = mf * n >> __CHAR_BIT__ * sizeof (T);
return (n - t >> 1) + t;
};
}
}
auto [l2, m2] = chooseMultiplier<T, T2>(d >> 1);
return [mf = m2, lf = l2 + bitLen](T n) {
return mf * (n >> 1) >> lf;
};
}

std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
std::mt19937_64 rnd64(std::chrono::steady_clock::now().time_since_epoch().count());
class Timer final {
std::chrono::high_resolution_clock::time_point start_;
std::string name_;
public:
Timer(std::string name = {}) : start_(std::chrono::high_resolution_clock::now()), name_(name) {}
~Timer() {
auto elapsedTime = std::chrono::high_resolution_clock::now() - start_;
std::cerr << std::setprecision(3) << std::fixed << "[Time used: " << name_ << "] " << elapsedTime.count() / 1'000'000.0 << "ms\n";
}
};

inline uint32_t div14(uint32_t n) {
return 4908534053ull * (n >> 1) >> 35;
}

int main() {
const int N = 1e7 + 2;
// test for mod = 14
{
std::vector<uint32_t> a(N);
for (auto &x : a) x = rnd();
const uint32_t modConst = 14;
uint32_t mod = 14;
for (auto x : a) {
if (div14(x) != x / mod) {
std::cerr << x << ' ' << div14(x) << ' ' << x / mod << '\n';
return -1;
}
}
{
Timer A("mod14");
uint64_t sum = 0;
for (auto x : a) sum += div14(x);
cerr(sum);
}
{
Timer A("default");
uint64_t sum = 0;
for (auto x : a) sum += x / mod;
cerr(sum);
}
{
Timer A("const default");
uint64_t sum = 0;
for (auto x : a) sum += x / modConst;
cerr(sum);
}
}
NewLine;
// test for uint32
{
std::vector<uint32_t> a(N);
for (auto &x : a) x = rnd();
uint32_t mod = 0;
while ((mod = rnd()) == 0);
auto f = getDivFun<uint32_t, u_int64_t>(mod);
for (auto x : a) {
if (f(x) != x / mod) {
std::cerr << x << ' ' << f(x) << ' ' << x / mod << '\n';
return -1;
}
}
{
Timer A("default");
uint64_t sum = 0;
for (auto x : a) sum += x / mod;
cerr(sum);
}
{
Timer A("fast?");
uint64_t sum = 0;
for (auto x : a) sum += f(x);
cerr(sum);
}
}
NewLine;
// test for uint64
{
std::vector<uint64_t> a(N);
for (auto &x : a) x = rnd();
uint64_t mod = 0;
while ((mod = rnd64()) == 0);
mod = mod % INT_MAX; // avoid answer too small
auto f = getDivFun<u_int64_t, __uint128_t>(mod);
for (auto x : a) {
if (f(x) != x / mod) {
std::cerr << x << ' ' << f(x) << ' ' << x / mod << '\n';
return -1;
}
}
{
Timer A("default");
__uint128_t sum = 0;
for (auto x : a) sum += x / mod;
cerr(uint64_t(sum >> 64));
cerr(uint64_t(sum));
}
{
Timer A("fast?");
__uint128_t sum = 0;
for (auto x : a) sum += f(x);
cerr(uint64_t(sum >> 64));
cerr(uint64_t(sum));
}
}
}

实测代码发现 Mac M1 上比不过,但是其他指令集表现还不错

遗留问题

  • 关于 muluh 所有主流指令集的汇编写法
  • lambda 如何变成内联函数,有没有可能用函数指针做多情况下内联(貌似此处没有必要)