36

Exploring RustFFT's SIMD Architecture - tutorials - The Rust Programming Languag...

 3 years ago
source link: https://users.rust-lang.org/t/exploring-rustffts-simd-architecture/53780/13
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Exploring RustFFT's SIMD Architecture

After Releasing RustFFT 5.0 yesterday 199, a few people asked for details on how RustFFT 5.0 achieved its speed improvements over RustFFT 4.0 and its newfound speed advantage over the C FFT library FFTW 28.

In this post, I'll start by giving a high-level overview about FFT computation, then discuss how RustFFT computes FFTs, and finally dive into exactly how the AVX code is structured to maximize speed.

FFT Overview

The FFT is fundamentally a recursive process. Nearly all FFT algorithms come down to finding a way to express a FFT as some combination of smaller FFTs. At the heart of FFT computation is the Mixed-Radix Algorithm, which depends on the prime factorization of the FFT input size.

If you can factor a FFT size N into integers A * B, you can use that factorization to compute the size-N FFT:

Algorithm 1:

  1. Treat the input array as a 2D array of width A, and height B.
  2. Compute size-B sub-FFTs down each column of our 2D array.
  3. Compute an O(N) processing step, by multiplying the array with precomputed complex numbers called "twiddle factors"
  4. Compute size-A sub-FFTs down the rows of our 2D array.

The most well-known FFT algorithm is probably the Cooley-Tukey Algorithm 20, and by reading the linked pseudocode, it becomes clear that Cooley-Tukey is just a special case of Algorithm 1 where A=2 and B=N/2, or vice versa.

If the size-A and size-B sub-FFTs are also computed with Algorithm 1, the result is an O(nlogn) divide-and-conquer algorithm with prime-number base cases, corresponding to the original prime factorization of N. These prime number base cases can then be computed using optimized hard-coded functions, or by FFT algorithms that specialize in prime number sizes.

For a simple example of a hardcoded base case, a size-2 FFT is just some additions and subtractions:

output[0] = input[0] + input[1];
output[1] = input[0] - input[1];

RustFFT's Algorithm

RustFFT's FFT computation is based on a specific formulation of the Mixed Radix algorithm, optimized for modern systems which are heavily dependent on CPU cache: the Six-Step FFT. Like Algorithm 1, we're going to factor our FFT size into A * B, and then treat our input array as a 2D array with width A and height B:

Algorithm 2:

  1. Transpose our 2D array
  2. Compute sub-FFTs of size B down the rows of our transposed array
  3. Multiply with precomputed twiddle factors
  4. Transpose again
  5. Compute sub-FFTs of size A down the rows of our transposed array
  6. Transpose one more time.

The MixedRadixSmall struct 37 in the RustFFT source code has an easy-to-read real-world example of Algorithm 2.

The key difference between Algorithm 2 and Algorithm 1 described in the previous section is step 2: Instead of computing FFTs down the columns, which requires strided memory access and is cache-unfriendly, we transpose the data first, so that the strided columns become contiguous rows. The overall result is that this FFT algorithm is inherently cache-friendly.

When the whole problem fits comfortably in cache, other algorithms such as RustFFT's Radix4 1 algorithm (which is also a special case of the Mixed-Radix Algorithm) are faster, but the larger your FFT size is, the faster the six-step FFT will be compared to anything else.

The six-step FFT appears to be fastest when the two factors A and B are as close to sqrt(N) as possible, with the ideal case being A=B, and RustFFT's planner has code in it to divide factors equally when planning a step of Mixed Radix.

SIMD Challenges

Unfortunately, even though Algorithm 2 is very fast for scalar operations, it's just about the worst possible architecture for SIMD optimization. All SIMD instruction sets are optimized around the idea of loading a handful of blocks of memory into registers, doing parallel pairwise operations until your data has been completely processed, and then writing it out and never touching it again.

As soon as you have to do one of:

  1. Swap data around within a SIMD register
  2. Load and store the same data more than once
  3. Load more than 10-15 chunks of memory at once (Resulting in lot of loads/stores to temporary memory, which violates rule #2)

You're going to see slowdowns. Any nontrivial application will have to do at least some of these, but as a general rule of thumb, if you're writing SIMD code, you should structure your algorithms to minimize them.

The problem is, Algorithm 2 is almost intentionally designed to break these rules. 3 of the 6 steps are transposes, and during a transpose you're doing nothing but violating rule #1. If we're recursively using this same algorithm in our sub-FFTs, then 5 of our 6 steps spend almost of all of their time violating rule #1.

SIMD Twiddle Factors

Step 3 of Algorithm 2, where we apply twiddle factors, is well-suited to SIMD. You load some data from the FFT array, load some twiddle factors, multiply them together, then write the result back to the FFT array. All three rules are satisfied.

When I first began experimenting with SIMD code back in April 2020, switching this step to use SIMD instrinsics was one of the first things I tried. Unfortunately, the overall performance barely changed. The most obvious reason is because the twiddle factor step is such a small percentage of the FFT's overall time. But there's an equally serious problem to overcome:

AVX has 16 registers to work with, letting you store a total of 16 * 8 floats at once - but when you're just doing a pairwise multiplication of 2 arrays, you're only using 3-4 of those registers, maximum, at a time. Most of the CPU's ability to pipeline operations is sitting unused.

There are a few tricks like auto-unrolling that help make pairwise multiplication more pipelined, but that doesn't solve the transpose problem. I quickly realized that a different architecture would be needed, one which satisfied the three SIMD rules above while taking advantage of as much of the CPU as possible.

AVX-Friendly Architecture

One approach would be to take Algorithm 2 and skip steps 1 and 4, giving us the "Four Step FFT" algorithm:

Algorithm 3:

  1. Compute sub-FFTs of size B down the columns of our 2D array
  2. Multiply with precomputed twiddle factors
  3. Compute sub-FFTs of size A down the rows of our 2D array
  4. Transpose

This setup gets us somewhere, because we can start computing the size-B column FFTs via AVX: Use _mm256_loadu_ps() to load 4 complex numbers from row 0, 4 complex numbers from row 1, 4 complex numbers from row 2, etc. Load one AVX register from each row, and we can compute 4 size-B FFTs in parallel.

This doesn't quite work, because we have to be able to handle arbitrary values of B. We only have 16 AVX registers to work with, so if B is more than 16, then we're violating rule #3 from our SIMD rules up above.

Hardcoding a Specific Column FFT Size

The solution is to make B a compile-time constant: Instead of having to worry about handling arbitrary values of B all in one implementation, we can monomorphize our FFT algorithms for various values of B: One FFT function that factors N into A=N/2, B=2 , another FFT function that factors N into A=N/3, B=3, another that factors N into A=N/4, B=4, etc.

If we pick 12 for example as the value for B, that gives us the following FFT algorithm.

Algorithm 4:

  1. Compute sub-FFTs of size 12 down the columns of our 2D array
    1.1. Load a single AVX vector from each of the 12 rows. Each vector can hold 8 floats, so we can load 4 columns' worth of data at a time.
    1.2. Compute 4 parallel FFTs of size 12 using SIMD instructions.
    1.3. Write our 12 AVX vectors back to where we originally loaded them.
  2. Multiply with precomputed twiddle factors
  3. Compute sub-FFTs of size N / 12 down the rows of our 2D array
  4. Transpose

Part of the reason why this setup is so fast is that in step 1.2, we're using 12 of our 16 AVX registers for FFT data, giving us 4 registers for temporary data, which turns out to be just barely enough. So as we compute these size-12 FFTs, we're doing so in a way that gives the compiler and processor the best possible opportunity to pipeline, parallelize, and reorder everything to be as fast as possible.

Merging Twiddle Factors

There is one final problem: We're still violating rule #2 from the SIMD rules earlier: Step 1 processes a bunch of data and writes it back to memory, and step 2 immediately loads that data back to do more operations on it. So the final step is to merge step 2 into step 1, giving us the algorithm found in RustFFT AVX code:

Algorithm 5:

  1. Compute sub-FFTs of size 12 down the columns of our 2D array
    1.1. Load a single AVX vector from each of the 12 rows. Each vector can hold 8 floats, so we can load 4 columns' worth of data at a time.
    1.2. Compute 4 parallel FFTs of size 12 using SIMD instructions.
    1.3. Multiply with precomputed twiddle factors
    1.4. Write our 12 AVX vectors back to where we originally loaded them.
  2. Compute sub-FFTs of size N / 12 down the rows of our 2D array
  3. Transpose

I guess we could call this the "Three Step FFT With Four Substeps" algorithm. Not quite as catchy as Six Step FFT or Four Step FFT. I think the latter two have some satisfying alliteration that this new name doesn't quite capture.

We're still violating Rule #1 in the Step 3 Transpose, and benchmarking shows that we do indeed spend around 40%-50% of our time on the transpose. I suspect that transposing, or bit reversing, or some other nontrivial reordering of the data appears to be a fact of life with FFT computation - but at least this way, it's only a third of the process, rather than most of it.

The MixedRadix8xnAvx 5 , MixedRadix9xnAvx, and MixedRadix12xnAvx 5 structs (and several others in that file) all implement Algorithm 5. "8xn", "9xn", etc are a reference that they factor the FFT size into 8 * N, or 9 * n, etc.

Composing Algorithms Into a FFT Plan

In Algorithm 5, if step 2 was a static function call, then we could only ever process FFTs of size 12^U, or more generally, B^U, for any given radix B. In order to be composable, we need some form of dynamic dispatch for these inner FFTs.

The final special sauce of RustFFT is that the inner FFT call in step 2 is a trait object of the Fft trait 2. The inner FFT of a MixedRadix12xn can be a MixedRadix5xn. The inner FFT of that MixedRadix5xn can be a hardcoded base case 6, optimized for a specific size. Any struct that implements the Fft trait can be an inner FFT for any other FFT, letting us compose FFTs however we need.

Once we have this capability, however, there's a new challenge: If we have a FFT size with a prime factorization of 2^10 * 3^8, there are 18 total prime factors that we could arrange in different ways, and the result is a combinatoric explosion of choices. How do we know where to start?

To begin with, it might be best to know which of our FFT algorithms is fastest.

Algorithm Performance

RustFFT's AVX code has implementations for Mixed-Radix 2xn, 3xn, 4xn, 5xn, 6xn, 7xn, 8xn, 9xn, 11x, 12xn, and 16xn.

Of these, 12xn is the fastest in benchmarks, followed closely by 8xn and 9xn. I suspect that they're faster than the others because, like I mentioned before, these algorithms do the best job of using CPU registers to their fullest potential. Also recall something I said way back at the beginning of this article: The Six-Step FFT appears to perform best when N is factored into A and B where A and B are close, and ideally equal. It stands to reason that the larger we can make the constant value in our monomorphized algorithms, the faster the resulting FFT will be.

With that in mind, one might think that 16xn would be the fastest -- but unfortunately, it is just barely too large to fit into AVX registers. We need 16 vectors for just the FFT data, and more for temporaries. The end result is that the CPU has to spill data to temporary memory during the computation, which violates rule #2 of our SIMD rules up above. These problems hamper 16xn just enough to make it lose out.

(Footnote: AVX-512 has 32 SIMD registers instead of 16, but that doesn't offer much hope, because in this architecture, a Bxn algorithm requires both B AVX registers storing FFT data, and B general-purpose registers storing array pointers. Our architecture thus doesn't give us an obvious way to use all 32 of those AVX-512 registers, since AVX-512 still only has 16 general-purpose registers. If you have any ideas here, I'd love to hear them!)

Planning Heuristics

With the benchmarking results in mind, the strategy for our FFT planner becomes clear: Pass the ball to TuckerMixedRadix12xnAvx. If there are no factors of 3 remaining in the FFT size, begrudgingly use 8xn instead. If instead, there are no multiples of 4 left, use 9xn.

In practice, there is an entire other planning challenge around how to choose a recursion base case -- and there's an entire separate architecture for writing hardcoded base cases -- but those seem to have a less meaningful impact on performance than most of the things discussed so far, and this article is already too long, so I'll stop here.

Unanswered Questions

  • FFTW beats RustFFT at a size of 512. 512 is a hardcoded base case in RustFFT, so it doesn't even have anything to do with the artchitecture this article discusses. How is their size-512 FFT so fast?

  • Is there a smarter way to plan which algorithm to use? "Pass the ball to 12xn" gets us to the point where we beat FFTW, but that model of heuristics (IE, run benchmarks, find a pattern, and then hardcode it to follow the pattern) is really labor-intensive, and is subject to human bias in a number of ways.

  • Like I mentioned above, this architecture doesn't extend well to the 32 registers of AVX-512. How can we take advantage of those registers?

  • One thing this article didn't even touch is split-vs-interleaved complex numbers.

    • RustFFT stores its complex numbers interleaved, IE in the memory layout, we have something like [0re, 0im, 1re, 1im, 2re, 2im]. This is very convenient to work with at a high level, but makes multiplications a little more complicated, because we have to shuffle data between lanes. If you remember way back to the SIMD rules in the middle of the article, shuffling data between lanes is a violation of Rule #1.
    • An alternative memory layout is to store the imaginary numbers in separate arrays for the real numbers. This would mean no data reshuffling when multiplying.
    • Is this alternative faster, even considering that you have to run O(N) pre and postprocessing steps to split the data apart and put it back together? Does it make the algorithms harder to understand? In the transpose steps, you'd have to do 8x8 f32 transposes instead of 4x4 complex transposes, which are considerably more complicated. Do the extra costs of those transposes outweigh the benefits?
  • Another issue this article didn't touch is f64 FFTs. RustFFT supports them using the exact same architecture, but there's a unique challenge around how to be generic over floating point type. RustFFT has the AvxVector trait 8 which has common operations on it, so that the MixedRadix8xn structs can call these generic methods instead, which has worked out well.

    • One challenge that I haven't been able to solve is hardcoded base cases: In order to compute a size-12 f32 FFT, you need an array of 8 twiddle factors, represented as [__m256; 2]. In order to compute a size-12 f64 FFT, you need an array of 6 twiddle factors, represented as [__m256d;3]. As far as I know, there's no proposal for specializing the layout of a struct based on a generic parameter, and in general there are countless challenges in the way of making hardcoded base cases generic.
    • Another challenge is that putting your SIMD intrinsics inside trait methods more or less requires you to make every single line of your AVX project unsafe. This is unreasonable, but unfortunately there's no solution for now. I opened a RFC 83 to fix this, but I think that's going to (very understandably) move pretty slowly.

Thanks for reading! If you want to help answer these questions, or just have comments in general, leave a reply here or on reddit, or send me a private message on either of the two.


Recommend

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK