6

从零到一实现 Rust 的 Channel 并发处理模型

 1 year ago
source link: https://www.51cto.com/article/751236.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

从零到一实现 Rust 的 Channel 并发处理模型

作者:三元同学 2023-04-06 08:01:30
这篇文章我们介绍 Rust 中并发的基础概念,包括 Mutex、Condvar、Arc、Atomic 等,然后我们实现了一个简单的 MPSC channel,即多生产者单消费者模型,理解了 channel 内部的实现原理,其内部也是基于 Mutex 和 Condvar 这些基础的原语来实现的。
294399339363dede56696566f8e7fbea9a7d9b.png

随着 SWC、NAPI-RS、Rspack 等等 Rust 前端工具链的出现,Rust 正在逐步成为前端工程化的一种新的选择,无论是在性能、安全性还是开发体验上都有着很大的优势。笔者在工作中也在使用 Rust 进行一些前端工具链的开发工作,对于 Rust 的一些特性也在不断的学习和探索,最近也会不定期的分享一些 Rust 的相关内容,比如: 如何用 napi-rs 搭建一个 Node.js 可以调用的 Rust 库、Rust 并发和异步模型、Rust 宏编程 等等话题。

这篇文章将会围绕 Rust 的并发模型展开,首先会介绍并发的基本概念,然后会对 Rust 中一些重要的并发工具进行介绍,比如 Atomic、Mutex、Condvar 等等,最后会实现一个 channel 并发处理模型。

注: 关于基础的环境搭建和语法内容不会进行讲解,可以参考 《Rust 语言圣经》这本书,相信对于初学者是一个不错的选择,地址: https://course.rs/about-book.html。

什么是并发?

要理解并发,我们绕不开另外一个相似的概念——并行,这两个概念也是计算机科学中经常被提到的两个概念,它们之间到底有什么区别?

这里引入非常经典的解释,来自 Golang 之父 Rob Pike 的一段话:

Concurrency is about dealing with lots of things at once. Parallelism is about doing lots of things at once.

翻译过来就是: 并发是指同时处理很多事情,而并行是指同时做很多事情。

在并发的场景中,对于正在处理的一些任务,虽然看起来好像它们在同时执行,但实际上是通过在单个处理器上交替轮流运行,某个时刻只有一个任务在运行,而其他任务都处于等待状态。

而在并行的场景中,对于正在处理的一些任务,它们是真正的同时执行。

而两者也并不是相互排斥的,并发和并行可以同时存在,比如在多核的 CPU 中,我们可以同时运行多个并发的任务,这样就可以充分利用多核 CPU 的优势,提高程序的执行效率。

Rust 中的并发原语

我们通常可以通过把任务放到多线程,或者多个异步任务来实现并发,在这个过程中,其实真正的难点不在于如何创建多个线程或者异步任务,而在于如何处理这些并发任务的同步和竞态问题。

在 Rust 中,提供了一些并发原语,来帮助我们处理并发任务的同步和竞态问题,这些原语包括: Atomic、Mutex、Condvar、Arc 等等,下面我们来逐一介绍一下。

Atomic

Atomic 是原子操作,它提供了一些原子操作的方法,比如 fetch_add、fetch_sub 等等,这些方法都是原子化的,也就是说,这些方法在执行的过程中,不会被其他线程打断,也不会被其他线程修改,这样就可以保证这些方法的执行是安全的。比如:

use std::sync::atomic::{AtomicUsize, Ordering};

let a = AtomicUsize::new(0);
a.fetch_add(1, Ordering::SeqCst);

Ordering::SeqCst 代表严格控制操作顺序的一致性,可以参考: https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html

上面的代码中,我们创建了一个 AtomicUsize 类型的变量 a,然后调用了 fetch_add 方法,这个方法会将 a 的值加 1,这个过程是原子化的。

为什么这里要突出强调一下原子化呢?这里我们来举个例子:

use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;

let counter = AtomicUsize::new(0);

let t1 = thread::spawn(|| {
    for _ in 0..100 {
        counter.fetch_add(1, Ordering::Relaxed);
    }
});

let t2 = thread::spawn(|| {
    for _ in 0..100 {
        counter.fetch_add(1, Ordering::Relaxed);
    }
});

t1.join().unwrap();
t2.join().unwrap();

assert_eq!(counter.load(Ordering::Relaxed), 200);

如果 fetch_add 方法执行不是原子化的,那么就可能出现竞态问题。例如,当线程 t1 和 t2 同时运行时,它们可能读取相同的计数器值,然后各自将其增加,并将结果存回计数器中,从而导致丢失一次增加的操作。这样就会导致最终结果小于预期值 200。

所以所谓的原子化,实际上是将某些步骤合并成一个原子操作,不能中断,拿这里的 fetch_add 来说:

  1. 读取 counter 的值。
  2. 将 counter 的值加 1。

这两个步骤不能中断,如果中断了,那么就会导致竞态问题。

Mutex

Mutex 是常用的一种互斥锁,它可以保证在同一时刻,只有一个线程可以访问某个数据,其他线程必须等待,直到锁被释放。

Mutex 有两种状态: 锁定和未锁定,当 Mutex 处于锁定状态时,其他线程就无法再次获取锁,直到 Mutex 处于未锁定状态。

举一个例子:

use std::sync::Mutex;
use std::thread;

let counter = Mutex::new(0);

let mut handles = vec![];

for _ in 0..10 {
    let handle = thread::spawn(move || {
        let mut value = counter.lock().unwrap();
        *value += 1;
    });
    handles.push(handle);
}

for handle in handles {
    handle.join().unwrap();
}

println!("Result: {}", *counter.lock().unwrap());

这段代码会有编译问题,后续会分析。

这里我们通过循环创建了 10 个线程来增加计数器的值。每个线程都获取了 Mutex 锁,并修改了计数器的值。当某个线程完成时,它会释放互斥锁,允许其他线程进行修改。

最后,我们使用 join() 方法等待所有线程完成,并打印出最终结果。

但这里的代码涉及到所有权转移的问题,我们知道,在 Rust 中,同一时间一个变量只能有一个所有者,当我们将 counter 传递给线程时,就会发生所有权转移,这样就会导致其它的线程无法获取 counter 的所有权,导致编译报错。

我们需要使用 Arc 来解决这个问题。

Arc 是原子引用计数,它可以在多个线程之间共享数据,它的内部实现是通过原子操作来实现的,所以它是线程安全的。

我们可以通过 Arc::new 来创建一个 Arc 对象,然后通过 Arc::clone 来克隆一个 Arc 对象,这样就可以在多个线程之间共享数据了。

use std::sync::{Arc, Mutex};
use std::thread;

let counter = Arc::new(Mutex::new(0));

let mut handles = vec![];

for _ in 0..10 {
    let counter = Arc::clone(&counter);
    let handle = thread::spawn(move || {
        let mut value = counter.lock().unwrap();
        *value += 1;
    });
    handles.push(handle);
}

for handle in handles {
    handle.join().unwrap();
}

println!("Result: {}", *counter.lock().unwrap());

Condvar

Condvar 是一个条件变量,它可以让线程等待某个条件满足,然后再执行。比如:

use std::sync::{Arc, Condvar, Mutex};

let pair = Arc::new((Mutex::new(false), Condvar::new()));

let pair2 = Arc::clone(&pair);

let thread1 = std::thread::spawn(move || {
    let (lock, cvar) = &*pair2;
    let mut started = lock.lock().unwrap();
    *started = true;
    cvar.notify_one();
});

let (lock, cvar) = &*pair;

let mut started = lock.lock().unwrap();

while !*started {
    started = cvar.wait(started).unwrap();
}

thread1.join().unwrap();

上面的代码中,我们创建了一个 pair,它是一个元组,第一个元素是一个 Mutex,第二个元素是一个 Condvar。然后我们创建了一个线程 thread1,它会将 Mutex 中的值设置为 true,然后调用 Condvar 的 notify_one 方法,通知 Condvar 等待的线程。

而在主线程中,我们会调用 Condvar 的 wait 方法,等待 Condvar 的通知,当主线程收到通知后,就会继续执行。

使用 Channel 处理并发

读到这里,你可能会说了,我们使用 Mutex、Arc、Condvar 等方式来处理并发,看起来很麻烦呀?其实,Rust 中还有一种更简单的方式来处理并发,那就是通过 Channel。

Channel 的本质是一个消息队列,它可以让多个线程之间进行消息通信,把读者和写者分离。根据读者和写者的数量,Channel 可以分为下面的几个类型:

  • 单生产者单消费者(Single Producer, Single Consumer, SPSC)
  • 单生产者多消费者(Single Producer, Multiple Consumer, SPMC)
  • 多生产者单消费者(Multiple Producer, Single Consumer, MPSC)
  • 多生产者多消费者(Multiple Producer, Multiple Consumer, MPMC)

其中 MPSC 是最常用的,在 Rust 中,它是通过 std::sync::mpsc 模块来实现的。我们来看看它是如何使用的。

use std::sync::mpsc;

let (s, r) = mpsc::channel();

let s1 = mpsc::Sender::clone(&s);

std::thread::spawn(move || {
    let val = String::from("hi");
    s1.send(val).unwrap();
});

let received = r.recv().unwrap();

println!("Got: {}", received);

上面的代码中,我们创建了一个 Channel,它是一个元组,第一个元素是一个 Sender,第二个元素是一个 Receiver。Sender 用来发送消息,Receiver 用来接收消息。

我们通过 mpsc::Sender::clone 方法来克隆一个 Sender,然后将克隆的 Sender 传递给线程,线程中通过 Sender 的 send 发送消息。而在主线程中,我们通过 Receiver 的 recv 方法来接收消息。

实现一个 Channel

接下来我们基于 Arc、Mutex、Condvar 来实现一个 Channel,它的功能和 std::sync::mpsc 中的 channel 类似,支持多生产者单消费者。

1、创建项目

首先我们通过 cargo new my-channel --lib 来创建一个库项目,然后在 Cargo.toml 中添加依赖:

[dependencies]
anyhow="1.0.40"

anyhow 是一个错误处理库,它可以让我们更方便的处理错误。

2、整体设计

对外暴露一个 channel 函数,它返回一个 Sender 和 Receiver,Sender 用来发送消息,Receiver 用来接收消息。

pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
    todo!()
}

因此关键的数据结构就是 Sender 和 Receiver,它们都需要持有一个共享的内部数据结构,我们将其命名为 Inner,它的定义如下:

// src/lib.rs
use anyhow::{anyhow, Ok, Result};
use std::{
    collections::VecDeque,
    sync::{atomic::AtomicUsize, Arc, Condvar, Mutex},
};

struct Inner<T> {
    // 共享的数据
    data: Mutex<VecDeque<T>>,
    // 条件变量
    condvar: Condvar,
    // 发送者数量,使用原子操作
    senders: AtomicUsize,
    // 接收者数量,使用原子操作
    receivers: AtomicUsize,
}

pub struct Sender<T> {
    inner: Arc<Inner<T>>,
}

pub struct Receiver<T> {
    inner: Arc<Inner<T>>,
}

OK,确定了数据结构之后,我们来实现 Sender 和 Receiver 的行为。

3、实现 Sender

首先我们来实现 Sender:

impl<T> Sender<T> {
    pub fn send(&self, value: T) -> Result<()> {
        todo!()
    }

    pub fn get_receivers_count(&self) -> usize {
        todo!()
    }
}

我们需要实现下面的方法:

  • send 方法,用来发送消息。
  • get_receivers_count 方法,用来获取接收者的数量。

具体实现如下:

impl<T> Sender<T> {
    pub fn send(&self, value: T) -> Result<()> {
        // 如果没有接收者了,就抛错
        if self.get_receivers_count() == 0 {
            return Err(anyhow!("no more receivers"));
        }
        let mut data = self.inner.data.lock().unwrap();
        data.push_back(value);
        // 通知接收者
        self.inner.condvar.notify_one();
        Ok(())
    }

    pub fn get_receivers_count(&self) -> usize {
        self.inner
            .receivers
            .load(std::sync::atomic::Ordering::SeqCst)
    }
}

上面的代码中,我们通过 get_receivers_count 方法来获取接收者的数量,如果没有接收者了,就抛错。然后我们通过 Mutex 的 lock 方法来获取锁,然后将消息放入队列中,最后通过 Condvar 的 notify_one 方法来通知接收者。

4、实现 Receiver

接下来我们来实现 Receiver:

impl<T> Receiver<T> {
    pub fn recv(&self) -> Result<T> {
        todo!()
    }

    pub fn get_senders_count(&self) -> usize {
        todo!()
    }
}

我们需要实现下面的方法:

  • recv 方法,用来接收消息。
  • get_senders_count 方法,用来获取发送者的数量。

具体实现如下:

impl<T> Receiver<T> {
    pub fn recv(&self) -> Result<T> {
        let mut data = self.inner.data.lock().unwrap();
        loop {
            // 如果没有发送者了,就抛错
            if self.get_senders_count() == 0 {
                return Err(anyhow!("no more senders"));
            }
            // 如果队列中有消息,就返回
            if let Some(value) = data.pop_front() {
                return Ok(value);
            }
            // 如果队列中没有消息,就等待
            data = self.inner.condvar.wait(data).unwrap();
        }
    }

    pub fn get_senders_count(&self) -> usize {
        self.inner
            .senders
            .load(std::sync::atomic::Ordering::SeqCst)
    }
}

上面的代码中,我们通过 get_senders_count 方法来获取发送者的数量,如果没有发送者了,就抛错。

然后我们通过 Mutex 的 lock 方法来获取锁,通过 Condvar 的 wait 方法来等待消息,如果队列中有消息,就返回,如果队列中没有消息,就继续等待,直到有消息为止。

当然,我们还需要实现 Drop trait,当 Sender 或者 Receiver 被释放时,我们需要更新发送者数量或者接收者数量:

impl<T> Drop for Sender<T> {
    fn drop(&mut self) {
        self.inner
            .senders
            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
    }
}

impl<T> Drop for Receiver<T> {
    fn drop(&mut self) {
        self.inner
            .receivers
            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
    }
}

5、实现 channel 函数

最后我们来实现 channel 函数:

pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
    let inner = Arc::new(Inner {
        data: Mutex::new(VecDeque::new()),
        condvar: Condvar::new(),
        senders: AtomicUsize::new(1),
        receivers: AtomicUsize::new(1),
    });
    (
        Sender {
            inner: inner.clone(),
        },
        Receiver { inner },
    )
}

我们通过 Arc 来包装 Inner,然后创建一个 Sender 和一个 Receiver,最后返回。

我们来测试一下目前的 channel 能否正常工作:

#[test]
fn test_channel() {
    let (mut s, r) = channel();
    let mut s1 = s.clone();
    let mut s2 = s.clone();
    let t = thread::spawn(move || {
        s.send(1).unwrap();
    });
    let t1 = thread::spawn(move || {
        s1.send(10).unwrap();
    });
    let t2 = thread::spawn(move || {
        s2.send(100).unwrap();
    });
    for handle in [t, t1, t2] {
        handle.join().unwrap();
    }

    let mut result = [r.recv().unwrap(), r.recv().unwrap(), r.recv().unwrap()];
    // 保证顺序的稳定
    result.sort();

    assert_eq!(result, [1, 10, 100]);
}

#[test]
fn with_no_senders() {
    let (s, r) = channel::<i32>();
    drop(s);
    assert!(r.recv().is_err());
}

#[test]
fn with_no_receivers() {
  let (mut s, _) = channel::<i32>();
  assert!(s.send(1).is_err());
}

OK,目前的 channel 已经可以正常工作了。

这篇文章中,我们介绍了 Rust 中并发的基础概念,包括 Mutex、Condvar、Arc、Atomic 等,然后我们实现了一个简单的 MPSC channel,即多生产者单消费者模型,理解了 channel 内部的实现原理,其内部也是基于 Mutex 和 Condvar 这些基础的原语来实现的。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK