4

.NET Task 揭秘(3)async 与 AsyncMethodBuilder - 黑洞视界

 1 year ago
source link: https://www.cnblogs.com/eventhorizon/p/17220585.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

.NET Task 揭秘(3)async 与 AsyncMethodBuilder

本文为系列博客

  1. 什么是 Task
  2. Task 的回调执行与 await
  3. async 与 AsyncMethodBuilder(本文)
  4. 总结与常见误区(TODO)

上文我们学习了 await 这个语法糖背后的实现,了解了 await 这个关键词是如何去等待 Task 的完成并获取 Task 执行结果。并且我们还实现了一个简单的 awaitable 类型,它可以让我们自定义 await 的行为。

class FooAwaitable<TResult>
{
    // 回调,简化起见,未将其包裹到 TaskContinuation 这样的容器里
    private Action _continuation;

    private TResult _result;
    
    private Exception _exception;

    private volatile bool _completed;

    public bool IsCompleted => _completed;

    // Awaitable 中的关键部分,提供 GetAwaiter 方法
    public FooAwaiter<TResult> GetAwaiter() => new FooAwaiter<TResult>(this);

    public void Run(Func<TResult> func)
    {
        new Thread(() =>
        {
            var result = func();
            TrySetResult(result);
        })
        {
            IsBackground = true
        }.Start();
    }

    private bool AddFooContinuation(Action action)
    {
        if (_completed)
        {
            return false;
        }
        _continuation += action;
        return true;
    }

    internal void TrySetResult(TResult result)
    {
        _result = result;
        _completed = true;
        _continuation?.Invoke();
    }
    
    internal void TrySetException(Exception exception)
    {
        _exception = exception;
        _completed = true;
        _continuation?.Invoke();
    }

    // 1 实现 ICriticalNotifyCompletion
    public struct FooAwaiter<TResult> : ICriticalNotifyCompletion
    {
        private readonly FooAwaitable<TResult> _fooAwaitable;
        
        // 2 实现 IsCompleted 属性
        public bool IsCompleted => _fooAwaitable.IsCompleted;

        public FooAwaiter(FooAwaitable<TResult> fooAwaitable)
        {
            _fooAwaitable = fooAwaitable;
        }

        public void OnCompleted(Action continuation)
        {
            Console.WriteLine("FooAwaiter.OnCompleted");
            if (_fooAwaitable.AddFooContinuation(continuation))
            {
                Console.WriteLine("FooAwaiter.OnCompleted: added continuation");
            }
            else
            {
                Console.WriteLine("FooAwaiter.OnCompleted: already completed, invoking continuation");
                continuation();
            }
        }

        public void UnsafeOnCompleted(Action continuation)
        {
            Console.WriteLine("FooAwaiter.UnsafeOnCompleted");
            if (_fooAwaitable.AddFooContinuation(continuation))
            {
                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: added continuation");
            }
            else
            {
                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: already completed, invoking continuation");
                continuation();
            }
        }

        // 3. 实现 GetResult 方法
        public TResult GetResult()
        {
            if (_fooAwaitable._exception != null)
            {
                // 4. 如果 awaitable 中有异常,则抛出
                throw _fooAwaitable._exception;
            }
            Console.WriteLine("FooAwaiter.GetResult");
            return _fooAwaitable._result;
        }
    }
}

如果在一个方法中使用了 await,那么这个方法就必须添加 async 修饰符。并且这个方法的返回类型通常是 Task 或者 其它 runtime 里定义的 awaitable 类型。

int foo = await FooAsync();
Console.WriteLine(foo); // 1

async Task<int> FooAsync()
{
    await Task.Delay(1000);
    return 1;
}

问题1: 上面的代码中,FooAsync 方法是一个异步方法,它的返回类型是 Task。但代码中的 await FooAsync() 并不会返回 Task,而是返回 int。这是为什么呢?

如果我们把 FooAsync 的返回值改成我们自己实现的 awaitable 类型,编译器会报错:

问题2: 明明我们可以在 FooAwaitable 实例上使用 await 关键词,为什么把它作为 FooAsync 的返回类型就会报错呢?且提示它不是一个 task-like 类型?

实际上我们在上篇文章实现的 awaitable 类型 FooAwaitable,只是支持了 await 关键词,并不是一个完整的 task-like 类型。

而上面两个问题的答案就是本文要讲的内容:AsyncMethodBuilder

AsyncMethodBuilder 介绍

AsyncMethodBuilder 是状态机的重要组成部分#

引用上一篇文章介绍状态机的代码:

class Program
{
    static async Task Main(string[] args)
    {
        var a = 1;
        Console.WriteLine(await FooAsync(a));
    }

    static async Task<int> FooAsync(int a)
    {
        int b = 2;
        int c = await BarAsync();
        return a + b + c;
    }

    static async Task<int> BarAsync()
    {
        await Task.Delay(100);
        return 3;
    }
}

由 FooAsync 编译成的 IL 代码经整理后的等效 C# 代码如下:

using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

class Program
{
    static async Task Main(string[] args)
    {
        var a = 1;
        Console.WriteLine(await FooAsync(a));
    }

    static Task<int> FooAsync(int a)
    {
        var stateMachine = new FooStateMachine
        {
            _asyncTaskMethodBuilder = AsyncTaskMethodBuilder<int>.Create(),
    
            _state = -1, // 初始化状态
            _a = a // 将实参拷贝到状态机字段
        };
        // 开始执行状态机
        stateMachine._asyncTaskMethodBuilder.Start(ref stateMachine);
        return stateMachine._asyncTaskMethodBuilder.Task;
    }

    static async Task<int> BarAsync()
    {
        await Task.Delay(100);
        return 3;
    }

    public class FooStateMachine : IAsyncStateMachine
    {
        // 方法的参数和局部变量被编译会字段
        public int _a;
        public AsyncTaskMethodBuilder<int> _asyncTaskMethodBuilder;
        private int _b;

        private int _c;

        // -1: 初始化状态
        // 0: 等到 Task 执行完成
        // -2: 状态机执行完成
        public int _state;

        private TaskAwaiter<int> _taskAwaiter;

        public void MoveNext()
        {
            var result = 0;
            TaskAwaiter<int> taskAwaiter;
            try
            {
                // 状态不是0,代表 Task 未完成
                if (_state != 0)
                {
                    // 初始化局部变量
                    _b = 2;

                    taskAwaiter = Program.BarAsync().GetAwaiter();
                    if (!taskAwaiter.IsCompleted)
                    {
                        // state: -1 => 0,异步等待 Task 完成
                        _state = 0;
                        _taskAwaiter = taskAwaiter;
                        var stateMachine = this;
                        // 内部会调用 将 stateMachine.MoveNext 注册为 Task 的回调
                        _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted(ref taskAwaiter, ref stateMachine);
                        return;
                    }
                }
                else
                {
                    taskAwaiter = _taskAwaiter;
                    // TaskAwaiter 是个结构体,这边相当于是个清空 _taskAwaiter 字段的操作
                    _taskAwaiter = new TaskAwaiter<int>();
                    // state: 0 => -1,状态机恢复到初始化状态
                    _state = -1;
                }

                _c = taskAwaiter.GetResult();
                result = _a + _b + _c;
            }
            catch (Exception e)
            {
                // state: any => -2,状态机执行完成
                _state = -2;
                _asyncTaskMethodBuilder.SetException(e);
                return;
            }

            // state: -1 => -2,状态机执行完成
            _state = -2;
            // 将 result 设置为 FooAsync 方法的返回值
            _asyncTaskMethodBuilder.SetResult(result);
        }

        public void SetStateMachine(IAsyncStateMachine stateMachine)
        {
        }
    }
}

在编译器生成的状态机类中,我们可以看到一个名为 _asyncTaskMethodBuilder 的字段,它的类型是 AsyncTaskMethodBuilder<int>。
这个 AsyncTaskMethodBuilder 就是 Task所绑定的 AsyncMethodBuilder。

AsyncMethodBuilder 的结构#

以 AsyncTaskMethodBuilder<TResult> 为例,我们来看下 AsyncMethodBuilder 的结构:

public struct AsyncTaskMethodBuilder<TResult>
{
    // 保存最后作为返回值的 Task
    private Task<TResult>? m_task;

    // 创建一个 AsyncTaskMethodBuilder
    public static AsyncTaskMethodBuilder<TResult> Create() => default;

    // 开始执行 AsyncTaskMethodBuilder 及其绑定的状态机 
    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
        AsyncMethodBuilderCore.Start(ref stateMachine);

    // 绑定状态机,但编译器的编译结果不会调用
    public void SetStateMachine(IAsyncStateMachine stateMachine) =>
        AsyncMethodBuilderCore.SetStateMachine(stateMachine, m_task);

    // 将状态机的 MoveNext 方法注册为 async方法 内 await 的 Task 的回调
    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    // 同上,参考前一篇文章讲 UnsafeOnCompleted 和 OnCompleted 的区别
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public Task<TResult> Task
    {
        get => m_task ?? InitializeTaskAsPromise();
    }

    public void SetResult(TResult result)
    {
        if (m_task is null)
        {
            m_task = Threading.Tasks.Task.FromResult(result);
        }
        else
        {
            SetExistingTaskResult(m_task, result);
        }
    }

    public void SetException(Exception exception) => SetException(exception, ref m_task);
}

非泛型的 Task 对应的 AsyncMethodBuilder 是 AsyncTaskMethodBuilder,它的结构与泛型的 AsyncTaskMethodBuilder<TResult> 类似,但因为最终返回的 Task 没有执行结果,它的 SetResult 只是为了标记 Task 的完成状态并触发 Task 的回调。

public struct AsyncTaskMethodBuilder
{
    private Task<VoidTaskResult>? m_task;

    public static AsyncTaskMethodBuilder Create() => default;


    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
        AsyncMethodBuilderCore.Start(ref stateMachine);

    public void SetStateMachine(IAsyncStateMachine stateMachine) =>
        AsyncMethodBuilderCore.SetStateMachine(stateMachine, task: null);

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine =>
        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);

    public Task Task
    {
        get => m_task ?? InitializeTaskAsPromise();
    }

    public void SetResult()
    {
        if (m_task is null)
        {
            m_task = Task.s_cachedCompleted;
        }
        else
        {
            AsyncTaskMethodBuilder<VoidTaskResult>.SetExistingTaskResult(m_task, default!);
        }
    }

    public void SetException(Exception exception) =>
        AsyncTaskMethodBuilder<VoidTaskResult>.SetException(exception, ref m_task);
}            

AsyncMethodBuilder 功能分析#

AsyncTaskMethodBuilder 在 FooAsync 方法的执行过程中,起到了以下作用:

  1. 对内:关联状态机和状态机执行的上下文,管理状态机的生命周期。
  2. 对外:构建一个 Task 对象,作为异步方法的返回值,并会触发该 Task 执行的完成或异常。

为了方便说明,下文我们将 FooAsync 方法返回的 Task 称为 FooTask,BarAsync 方法返回的 Task 称为 BarTask。

对状态机的生命周期进行管理#

状态机通过 _asyncTaskMethodBuilder.Start 方法来启动且其 MoveNext 方式是通过 _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted 方法来注册为 BarTask 的回调的。

对 async 方法的返回值进行包装#

_asyncTaskMethodBuilder 是用来构建一个 Task 对象,_asyncTaskMethodBuilder 的 Task 属性就是 FooAsync 方法返回的 FooTask。通过 _asyncTaskMethodBuilder 的 SetResult 方法,我们可以设置 FooTask 的执行结果, 通过 SetException 方法,我们可以设置 FooTask 的异常。

小结#

一个 AsyncMethodBuilder 是由下面几个部分组成的:

  1. 一个 Task 对象,作为异步方法的返回值。
  2. Create 方法,用来创建 AsyncMethodBuilder。
  3. Start 方法,用来启动状态机。
  4. AwaitOnCompleted/AwaitUnsafeOnCompleted 方法,用来将状态机的 MoveNext 方法注册为 async方法 内 await 的 Task 的回调。
  5. SetResult/SetException 方法,用来标记 Task 的完成状态并触发 Task 的回调。
  6. SetStateMachine 方法,用来关联状态机,不常用,编译结果也不会调用。

async void

为了让 async 方法适配传统的事件回调,C# 引入了 async void 的概念。

var foo = new Foo();
foo.OnSayHello += FooAsync;
foo.SayHello();

Console.ReadLine();

async void FooAsync(object sender, EventArgs e)
{
    var args = e as SayHelloEventArgs;
    await Task.Delay(1000);
    Console.WriteLine(args.Message);
}

class Foo
{
    public event EventHandler OnSayHello;

    public void SayHello()
    {
        OnSayHello.Invoke(this, new SayHelloEventArgs { Message = "Hello" });
    }
}

class SayHelloEventArgs : EventArgs
{
    public string Message { get; set; }
}

async void 也有一个对应的 AsyncVoidMethodBuilder。

    public struct AsyncVoidMethodBuilder
    {
        // AsyncVoidMethodBuilder 是对 AsyncTaskMethodBuilder 的封装
        private AsyncTaskMethodBuilder _builder;

        public static AsyncVoidMethodBuilder Create()
        {
            // ...
        }

        public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>
            AsyncMethodBuilderCore.Start(ref stateMachine);

        public void SetStateMachine(IAsyncStateMachine stateMachine) =>
            _builder.SetStateMachine(stateMachine);

        public void AwaitOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : INotifyCompletion
            where TStateMachine : IAsyncStateMachine =>
            _builder.AwaitOnCompleted(ref awaiter, ref stateMachine);

        public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : ICriticalNotifyCompletion
            where TStateMachine : IAsyncStateMachine =>
            _builder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);

        public void SetResult()
        {
            // 仅仅是做 runtime 的一些状态标记
        }

        public void SetException(Exception exception)
        {
            // 这个异常只能通过 TaskScheduler.UnobservedTaskException 事件来捕获
        }

        // 因为没有返回值,这个 Task 不对外暴露
        private Task Task => _builder.Task;
    }

自定义 AsyncMethodBuilder

自定义一个 AsyncMethodBuilder,不需要实现任意接口,只需要实现上面说的那 6 个主要组成部分,编译器就能够正常编译。

awaitable 绑定 AsyncMethodBuilder 的方式有两种:

  1. 在 awaitable 类型上添加 AsyncMethodBuilderAttribute 来绑定 AsyncMethodBuilder。
  2. 在 async 方法上添加 AsyncMethodBuilderAttribute 来绑定 AsyncMethodBuilder,用来覆盖 awaitable 类型上的 AsyncMethodBuilderAttribute(前提是 awaitable 类型上有 AsyncMethodBuilderAttribute)。
struct FooAsyncMethodBuilder<TResult>
{
    private FooAwaitable<TResult> _awaitable;

    // 1. 定义 Task 属性
    public FooAwaitable<TResult> Task
    {
        get
        {
            Console.WriteLine("FooAsyncMethodBuilder.Task");
            return _awaitable;
        }
    }
    
    // 2. 定义 Create 方法
    public static FooAsyncMethodBuilder<TResult> Create()
    {
        Console.WriteLine("FooAsyncMethodBuilder.Create");
        var awaitable = new FooAwaitable<TResult>();
        var builder = new FooAsyncMethodBuilder<TResult>
        {
            _awaitable = awaitable,
        };
        return builder;
    }
    
    // 3. 定义 Start 方法
    public void Start<TStateMachine>(ref TStateMachine stateMachine)
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.Start");
        stateMachine.MoveNext();
    }

    
    // 4. 定义 AwaitOnCompleted/AwaitUnsafeOnCompleted 方法
    
    // 如果 awaiter 实现了 INotifyCompletion 接口,就调用 AwaitOnCompleted 方法
    public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.AwaitOnCompleted");
        awaiter.OnCompleted(stateMachine.MoveNext);
    }

    [SecuritySafeCritical]
    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter,
        ref TStateMachine stateMachine)
        where TAwaiter : ICriticalNotifyCompletion
        where TStateMachine : IAsyncStateMachine
    {
        Console.WriteLine("FooAsyncMethodBuilder.AwaitUnsafeOnCompleted");
        awaiter.UnsafeOnCompleted(stateMachine.MoveNext);
    }

    // 5. 定义 SetResult/SetException 方法
    public void SetResult(TResult result)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetResult");
        _awaitable.TrySetResult(result);
    }
    
    public void SetException(Exception exception)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetException");
        _awaitable.TrySetException(exception);
    }
    
    // 6. 定义 SetStateMachine 方法,虽然编译器不会调用,但是编译器要求必须有这个方法
    public void SetStateMachine(IAsyncStateMachine stateMachine)
    {
        Console.WriteLine("FooAsyncMethodBuilder.SetStateMachine");
    }
}

// 7. 通过 AsyncMethodBuilderAttribute 绑定 FooAsyncMethodBuilder
[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder<>))]
class FooAwaitable<TResult>
{
    // ...
}
Console.WriteLine("await Foo1Async()");
int foo1= await Foo1Async();
Console.WriteLine("Foo1Async() result: " + foo1);
Console.WriteLine();

Console.WriteLine("await Foo2Async()");

int foo2 = await Foo2Async();
Console.WriteLine("Foo2Async() result: " + foo2);
Console.WriteLine();

Console.WriteLine("await FooExceptionAsync()");
try
{
    await FooExceptionAsync();
}
catch (Exception e)
{
    Console.WriteLine(e.Message);
}

async FooAwaitable<int> Foo1Async()
{
    await Task.Delay(1000);
    return 1;
}

// 覆盖默认的 AsyncMethodBuilder,使用 FooAsyncMethodBuilder2
// 本文省略了 FooAsyncMethodBuilder2 的定义,可以参考上面的 FooAsyncMethodBuilder
[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder2<>))]
async FooAwaitable<int> Foo2Async()
{
    await Task.Delay(1000);
    return 2;
}

执行结果:

await Foo1Async()
FooAsyncMethodBuilder.Create
FooAsyncMethodBuilder.Start
FooAsyncMethodBuilder.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder.SetResult
FooAwaiter.GetResult
Foo1Async() result: 1

await Foo2Async()
FooAsyncMethodBuilder2.Create
FooAsyncMethodBuilder2.Start
FooAsyncMethodBuilder2.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder2.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder2.SetResult
FooAwaiter.GetResult
Foo2Async() result: 2

await FooExceptionAsync()
FooAsyncMethodBuilder.Create
FooAsyncMethodBuilder.Start
FooAsyncMethodBuilder.AwaitUnsafeOnCompleted
FooAsyncMethodBuilder.Task
FooAwaiter.UnsafeOnCompleted
FooAwaiter.UnsafeOnCompleted: added continuation
FooAsyncMethodBuilder.SetException
Exception from FooExceptionAsync

在方法上添加 AsyncMethodBuilderAttribute 的功能是后来才添加的,通过这个功能,可以覆盖 awaitable 类型上的 AsyncMethodBuilderAttribute,以便进行性能优化。例如 .NET 6 开始提供的 PoolingAsyncValueTaskMethodBuilder,对原始的 AsyncValueTaskMethodBuilder 进行了池化处理,可以通过在方法上添加 AsyncMethodBuilderAttribute 来使用。

欢迎关注个人技术公众号

1201123-20230302194546214-138980196.png

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK