18

Cantor Trading

 2 years ago
source link: https://cantortrading.fi/rust_decimal_str/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Parsing Decimals 4 times faster

December 31, 2021

Introduction

At Cantor, our systems must rapidly react to new updates in market data. Market data tends to be very bursty, and we only have a few microseconds to handle a message during these bursts. Unfortunately, these messages tend to come in JSON format, optimized for simplicity and readability instead of parsing performance. The performance of our market data handlers is thus limited by how fast we can parse these JSON messages into a native rust format.

We use rust-decimal to represent numeric values in market data. rust-decimal represents a number using a 128-bit binary representation consisting of a 96-bit integer mantissa, an exponent in the range 0 to 28 inclusive, and a 1-bit sign: sign * mantissa * 10^-exponent. We love this library for its rich API, exact base-10 representation, and nice inter-operation with other crates, but we noticed one problem: It regularly costs ~100ns to parse a decimal from string, which became a bottleneck in our application:

Perf dump of a benchmark

This prompted us to improve the parsing performance of rust_decimal. We have two benchmarks as the baseline:

Single-pass parsing

The original implementation of decimal parsing looks like:

// heavily simplified semi-pseudo code
fn parse_str_radix_10(bytes: &[u8]) -> Result<Decimal> {
    // check the first character and determine `sign`
    let mut coeff = ArrayVec<u32>::new();
    for b in bytes.iter() {
        match b {
            b'0'..=b'9' => { coeff.push(b - b'0'); }
            b'.' => { /* b is the decimal point. record its position to calculate `scale` */ }
            b'_' => {} // b is ignored
            _ => return Err,
        }
    }

    if needs_rounding {
        let b = next_digit(bytes);
        if b >= 5 { // round up
            // handle overflow by looping over `coeff` and increment the numbers
        }
    }

    let mut num: [u32; 3] = [0, 0, 0];
    for c in coeff.iter() {
        mul_by_10(&mut num); // num *= 10
        add_by(&mut num, c); // num += c
        if overflow(num) { return Err; }
    }

    return Decimal::from_parts(num, scale, sign)
}

It first loops over the characters in the input string, converts each digit to a u32, and stores them in a coeff array. If overflow occurs (which rarely happens), it rounds up the digits in the coeff array. Finally, it loops over coeff and computes the output number using manual 96-bit arithmetic functions.

It is inefficient and unnecessary to loop twice. In addition, we have to reserve 128 (32 × 4) extra bytes on the stack to store the digits. We simplified the logic to compute num in one pass.

The single-pass code looks like:

// heavily simplified semi-pseudo code
fn parse_str_radix_10(bytes: &[u8]) -> Result<Decimal> {
    // check the first character and determine `sign`
    let mut num: [u32; 3] = [0, 0, 0];

    for b in bytes.iter() {
        match b {
            b'0'..=b'9' => {
                let c = b - b'0';
                mul_by_10(&mut num); // num *= 10
                add_by(&mut num, c); // num += c
                if passed_decimal_point { scale += 1 }
                if overflow(num) { return Err; }
            }
            b'.' => { passed_decimal_point = true; }
            b'_' => {} // b is ignored
            _ => return Err,
        }
    }

    if needs_rounding {
        let b = next_digit(bytes);
        if b >= 5 { // round up
            add_by(&mut num, 1); // num += 1
        }
    }

    return Decimal::from_parts(num, scale, sign)
}

mul_by_10 and add_by and are manual implementations of 96 bit arithmetic, and the functions look roughly as follows:

fn add_by(num: &mut [u32], by: u32) {
    for i in num.iter_mut() {
        // do some computation
    }
}

In Rust, slices’ lengths are runtime determined. Even though we always pass in num as an array of 3 u32s, the compiler cannot take advantage of this fact to optimize the function. If we change the signature to:

fn add_by(num: &mut [u32; 3], by: u32)

The compiler now knows that num is exactly 3 u32s at compile time, and will have a lot more opportunities to optimize this function, e.g. by replacing num.len() with a constant and optimizing specific code paths away. Since there is only one caller throughout the crate, we manually flattened the function and unrolled the loop.

With these modifications, we already see a ~35% reduction in parse time - the benchmarks are now:

  • Decimal sample benchmark: 356ns/iter
  • Market data benchmark: 9.5us/iter

Using compiler-supported integer operations

rust-decimal internally stores a 96-bit integer mantissa, and as a result, is required to do 96-bit integer operations on that. For example, some code to multiply by 10, with each number represented as a &[u32] is

/// words[n] = words[n] * 10 + overflow(words[n-1])
#[inline]
pub(crate) fn mul_by_10(bits: &mut [u32; 3]) -> u32 {
    let mut overflow = 0u64;
    for b in bits.iter_mut() {
        // Calculate word * 10,m and add overflow from previous word
        let result = u64::from(*b) * 10u64 + overflow;
        let hi = (result >> 32) & U32_MASK;
        let lo = (result & U32_MASK) as u32;
        *b = lo;
        overflow = hi;
    }

    overflow as u32
}

generating the respectable, but long, code:

mul_by_10:
    movl    (%rdi), %eax
    movl    4(%rdi), %ecx
    addq    %rax, %rax
    leaq    (%rcx,%rcx,4), %rcx
    leaq    (%rax,%rax,4), %rax
    movl    %eax, (%rdi)
    shrq    $32, %rax
    leaq    (%rax,%rcx,2), %rax
    movl    %eax, 4(%rdi)
    shrq    $32, %rax
    movl    8(%rdi), %ecx
    leaq    (%rcx,%rcx,4), %rcx
    leaq    (%rax,%rcx,2), %rax
    movl    %eax, 8(%rdi)
    shrq    $32, %rax
    retq

This can carry out each multiplication in parallel, but each word carries a dependency.

An immediate optimization is to perform a 128-bit operation:

/// words[n] = words[n] * 10 + overflow(words[n-1])
#[inline]
pub(crate) fn mul_by_10(bits: &mut u128) {
    *bits *= 10;
}

This generates the compact (and in fact pessimistic) code

mul_by_10_u128:
    movl    $10, %edx
    movq    8(%rdi), %rax
    mulxq   (%rdi), %rdx, %rcx
    leaq    (%rax,%rax,4), %rax
    movq    %rdx, (%rdi)
    leaq    (%rcx,%rax,2), %rax
    movq    %rax, 8(%rdi)
    retq

This code is somewhat better - we only have two multiplications, except there’s a dependency between each one since the multiplication properly (and pointlessly handles overflow). It doesn’t buy us very much, though.

However, an extension of this idea yields a much larger improvement. We maintain state in a 64-bit integer and only fall back to 128bit handling if needed.

This optimization yields much better results, another ~45% improvement in parse time:

  • Decimal sample benchmark: 193ns/iter
  • Market data benchmark: 7.9us/iter

Threading the needle with tail calls and const generics

The final optimization involves using tail calls to optimally split out fast/slow paths and rapidly switch between optimal parser loops. First, let’s talk about tail calls.

A tail call is a function call being the final action the calling function performs, with no access of local variables after the function call. In this case, it’s valid for the compiler to replace a call + ret instruction with simply jumping to the callee. You can play around with it here:

In simple cases, like the one above, the code compiles into something very close to the loop version of the code. In the general case, tail calls act as a sort of structured goto, which can be very advantageous for performance purposes. Specifically, tail calls make it much easier to optimize interpreters and parsers - internally, we brought about a ~30% performance improvement to an in-house interpreter by utilizing tail calls.

