Finding a billion factorials in 60 ms with SIMD
Hi everyone!
This blog is not about asymptotic optimization. Refer to this comment for those.
There is a problem on Library Checker that goes as follows:
Many Factorials: You're given $$$n_1,\dots,n_t$$$, where $$$t \leq 10^5$$$. For each $$$i$$$, find $$$n_i! \bmod M$$$, where $$$M = 998\;244\;353$$$.
In this blog, we will learn how to solve this task in 61 ms, without precalc and without FFT. That's right, we will take the dumbest solution we can imagine, and will improve its constant factor until its decent enough. How much decent? Well, let's use the following baseline: Naive solution
vector facts(vector args) {
const int block = 1 << 16;
vector args_per_block[mod / block + 1];
vector res(size(args));
for(auto [i, x]: args | views::enumerate) {
args_per_block[x / block].push_back(i);
}
uint64_t fact = 1;
for(int i = 0; i < mod; i += block) {
static array facts;
facts[0] = fact * (i + !i) % mod;
for(int j = 1; j < block; j++) {
facts[j] = (facts[j - 1] * uint64_t(i + j)) % mod;
}
for(int idx: args_per_block[i / block]) {
int x = args[idx];
res[idx] = facts[x - i];
}
fact = facts.back();
}
return res;
}
It doesn't do anything fancy, just splits the domain in blocks of length $$$2^{16}$$$, then computes all factorials in naive manner within each block, and uses intermediate computation results to find each particular factorial in $$$O(1)$$$ per input query, for a total runtime of $$$O(M+n)$$$.
The solution above takes 3745 ms, so the final result of 61 ms is a 60x improvement.
Many thanks to Qwerty1232 for useful discussions!
Wilson's theorem
As a first step, let's utilize the Wilson's theorem in the following formulation: $$$n! (p-1-n)! \equiv (-1)^n \pmod{p}$$$.
This is just a direct consequence of a more standard formulation, which is $$$(p-1)! \equiv -1 \pmod {p}$$$.
If we pick the smaller of $$$n$$$ and $$$p-1-n$$$, we will only need to go up to $$$\frac{p}{2}$$$ instead of going over all $$$p$$$ numbers.
Doing this reduces our runtime to 1889 ms, which is an expected 2x improvement.
Skipping even factorials
Note that $$$2 \cdot 4 \cdot \ldots \cdot 2n = (2n)!! = 2^n n!$$$. This allows us to repeatedly factor out even numbers from the computation, until we only need to find certain products of odd numbers and multiply them by a really large power of $$$2$$$.
Let $$$f(n)$$$ to be the product of all odd numbers up to $$$n$$$, inclusive. Then, we can rewrite the factorial as
This further reduces the computation of $$$n!$$$ to finding the product of $$$\log n$$$ values of $$$f(t)$$$ with $$$t \leq n$$$. Since $$$f(t)$$$ is the multiplication of only $$$\frac{n}{2}$$$ numbers, rather than $$$n$$$ (as was the case with factorials).
Implementing this change, we get a runtime of 998 ms, which is almost another 2x improvement.
Utilizing instruction-level parallelism
Further reading: (1)
Before we bring on the heavy machinery of SIMD and vectorization, there is still one more optimization we could use, though it already relies on certain knowledge of how CPU works under the hood. Specifically, we can utilize instruction-level parallelism.
Normally, modern CPUs are able to pipeline operations in a way that operations with high latency (such as multiplication) can overlap thanks to higher throughput. This should significantly reduce the number of CPU cycles we spend to execute a sequence of similar instructions, if they are independent.
Naturally, our problem here is that the instructions are not actually independent, as before multiplying with $$$k+1$$$ we first need to wait for the result of multiplication with $$$k+2$$$. To mitigate this, we can process $$$K$$$ independent blocks in parallel.
In particular, taking $$$K=8$$$ allows us to further reduce the runtime to 287 ms, making it a 3.5x improvement.
This is almost the best we can get without vectorization.
Vectorization
Next thing we can do is to add vectorization. For this, we will need Montgomery multiplication. It vectorizes as follows: Vectorized Montgomery product
using u64x4 [[gnu::vector_size(32)]] = uint64_t;
using u32x8 [[gnu::vector_size(32)]] = uint32_t;
auto swap_bytes(auto x) {
return decltype(x)(__builtin_shufflevector(u32x8(x), u32x8(x), 1, 0, 3, 2, 5, 4, 7, 6));
}
u64x4 montgomery_reduce(u64x4 x, uint32_t mod, uint32_t imod) {
auto x_ninv = u64x4(_mm256_mul_epu32(__m256i(x), __m256i() + imod));
x += u64x4(_mm256_mul_epu32(__m256i(x_ninv), __m256i() + mod));
return swap_bytes(x);
}
u64x4 montgomery_mul(u64x4 x, u64x4 y, uint32_t mod, uint32_t imod) {
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
}
u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
return u32x8(montgomery_mul(u64x4(x), u64x4(y), mod, imod)) |
u32x8(swap_bytes(montgomery_mul(u64x4(swap_bytes(x)), u64x4(swap_bytes(y)), mod, imod)));
}
Now, the way in which we add vectorization to the main loop is, actually, very similar to what we did so far: we just split current blocks in sub-blocks, each being accumulated in one vector component. Additionally, since Montgomery multiplication transforms $$$(a, b)$$$ into $$$ab \cdot 2^{-32}$$$, the best thing we can do is keep track of the total power of $$$2$$$ by which we multiply the numbers, and then cancel this accumulated $$$2^{-32}$$$ in the very end, so that we don't interfere with them in the main loop.
Doing this, reduces the time to 119 ms, another 2.4x improvement.
So, should we call it a day?
Quick inverses, powers and input
Oh, but wait, shouldn't vectorization be doing 4 operations at once? Why is the improvement not 4x then?
When going from runtime $$$p+n$$$ to $$$\frac{p}{4} + n \log p$$$, we gave up on linear scaling in $$$n$$$, and added a factor of $$$\log p$$$ to it. It didn't matter as much previously, but at this stage it eats up a lot of compute. There are two main offenders that we can fix here:
- Multiplication by $$$2^x$$$ for particularly large $$$x$$$;
- Finding the inverse of the result for certain results.
If done naively, they take $$$O(\log p)$$$, but we can improve it to $$$O(1)$$$.
In particular, for quick power of $$$2$$$ modulo $$$p$$$, we can pre-compute powers of $$$2$$$ and of $$$2^{2^{16}}$$$ up to $$$2^{16}$$$: Fast fixed pow
template
unsigned pow_fixed(int n) {
static vector prec_low(1 << 16);
static vector prec_high(1 << 16);
static bool init = false;
if(!init) {
init = true;
prec_low[0] = prec_high[0] = 1;
unsigned step_low = base;
unsigned step_high = bpow(base, 1 << 16);
for(int i = 1; i < (1 << 16); i++) {
prec_low[i] = uint64_t(prec_low[i - 1]) * step_low % mod;
prec_high[i] = uint64_t(prec_high[i - 1]) * step_high % mod;
}
}
return uint64_t(prec_low[n & 0xFFFF]) * prec_high[n >> 16] % mod;
}
And for the inverses of multiple numbers, we can find them in $$$O(n + \log p)$$$ instead of $$$O(n \log p)$$$: Fast bulk inverse
vector bulk_invs(auto const& args) {
vector res(size(args), args[0]);
for(size_t i = 1; i < size(args); i++) {
res[i] = uint64_t(res[i - 1]) * args[i] % mod;
}
auto all_invs = inv(res.back());
for(size_t i = size(args) - 1; i > 0; i--) {
res[i] = uint64_t(all_invs) * res[i - 1] % mod;
all_invs = uint64_t(all_invs) * args[i] % mod;
}
res[0] = all_invs;
return res;
}
Adding all these and a faster IO to the solution allow us to drop the runtime further, to 99 ms, for a 1.2x improvement.
Fallback to regular factorial computation
Alright, going below 100 ms is good, but I promised 60 ms, didn't I?.. So, is there still something that we didn't optimize enough?
There sure is! We reduce each $$$n!$$$ computation to $$$\log p$$$ queries on "odd factorials", and it puts a lot of strain on storing these queries, as well as actually multiplying the sub-query results. What we can do here instead, is to only use odd factorials on a few top levels, and then fallback to regular factorial computation when the current value of $$$n$$$ is small enough.
Finally, adding this optimization gets us to 64 ms, a final 1.5x improvement.
There are still some minor optimizations to push it to 61 ms, but they're somewhat ugly, so let's skip those :)
Want some more?
Further reading: (3)
I hope the blog was somehow somewhat useful to someone :)
What's Your Reaction?