These articles about high performance protobuf parsing and why LuaJIT’s interpreter is fast do a great job of detailing the difficulties that compilers have with naively written interpreters so I won’t repeat everything here, but the main points are:

  • Compilers have trouble reasoning about large functions with complex control flow
    • Register allocation makes poor decisions about what data should be kept in registers or spilled to the stack
    • Optimizer will aggressively try to merge identical snippets in different paths, hurting overall code quality
    • Compilers cannot reason about fast/slow paths mixed into the same function and compromise fast paths to optimize slow paths
    • Branches are often shared between paths, losing prediction information
  • Having multiple specialized fast-paths in a single loop confuses compilers
    • Switching between the fast-paths complicates control flow, leading to the above problems
    • Code duplication can help solve this at a high maintenance cost, but not perfectly

Tail calls offer a solution here since they:

  • Segregate control flow into small and easy-to-reason-about functions
  • Allow simple jumps into distinct parts of the control flow graph while informing the compiler that the local state isn’t shared
  • Give each operation dedicated dispatch and branch predictions

How does this help us? Let’s consider the u64 parsing loop.

A basic graph of the control flow looks like:

Control flow of looping parser

I’m using green nodes for fast paths, yellow nodes for slow fallbacks that may exit the loop, red for a terminal error, and blue for leaving the loop.

Something that’s immediately clear is that the control flow is not straightforward - in addition to the branch-out-and-converge, there are multiple exit paths from numerous points in the loop to both fast and slow paths, and all sorts of conditions repeatedly getting checked. Many of these checks don’t even make sense - why would we re-check for overflow after observing a ’_’, or even try to handle a ’.’ again after seeing a ’.‘?

We solve this with tail calls by giving each operation a small function, which only does whatever work is required and then dispatches the next call.

For example, here’s a simplified version of the dispatch code and two handlers:

#[inline]
fn byte_dispatch_u64(bytes: &[u8], data64: u64, b: u8) -> Result<Decimal, Error> {
    match b {
        b'0'..=b'9' => handle_digit_64(bytes, data64,b - b'0'),
        b'.' => handle_point(bytes, data64),
        b => non_digit_dispatch_u64(bytes, data64, b),
    }
}

#[inline(never)]
fn handle_digit_64(bytes: &[u8], data64: u64,digit: u8) -> Result<Decimal, Error> {
    // we have already validated that we cannot overflow
    let data64 = data64 * 10 + digit as u64;

    if let Some((next, bytes)) = bytes.split_first() {
        let next = *next;
        if overflow_64(data64) {
            handle_full_128(data64 as u128, bytes, next)
        } else {
            byte_dispatch_u64(bytes, data64, next)
        }
    } else {
        finish_data(data as u128)
    }
}

#[inline(never)]
fn handle_point(bytes: &[u8], data64: u64) -> Result<Decimal, Error> {
    if let Some((next, bytes)) = bytes.split_first() {
        byte_dispatch_u64(bytes, data64, *next)
    } else {
        finish_data(data64 as u128)
    }
}

You can see from the code that each call does some small incremental part of parsing and decides where to dispatch or whether to build the final Decimal. The expanded control flow graph, with functions outlined by blocks and error paths removed for brevity, looks something like:

Control flow of tail call parser

You can also see that the generated code is very straightforward, especially given that each call includes dispatching and decimal-completion code:

handle_digit:
    -- Decimal manipulation
    lea    (%rcx,%rcx,4),%rax
    movzbl %r8b,%ecx
    lea    (%rcx,%rax,2),%rcx

    -- Bounds checking
    sub    $0x1,%rdx
    jb     <complete decimal>

    -- Dispatch next
    movzbl (%rsi),%eax
    add    $0x1,%rsi
    lea    -0x30(%rax),%r8d
    cmp    $0xa,%r8b
    jb     <handle_digit>
    cmp    $0x2e,%al
    jne    <handle_digit>
    jmp    <handle_point>

    -- Construct Decimal (entered from function prelude)
    test   %ecx,%ecx
    setne  %al
    mov    %rcx,%rdx
    shr    $0x20,%rdx
    setne  %dl
    or     %al,%dl
    movzbl %dl,%eax
    shl    $0x1f,%rax
    mov    %rcx,0xc(%rdi)
    mov    %rax,0x4(%rdi)
    movl   $0x0,(%rdi)
    ret

    -- dispatch non-numeric-or '.' character
    movzbl %al,%r8d
    jmp    <non_digit_dispatch>
    add    %al,(%rax)

While embedding dispatch and decimal construction into each operation increases code size, it offsets that by allowing for more optimization and specialized predictions on a per-node basis.

We take the specialization trick one step further to get another speedup - we move some of the parser state to a set of compile time booleans and use these to gate off checks in the code.

For example, the actual handle_digit_u64 signature is more like

fn handle_digit_64<
  const POINT: bool,
  const NEG: bool,:w
  
  const HAS: bool,
  const BIG: bool
>(
    bytes: &[u8],
    data64: u64,
    digit: u8,
)

Parameterized by a set of checkpoints we expect to encounter while parsing decimals. For example, POINT and BIG refer to whether we have seen a ’.’ yet and whether the string is long enough to cause overflow.

We can re-write some of the code above to remove branches that we don’t need to evaluate statically:

// Inside byte_dispatch_u64 - only catch the '.' case if we haven't seen a point
b'.' if !POINT => handle_point(bytes, data64)

// ...

// Inside handle_digit_u64 - skip overflow checks if the string is not big
if BIG && overflow_64(data64) {
    handle_full_128(data64 as u128, bytes, next)
} else {
    byte_dispatch_u64(bytes, data64, next)
}

// Inside handle_point - dispatch bytes with point flag set

if let Some((next, bytes)) = bytes.split_first() {
    // Pseudocode here
    byte_dispatch_u64<POINT=true>(bytes, data64, *next)
} else {
    finish_data(data64 as u128)
}

This also allows us to swap between specialized parsers at runtime, with the only serious cost being additional code bloat - and even more aggressive optimizations and specialized dispatch sites offset the code bloat.

This change brings us another 30% performance gain (on my M1 laptop an impressive 50%!):

  • Decimal sample benchmark: 132ns/iter
  • Market data benchmark: 7.4us/iter

Since we’ve already removed most of the decimal parsing overhead, this change “just” leads to a 6% improvement in the message parsing. Given the tight performance constraints, any progress is still a win.

Non-guaranteed tail calls

Unfortunately, Rust cannot guarantee tail calls. An RFC has been in the works since 2017, but it appears to be in some mix of dead/stuck-in-limbo. Even though tail calls bring a massive performance increase (1.5-2x depending on platform), the lack of compiler support means our options are:

  • Trust the optimizer in release and use a different strategy in debug
  • Try to construct some fast loop/match structure, although it’s hard to express things like specialized parsers and per-location dispatch.
  • Fall back to C/C++ and #[musttail] for the 50%+ performance gain or keep a slower rust implementation

Writing parsers is where Rust shines - the rust-decimal parser only uses safe and panic-free methods to access the array while using high-level abstractions like const generics to get many specialized parsers out of one function. Adding guaranteed tail calls will be another reason to choose Rust over less-safe alternatives.

What’s next

Optimizing rust-decimal’s string parsing is a great demonstration that you don’t always need magic algorithms to make things fast. We eschewed complex float parsing optimizations in favor of aggressively optimizing the simple code and removed float parsing as a significant bottleneck. We’re also upstreaming this change! There are plenty of places throughout the rest of rust-decimal to use similar optimizations, and we plan to apply them piece-by-piece to help make rust-decimal the fastest base-10 number library.

If you find this article interesting, we’re hiring rust engineers! We’re especially interested in speaking to people with a background in high-performance computing and deep knowledge of networking/web protocols (TCP, DNS, HTTP, etc). Send us an email at [email protected] if you’d like to talk.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK