C 语言

关注公众号 jb51net

关闭
首页 > 软件编程 > C 语言 > C++20 协程库

基于C++20实现协程库的示例代码

作者:哦咧哇岸居

本文主要简单地介绍了C++20协程的关键点实现一个简单可用的协程库,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

前言

要理解C++20的协程怎么用,得先理解关键字,三个关键字:co_awaitco_yieldco_return,所有函数体内有这三个关键字其中一个或以上的,都将被转换成协程函数。还有两个概念:awaitergenerator,awaiter是co_await作用的对象,而generator是管理协程句柄(std::coroutine_handle)的对象,也是协程函数的返回值。

另外,C++20的协程是缺少“调度器”的,需要自行实现协程的调度

注意点:下面我会说“协程”和“协程函数”,注意两者区别

C++20协程关键字

co_return [result]

result是协程的最终返回结果,没有result时即无返回

co_yield value

co_yield会将协程挂起,value是协程挂起时返回的值,不可省略,且必须与co_return同类型,当co_return无返回值时,不可使用co_yield,每次执行到co_yield都可以获取一次协程的返回值(即多次返回值)

co_await value

①value是一个“awaiter”对象,不可省略。

②co_await可重载,也是一个“运算符函数”,重载的co_await运算符函数必须返回一个awaiter。

awaiter说明

awaiter是co_await直接作用的对象,它决定了co_await操作时协程如何响应

awaiter类型必须拥有以下三个接口函数:

await_ready

co_await接收到awaiter参数时首先执行的函数,返回值bool。当返回值为false时,协程将被挂起,然后执行await_suspend函数;当返回值为true时,协程不会挂起而继续执行,将跳过await_suspend函数。

await_suspend

await_suspend返回类型有三种:

①bool,如果返回false,则立即恢复协程而不挂起。如果返回true,协程挂起。

②void,视为返回true

③std::coroutine_handle<>,挂起当前协程,并立即恢复返回的另一个协程,就是该重要的机制,支撑起协程的嵌套和转移

关于await_suspend的实参——协程的句柄,可以是默认的std::coroutine_handle<>,也可以是指定的std::coroutine_handle<promise_type>,视功能需求自定义。

await_resume

await_resume函数的返回值就是co_await运算符的返回值。当协程在co_await挂起后被恢复时,或者协程在await_ready返回true时,将会调用await_resume。

其中,标准库提供了两个内置的awaiter,分别是std::suspend_always和std::suspend_never

promise_type说明

promise_type是协程内部状态的控制核心,控制着协程的创建、挂起、恢复和销毁,它还负责协程的最终返回(或异常)以及协程返回值的传递。

promise_type类型的接口函数如下表:

函数描述函数说明
generator get_return_object()在协程执行前调用,用于构造generator对象,然后将存放了promise_type引用的协程句柄传送到generator对象中存放。该接口的返回值就是协程函数的直接返回值
awaiter initial_suspend()在协程初始化时调用,返回一个awaiter对象
awaiter final_suspend() noexcept(true)在协程结束时调用。必须有noexcept修饰
awaiter yield_value(type value)co_yield操作时调用,value即是co_yield的参数。【无co_yield时非必须】
void unhandled_exception()协程发生异常时调用
void return_void()当co_return执行且无返回值时调用。有该函数的时候不能存在return_value函数。【co_return有返回值时非必须】
void return_value(type value)当co_return执行且有返回值时调用,value即是co_return的参数。有该函数的时候不能存在return_void函数。【co_return无返回值时非必须】
static generator get_return_object_on_allocation_failure()当标识为noexcept的内存分配函数返回nullptr时,协程函数在返回generator时将会通过调用此接口获得返回值。【非必须】

generator说明

generator是管理协程句柄(std::coroutine_handle)的对象,也是协程函数的返回值,一般由generator对象来操作promise_type对象

generator的对象由promise_type的get_return_object函数生成并返回

协程库源码实现

理解了上面的各个要点后,就能着手实现一个简单的协程库了。

以下就是源码,支持协程嵌套和转移,也支持异步执行及完成后唤醒,直接include本文件即可使用

#ifndef CO_TASK_H
#define CO_TASK_H

#include <coroutine>
#include <functional>  // std::function
#include <memory>  // std::shared_ptr
#include <optional>  // std::optional

// 用于在协程结束时自动恢复续体的 Awaitable
template<typename _Tp>
struct FinalAwaitable 
{
    bool await_ready() const noexcept 
    { return false; }
    
    // 对 C++20 对话式对称传输的支持:返回 handle 会立即切换到该协程
    std::coroutine_handle<> await_suspend(std::coroutine_handle<_Tp> h) noexcept 
    {
        // 这里的 h 是当前结束的协程(子协程)
        if constexpr (std::is_same<_Tp, void>::value)
        { return std::noop_coroutine(); }
        else 
        {
            auto& promise = h.promise(); 
            if (promise.continuation) 
            { return promise.continuation; }
            return std::noop_coroutine();
        }
    }
    
    void await_resume() noexcept 
    {}
};

template<typename _Ret, typename _ThreadPoolEnqueue, typename _EventLoopEnqueue>
struct ThreadPoolAwaiter 
{
    _ThreadPoolEnqueue pool_enqueue;
    _EventLoopEnqueue loop_enqueue;
    std::function<_Ret()> func;
    _Ret result;

    ThreadPoolAwaiter(_ThreadPoolEnqueue &&p, _EventLoopEnqueue &&l, std::function<_Ret()> &&f)
     : pool_enqueue(std::move(p))
     , loop_enqueue(std::move(l))
     , func(std::move(f))
    {}

    // 永远挂起,以便将任务切走
    bool await_ready() const noexcept 
    { return false; }

    void await_suspend(std::coroutine_handle<> handle) 
    {
        // 1. 将任务提交给线程池
        pool_enqueue([this, handle]() mutable {
            try 
            {
                result = func();
            } 
            catch (...) 
            {
                // 异常处理逻辑可以根据需求扩展
            }

            // 2. 线程池任务完成后,将“恢复协程”的操作投递回主事件循环
            loop_enqueue([handle]() {
                handle.resume();
            });
        });
    }

    _Ret await_resume() 
    { return std::move(result); }
};

template<typename _ThreadPoolEnqueue, typename _EventLoopEnqueue>
struct ThreadPoolAwaiter<void, _ThreadPoolEnqueue, _EventLoopEnqueue>
{
    _ThreadPoolEnqueue pool_enqueue;
    _EventLoopEnqueue loop_enqueue;
    std::function<void()> func;

    ThreadPoolAwaiter(_ThreadPoolEnqueue &&p, _EventLoopEnqueue &&l, std::function<void()> &&f)
     : pool_enqueue(std::move(p))
     , loop_enqueue(std::move(l))
     , func(std::move(f)) 
    {}

    // 永远挂起,以便将任务切走
    bool await_ready() const noexcept 
    { return false; }

    void await_suspend(std::coroutine_handle<> handle) 
    {
        // 1. 将任务提交给线程池
        pool_enqueue([this, handle]() mutable {
            try 
            {
                func();
            } 
            catch (...) 
            {
                // 异常处理逻辑可以根据需求扩展
            }

            // 2. 线程池任务完成后,将“恢复协程”的操作投递回主事件循环
            loop_enqueue([handle]() {
                handle.resume();
            });
        });
    }

    void await_resume() 
    { }
};

namespace _detail
{

template<typename _Func, typename ... _Args>
concept CanCallObject = requires(_Func &&f)
{ f(std::declval<std::function<void(_Args...)>>()); };

template<typename _Func>
concept IsNotBindExpression = requires(_Func f)
{ !std::is_bind_expression<std::decay_t<_Func>>::value; };

template<std::size_t, typename ... _Args>
struct RetType
{ 
    using type = std::tuple<std::decay_t<_Args>...>; 
    static constexpr type Make(_Args && ... args) 
    { return std::make_tuple(std::forward<_Args>(args)...); }
};

template<typename _Tp>
struct RetType<1, _Tp>
{ 
    using type = std::decay_t<_Tp>; 
    static constexpr type Make(_Tp && arg) 
    { return std::forward<_Tp>(arg); }
};

}

// 回调函数有参数的awaiter
template<typename _Functor, typename ... _Args>
struct AsyncAwaiterWithArgs
{
    static constexpr std::size_t ArgCount = sizeof...(_Args);
    using Caller = _Functor;
    using ReturnType = typename _detail::RetType<ArgCount, _Args...>::type;

    Caller caller;
    std::optional<ReturnType> result;

    AsyncAwaiterWithArgs(Caller &&caller)
     : caller(std::move(caller))
    {}

    // 永远挂起,以便将任务切走
    bool await_ready() const noexcept 
    { return false; }

    void await_suspend(std::coroutine_handle<> handle) 
    {
        caller([handle, this](_Args && ... args) {
            result.emplace(_detail::RetType<ArgCount, _Args...>::Make(std::forward<_Args>(args)...));
            handle.resume();
        });
    }

    ReturnType await_resume() 
    { return std::move(result.value()); }
};

// 回调函数无参数的awaiter
template<typename _Functor>
struct AsyncAwaiterWithoutArgs
{
    using Caller = _Functor;

    Caller caller;

    AsyncAwaiterWithoutArgs(Caller &&caller)
     : caller(std::move(caller))
    {}

    // 永远挂起,以便将任务切走
    bool await_ready() const noexcept 
    { return false; }

    void await_suspend(std::coroutine_handle<> handle) 
    {
        caller([handle]() {
            handle.resume();
        });
    }

    void await_resume() 
    { }
};


namespace _detail
{
    
template<typename _Tp, typename = void>
struct FunctionTraits
{ static_assert(false, "Unsupported type for FunctionTraits"); };

// 针对普通函数指针的特化
template <typename _Ret, typename... _Args>
struct FunctionTraits<_Ret(*)(_Args...)> 
{ 
    using arg_types = std::tuple<_Args...>; 
    using first_type = std::decay_t<std::tuple_element_t<0, arg_types>>;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    {
        if constexpr (sizeof...(_Args) == 0)
        { return AsyncAwaiterWithoutArgs<_Func>(std::move(f)); }
        else 
        { return AsyncAwaiterWithArgs<_Func, _Args...>(std::move(f)); }
    }
};

template <typename _Ret, typename... _Args>
struct FunctionTraits<_Ret(&)(_Args...)> 
{ 
    using arg_types = std::tuple<_Args...>; 
    using first_type = std::decay_t<std::tuple_element_t<0, arg_types>>;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    {
        if constexpr (sizeof...(_Args) == 0)
        { return AsyncAwaiterWithoutArgs<_Func>(std::move(f)); }
        else 
        { return AsyncAwaiterWithArgs<_Func, _Args...>(std::move(f)); }
    }
};

// 针对成员函数指针的特化
template <typename _ClassType, typename _Ret, typename ... _Args>
struct FunctionTraits<_Ret(_ClassType::*)(_Args...)> 
{ 
    using arg_types = std::tuple<_Args...>; 
    using first_type = std::decay_t<std::tuple_element_t<0, arg_types>>;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    {
        if constexpr (sizeof...(_Args) == 0)
        { return AsyncAwaiterWithoutArgs<_Func>(std::move(f)); }
        else 
        { return AsyncAwaiterWithArgs<_Func, _Args...>(std::move(f)); }
    }
};

// 针对const成员函数指针的特化
template <typename _ClassType, typename _Ret, typename ... _Args>
struct FunctionTraits<_Ret(_ClassType::*)(_Args...) const> 
{ 
    using arg_types = std::tuple<_Args...>; 
    using first_type = std::decay_t<std::tuple_element_t<0, arg_types>>;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    {
        if constexpr (sizeof...(_Args) == 0)
        { return AsyncAwaiterWithoutArgs<_Func>(std::move(f)); }
        else 
        { return AsyncAwaiterWithArgs<_Func, _Args...>(std::move(f)); }
    }
};

// std::function的特化
template <typename _Ret, typename... _Args>
struct FunctionTraits<std::function<_Ret(_Args...)>> 
{ 
    using arg_types = std::tuple<_Args...>; 
    using first_type = std::decay_t<std::tuple_element_t<0, arg_types>>;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    {
        if constexpr (sizeof...(_Args) == 0)
        { return AsyncAwaiterWithoutArgs<_Func>(std::move(f)); }
        else 
        { return AsyncAwaiterWithArgs<_Func, _Args...>(std::move(f)); }
    }
};

// 针对lambda和一般可调用对象的特化
template <typename _Functor>
struct FunctionTraits<_Functor, std::void_t<decltype(&_Functor::operator())>>
{ 
    using operator_type = decltype(&_Functor::operator());
    using traits = FunctionTraits<operator_type>;
    using arg_types = typename traits::arg_types;
    using first_type = typename traits::first_type;

    template<typename _Func>
    static auto GetAwaiter(_Func &&f)
    { return traits::GetAwaiter(std::move(f)); }
};

}

template<typename _Tp>
struct Task 
{
public:
    struct promise_type;
    using co_handle = std::coroutine_handle<promise_type>;

    using value_type = _Tp;
    using reference = _Tp &;
    using rvalue_reference = _Tp &&;
    using const_reference = const _Tp &;

    struct promise_type 
    {
        value_type value;
        std::exception_ptr exception;
        std::coroutine_handle<> continuation; // 父协程的句柄

        Task get_return_object() 
        { return Task(co_handle::from_promise(*this)); }

        std::suspend_always initial_suspend() 
        { return {}; }

        FinalAwaitable<promise_type> final_suspend() noexcept 
        { return {}; }

        void return_value(rvalue_reference v) 
        { value = std::move(v); }

        void return_value(const_reference v) 
        { value = v; }

        std::suspend_always yield_value(rvalue_reference v) 
        { 
            value = std::move(v); 
            return {}; 
        }

        std::suspend_always yield_value(const_reference v) 
        { 
            value = v; 
            return {}; 
        }

        void unhandled_exception() 
        { exception = std::current_exception(); }
    };

public:
    Task()
     : _M_handle(nullptr) 
    { }

    Task(co_handle h)
     : _M_handle(new co_handle(h), _S_on_release) 
    { }

    // 是否已执行完毕
    operator bool() const noexcept 
    { return _M_handle && *_M_handle && _M_handle->done(); }

    void operator()() const
    {
        if(_M_handle && *_M_handle)
        { _M_handle->resume();  }
    }

    // 获取返回值
    reference operator*() noexcept 
    { return _M_handle->promise().value; }

    auto operator co_await() noexcept 
    {
        struct Awaiter 
        {
            co_handle callee;

            bool await_ready() const noexcept 
            { return !callee || callee.done(); }

            // h 是当前正在 co_await 的父协程句柄
            std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept 
            {
                callee.promise().continuation = h; // 绑定续体
                return callee; // 对称传输:切换到子协程执行
            }

            value_type await_resume() 
            {
                if (callee.promise().exception) 
                { std::rethrow_exception(callee.promise().exception); }
                return std::move(callee.promise().value);
            }
        };
        return Awaiter{*_M_handle};
    }

private:
    static void _S_on_release(co_handle *h)
    {
        h->destroy();
        delete h;
    }

private:
    std::shared_ptr<co_handle> _M_handle;
};

// 针对 void 的特化
template<>
struct Task<void> 
{
public:
    struct promise_type;
    using co_handle = std::coroutine_handle<promise_type>;

    struct promise_type 
    {
        std::exception_ptr exception;
        std::coroutine_handle<> continuation; // 父协程的句柄

        Task get_return_object() 
        { return Task(co_handle::from_promise(*this)); }

        std::suspend_always initial_suspend() 
        { return {}; }

        FinalAwaitable<promise_type> final_suspend() noexcept 
        { return {}; }

        void return_void() 
        {}

        std::suspend_always yield_void() 
        { return {}; }

        void unhandled_exception() 
        { exception = std::current_exception(); }
    };

    Task()
     : _M_handle(nullptr) 
    { }

    Task(co_handle h)
     : _M_handle(new co_handle(h), _S_on_release) 
    { }

    // 是否已执行完毕
    operator bool() const noexcept
    { return _M_handle && *_M_handle && _M_handle->done(); }

    void operator()() const
    {
        if(_M_handle && *_M_handle)
        { _M_handle->resume(); }
    }

    // 获取返回值
    void operator*() const noexcept 
    { }

    auto operator co_await() noexcept 
    {
        struct Awaiter 
        {
            co_handle callee;

            bool await_ready() const noexcept 
            { return !callee || callee.done(); }

            // h 是当前正在 co_await 的父协程句柄
            std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept 
            {
                callee.promise().continuation = h; // 绑定续体
                return callee; // 对称传输:切换到子协程执行
            }

            void await_resume() 
            {
                if (callee.promise().exception) 
                { std::rethrow_exception(callee.promise().exception); }
            }
        };
        return Awaiter{*_M_handle};
    }

private:
    static void _S_on_release(co_handle *h)
    {
        h->destroy();
        delete h;
    }

private:
    std::shared_ptr<co_handle> _M_handle;
};

/**
 * 在线程池pool中异步执行一个函数,得到结果后在loop中唤醒协程
 * @param pool  线程池的提交函数(也是异步执行的事件队列插入函数),接收一个 std::function<void()> 参数,用于执行 @f 函数
 * @param loop  事件循环的提交函数,接收一个 std::function<void()> 参数,用于唤醒协程
 * @param f  需要异步执行的函数,其返回值将被协程接收,作为co_await表达式的结果(返回值可为void)
 */
template<typename _ThreadPoolEnqueue, typename _EventLoopEnqueue, typename _Func>
auto AsyncExec(_ThreadPoolEnqueue &&pool, _EventLoopEnqueue &&loop, _Func &&f) 
{
    using Ret = std::invoke_result_t<_Func>;
    return ThreadPoolAwaiter<Ret, _ThreadPoolEnqueue, _EventLoopEnqueue>(std::move(pool), std::move(loop), std::move(f));
}

/**
 * 执行一个带回调函数的异步操作,回调函数的参数列表将被协程接收,作为co_await表达式的结果(留着回调函数的空位,由AsyncExec函数填充)
 * 如果参数有1个,即co_await表达式直接返回该值;如果参数有2个及以上,将打包成std::tuple作为返回值
 * @param caller  限定了只能有一个参数,且该参数必须是一个可调用对象(回调函数)
 */
template<typename ... _Args, typename _Functor>
requires _detail::CanCallObject<_Functor, _Args...>
auto AsyncExec(_Functor &&caller)
{
    if constexpr (sizeof...(_Args) == 0)
    { return AsyncAwaiterWithoutArgs<_Functor>(std::move(caller)); }
    else 
    { return AsyncAwaiterWithArgs<_Functor, _Args...>(std::move(caller)); }
}
// 非std::bind返回值的特化版本
template<typename _Functor>
requires _detail::IsNotBindExpression<_Functor>
auto AsyncExec(_Functor &&f) 
{
    using CallbackType = typename _detail::FunctionTraits<std::decay_t<_Functor>>::first_type;
    return _detail::FunctionTraits<CallbackType>::GetAwaiter(std::move(f));
}

#endif // CO_TASK_H

使用例子

1.单线程基础功能测试案例

测试案例:展示协程库接口的基础使用

// 事件队列类实现
#include <atomic>  // std::atomic_flag
#include <queue>  // std::queue
#include <functional>  // std::function
#include <thread>  //  std::this_thread::yield

template<typename _Tp>
class EventQueue
{
public:
    using EventWrapper = std::function<_Tp()>;

public:
    void Push(EventWrapper &&e)
    {
        _M_Lock();
        _M_event_queue.push(std::move(e));
        _M_Unlock();
    }

    EventWrapper Pop()
    {
        EventWrapper result;
        _M_Lock();
        if(_M_event_queue.empty())
        {
            _M_Unlock();
            return result;
        }
        EventWrapper e = std::move(_M_event_queue.front());
        _M_event_queue.pop();
        _M_Unlock();
        result.swap(e);
        return result;
    }

private:
    void _M_Lock() noexcept
    {
        while(_M_lock.test_and_set())
        { std::this_thread::yield(); }
    }

    void _M_Unlock() noexcept
    { _M_lock.clear(); }

private:
    std::atomic_flag _M_lock;
    std::queue<EventWrapper> _M_event_queue;
};

// 事件循环类实现
#include <thread>  // std::this_thread::sleep_for
#include <chrono>  // std::chrono::milliseconds
#include <functional>  // std::function
#include <future>  // std::future  std::future_status
#include <atomic>  // std::atomic_bool

class EventLoop
{
public:
    using EventCallback = std::function<void()>;

    template<std::size_t _EmptySleep = 10, std::size_t _NormalSleep = 1>
    void Run()
    {
        _M_flag.store(true, std::memory_order_release);
        while(_M_flag.load(std::memory_order_acquire))
        {
            typename EventQueue<void>::EventWrapper f = _M_event_queue.Pop();
            if(!f)
            {
                std::this_thread::sleep_for(std::chrono::milliseconds(_EmptySleep));
                continue;
            }

            f();
            std::this_thread::sleep_for(std::chrono::milliseconds(_NormalSleep));
        }
    }

    /**
     * 等待一个 future 对象完成,并在完成后执行回调函数 e
     * 回调函数 e 接受一个 std::future<_Tp> 参数,表示等待完成的 future 对象
     */
    template<typename _Tp>
    void WaitFor(std::future<_Tp> &&f, std::function<void(std::future<_Tp> &&)> &&e)
    {
        typename EventQueue<void>::EventWrapper func = [this, f = std::move(f), e = std::move(e)]() mutable {
            if(f.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready)
            {
                WaitFor<_Tp>(std::move(f), std::move(e));
                return;
            }
            e(std::move(f));
        };
        Enqueue(std::move(func));
    }

    void Enqueue(EventCallback &&e)
    { _M_event_queue.Push(std::move(e)); }

    void Stop() noexcept
    { _M_flag.store(false, std::memory_order_release); }

private:
    std::atomic_bool _M_flag;
    EventQueue<void> _M_event_queue;
};

// 测试案例
#ifdef _WIN32
#include <cstdlib>  // system
#endif
#include <iostream>  // std::cout
#include <format>  // std::format

static EventLoop Main;  // 主事件循环

// 执行数值翻倍操作
int Computation(int input) 
{ return input * 2; }

// 协程函数3
Task<double> MyCoroutine3(double base)
{
    std::cout << std::format("[Event Loop] start coroutine 3") << std::endl;
    co_yield base + 1.1;
    std::cout << std::format("[Event Loop] coroutine 3 step 1") << std::endl;
    co_yield base + 2.2;
    std::cout << std::format("[Event Loop] coroutine 3 step 2") << std::endl;
    co_yield base + 3.3;
    std::cout << std::format("[Event Loop] coroutine 3 step 3") << std::endl;
    co_return base + 4.4;
}

// 协程函数2
Task<int> MyCoroutine2(int value)
{
    std::cout << std::format("[Event Loop] start coroutine 2") << std::endl;
    // 测试案例4:嵌套协程中的co_await
    int result = co_await AsyncExec(
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1),
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        [value]() {
            // 异步执行语句,return的值将作为co_await表达式的结果
            return Computation(value);
        }
    );
    std::cout << std::format("[Event Loop] get coroutine 2 result: {}", result) << std::endl;

    double v = 0;
    auto t = MyCoroutine3(100);
    int step = 1;
    while(!t)
    {
        double tmp = co_await AsyncExec(
            std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1),
            std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
            [t]() mutable {
                t();
                return *t;
            }
        );
        std::cout << std::format("[Event Loop] coroutine 2 step {} result: {}", step++, tmp) << std::endl;
        v += tmp;
    }
    std::cout << "[Event Loop] coroutine 2 get co3 result:" << v << std::endl;

    // 本协程的返回值为result
    co_return result;
}

// 协程函数
Task<void> MyCoroutine() 
{
    std::cout << std::format("[Event Loop] start coroutine 1") << std::endl;
    // 测试案例1:无返回值
    co_await AsyncExec(
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        []() {
            // 异步执行语句
            std::cout << "test return void" << std::endl;
        }
    );
    // 测试案例2:有返回值
    int result = co_await AsyncExec(
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        []() {
            // 异步执行语句,return的值将作为co_await表达式的结果
            return Computation(10);
        }
    );
    std::cout << std::format("[Event Loop] get coroutine 1 result: {}", result)  << std::endl;
    // 测试案例3:嵌套协程,获取子协程的返回值
    int ret = co_await MyCoroutine2(100);
    std::cout << std::format("[Event Loop] get coroutine 2 result: {}", ret) << std::endl;
    Main.Stop(); // 结束后停止循环
}

int main() 
{
    // 投递协程任务
    auto t1 = MyCoroutine();
    Main.Enqueue([&t1](){
        t1();  // 开启协程
    });

    // 运行主循环
    std::cout << "main loop start" << std::endl;
    Main.Run(); 
#ifdef _WIN32
    system("pause");
#endif
    return 0;
}

测试执行结果:

main loop start
[Event Loop] start coroutine 1
test return void
[Event Loop] get coroutine 1 result: 20
[Event Loop] start coroutine 2
[Event Loop] get coroutine 2 result: 200
[Event Loop] start coroutine 3
[Event Loop] coroutine 2 step 1 result: 101.1
[Event Loop] coroutine 3 step 1
[Event Loop] coroutine 2 step 2 result: 102.2
[Event Loop] coroutine 3 step 2
[Event Loop] coroutine 2 step 3 result: 103.3
[Event Loop] coroutine 3 step 3
[Event Loop] coroutine 2 step 4 result: 104.4
[Event Loop] coroutine 2 get co3 result:411
[Event Loop] get coroutine 2 result: 200

2.基于线程池异步执行的测试案例

测试案例:主线程向线程池插入一个耗时操作让线程池执行,得到结果后返回给主线程

// 事件队列类实现
#include <atomic>  // std::atomic_flag
#include <queue>  // std::queue
#include <functional>  // std::function
#include <thread>  //  std::this_thread::yield

template<typename _Tp>
class EventQueue
{
public:
    using EventWrapper = std::function<_Tp()>;

public:
    void Push(EventWrapper &&e)
    {
        _M_Lock();
        _M_event_queue.push(std::move(e));
        _M_Unlock();
    }

    EventWrapper Pop()
    {
        EventWrapper result;
        _M_Lock();
        if(_M_event_queue.empty())
        {
            _M_Unlock();
            return result;
        }
        EventWrapper e = std::move(_M_event_queue.front());
        _M_event_queue.pop();
        _M_Unlock();
        result.swap(e);
        return result;
    }

private:
    void _M_Lock() noexcept
    {
        while(_M_lock.test_and_set())
        { std::this_thread::yield(); }
    }

    void _M_Unlock() noexcept
    { _M_lock.clear(); }

private:
    std::atomic_flag _M_lock;
    std::queue<EventWrapper> _M_event_queue;
};

// 事件循环类实现
#include <thread>  // std::this_thread::sleep_for
#include <chrono>  // std::chrono::milliseconds
#include <functional>  // std::function
#include <future>  // std::future  std::future_status
#include <atomic>  // std::atomic_bool

class EventLoop
{
public:
    using EventCallback = std::function<void()>;

    template<std::size_t _EmptySleep = 10, std::size_t _NormalSleep = 1>
    void Run()
    {
        _M_flag.store(true, std::memory_order_release);
        while(_M_flag.load(std::memory_order_acquire))
        {
            typename EventQueue<void>::EventWrapper f = _M_event_queue.Pop();
            if(!f)
            {
                std::this_thread::sleep_for(std::chrono::milliseconds(_EmptySleep));
                continue;
            }

            f();
            std::this_thread::sleep_for(std::chrono::milliseconds(_NormalSleep));
        }
    }

    /**
     * 等待一个 future 对象完成,并在完成后执行回调函数 e
     * 回调函数 e 接受一个 std::future<_Tp> 参数,表示等待完成的 future 对象
     */
    template<typename _Tp>
    void WaitFor(std::future<_Tp> &&f, std::function<void(std::future<_Tp> &&)> &&e)
    {
        typename EventQueue<void>::EventWrapper func = [this, f = std::move(f), e = std::move(e)]() mutable {
            if(f.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready)
            {
                WaitFor<_Tp>(std::move(f), std::move(e));
                return;
            }
            e(std::move(f));
        };
        Enqueue(std::move(func));
    }

    void Enqueue(EventCallback &&e)
    { _M_event_queue.Push(std::move(e)); }

    void Stop() noexcept
    { _M_flag.store(false, std::memory_order_release); }

private:
    std::atomic_bool _M_flag;
    EventQueue<void> _M_event_queue;
};

// 线程池的实现
#include <future>  // std::future
#include <atomic>  // std::atomic_flag  std::atomic_bool
#include <queue>  // std::queue
#include <functional>  // std::function
#include <thread>  // std::this_thread::sleep_for  std::this_thread::yield
#include <vector>  // std::vector
#include <memory>  // std::shared_ptr  std::make_shared

namespace detail
{

template<typename _AtomicFlag = std::atomic_flag>
class SpinLock
{
public:
    SpinLock()
     : _M_lock(ATOMIC_FLAG_INIT) 
    {}

    void lock() noexcept
    {
        while(_M_lock.test_and_set())
        { std::this_thread::yield(); }
    }

    void unlock() noexcept
    { _M_lock.clear(std::memory_order_release); }

private:    
    std::atomic_flag _M_lock;
};

} // namespace detail

template<typename _Lock = detail::SpinLock<>>
class ThreadPool
{
private:
    using AsyncTaskWrapper = std::function<void()>;

public:
    ThreadPool(std::size_t num_threads = 1)
    {
        _M_flag.store(true, std::memory_order_release);
        while(num_threads-- > 0)
        {
            _M_threads.emplace_back(std::async(std::launch::async, [this](){
                _M_MainLoop();
            }));
        }
    }
    ~ThreadPool()
    {
        Stop();
        for(auto &t : _M_threads)
        { t.get(); }
    }

    // 异步执行,不关心结果
    void Async(std::function<void()> &&async_task)
    {
        _M_Lock();
        _M_async_task_queue.push(std::move(async_task));
        _M_Unlock();
    }

    void Stop() noexcept
    { _M_flag.store(false, std::memory_order_release); }

private:
    void _M_MainLoop()
    {
        while(_M_flag.load(std::memory_order_acquire))
        {
            _M_Lock();

            if(_M_async_task_queue.empty())
            {
                _M_Unlock();
                std::this_thread::sleep_for(std::chrono::milliseconds(10));
                continue;
            }

            AsyncTaskWrapper task = std::move(_M_async_task_queue.front());
            _M_async_task_queue.pop();
            _M_Unlock();

            task();
        }
    }

    void _M_Lock() noexcept
    { _M_lock.lock(); }

    void _M_Unlock() noexcept
    { _M_lock.unlock(); }

private:
    std::queue<AsyncTaskWrapper> _M_async_task_queue;
    _Lock _M_lock;
    std::atomic_bool _M_flag;
    std::vector<std::future<void>> _M_threads;
};

// 测试案例源码

#ifdef _WIN32
#include <cstdlib>  // system
#endif
#include <iostream>  // std::cout
#include <format>  // std::format

static EventLoop Main;  // 主消息循环
static ThreadPool<> Pool(2);  // 2个worker线程的线程池

// 一个模拟耗时任务
int HeavyComputation(int input) 
{
    std::cout << "[Thread Pool] caling  thread_ID: " << std::this_thread::get_id() << std::endl;
    std::this_thread::sleep_for(std::chrono::seconds(2));
    return input * 2;
}

// 协程函数2
Task<int> MyCoroutine2(int coid)
{
    std::cout << std::format("[Event Loop---{}---] start coroutine 2 thread_ID: ", coid) << std::this_thread::get_id() << std::endl;
    // 关键点:co_await 会触发 ThreadPoolAwaiter
    int result = co_await AsyncExec(
        std::bind(&ThreadPool<>::Async, &Pool, std::placeholders::_1), 
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        []() {
            return HeavyComputation(20);
        }
    );
    std::cout << std::format("[Event Loop---{}---] get coroutine 2 result: {}", coid, result) << std::endl;
    co_return result;
}

// 协程函数
Task<void> MyCoroutine(int coid) 
{
    std::cout << std::format("[Event Loop---{}---] start coroutine 1 thread_ID: ", coid) << std::this_thread::get_id() << std::endl;
    // 关键点:co_await 会触发 ThreadPoolAwaiter
    co_await AsyncExec(
        std::bind(&ThreadPool<>::Async, &Pool, std::placeholders::_1), 
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        []() {
            std::cout << "test return void" << std::endl;
        }
    );
    int result = co_await AsyncExec(
        std::bind(&ThreadPool<>::Async, &Pool, std::placeholders::_1), 
        std::bind(&EventLoop::Enqueue, &Main, std::placeholders::_1), 
        []() {
            return HeavyComputation(10);
        }
    );
    std::cout << std::format("[Event Loop---{}---] get coroutine 1 result: {} thread_ID: ", coid, result) << std::this_thread::get_id() << std::endl;
    // 嵌套协程,这里接收协程返回值
    int ret = co_await MyCoroutine2(coid);
    std::cout << std::format("[Event Loop---{}---] get coroutine 2 result: {}", coid, ret) << std::endl;
    Main.Stop(); // 结束后停止循环
}

int main() 
{
    // 投递协程任务,创建2个协程
    auto t1 = MyCoroutine(1);
    auto t2 = MyCoroutine(2);
    // 将协程投递到主循环队列执行
    Main.Enqueue([&t1](){
        t1();
    });
    Main.Enqueue([&t2](){
        t2();
    });
    // 运行主循环
    std::cout << "main loop start" << std::endl;
    Main.Run(); 
#ifdef _WIN32
    system("pause");
#endif
    return 0;
}

测试执行结果:

main loop start
[Event Loop---1---] start coroutine 1 thread_ID: 1
test return void
[Event Loop---2---] start coroutine 1 thread_ID: 1
test return void
[Thread Pool] caling  thread_ID: 3
[Thread Pool] caling  thread_ID: 2
[Event Loop---1---] get coroutine 1 result: 20 thread_ID: 1
[Event Loop---1---] start coroutine 2 thread_ID: 1
[Thread Pool] caling  thread_ID: 3
[Event Loop---2---] get coroutine 1 result: 20 thread_ID: 1
[Event Loop---2---] start coroutine 2 thread_ID: 1
[Thread Pool] caling  thread_ID: 2
[Event Loop---1---] get coroutine 2 result: 40
[Event Loop---1---] get coroutine 2 result: 40

3.嵌入事件回调机制的测试案例

有部分第三方库的接口,会需要传入一个回调函数的参数,这里展示如何嵌入使用

// 事件队列类实现
#include <atomic>  // std::atomic_flag
#include <queue>  // std::queue
#include <functional>  // std::function
#include <thread>  //  std::this_thread::yield

template<typename _Tp>
class EventQueue
{
public:
    using EventWrapper = std::function<_Tp()>;

public:
    void Push(EventWrapper &&e)
    {
        _M_Lock();
        _M_event_queue.push(std::move(e));
        _M_Unlock();
    }

    EventWrapper Pop()
    {
        EventWrapper result;
        _M_Lock();
        if(_M_event_queue.empty())
        {
            _M_Unlock();
            return result;
        }
        EventWrapper e = std::move(_M_event_queue.front());
        _M_event_queue.pop();
        _M_Unlock();
        result.swap(e);
        return result;
    }

private:
    void _M_Lock() noexcept
    {
        while(_M_lock.test_and_set())
        { std::this_thread::yield(); }
    }

    void _M_Unlock() noexcept
    { _M_lock.clear(); }

private:
    std::atomic_flag _M_lock;
    std::queue<EventWrapper> _M_event_queue;
};

// 事件循环类实现
#include <thread>  // std::this_thread::sleep_for
#include <chrono>  // std::chrono::milliseconds
#include <functional>  // std::function
#include <future>  // std::future  std::future_status
#include <atomic>  // std::atomic_bool

class EventLoop
{
public:
    using EventCallback = std::function<void()>;

    template<std::size_t _EmptySleep = 10, std::size_t _NormalSleep = 1>
    void Run()
    {
        _M_flag.store(true, std::memory_order_release);
        while(_M_flag.load(std::memory_order_acquire))
        {
            typename EventQueue<void>::EventWrapper f = _M_event_queue.Pop();
            if(!f)
            {
                std::this_thread::sleep_for(std::chrono::milliseconds(_EmptySleep));
                continue;
            }

            f();
            std::this_thread::sleep_for(std::chrono::milliseconds(_NormalSleep));
        }
    }

    /**
     * 等待一个 future 对象完成,并在完成后执行回调函数 e
     * 回调函数 e 接受一个 std::future<_Tp> 参数,表示等待完成的 future 对象
     */
    template<typename _Tp>
    void WaitFor(std::future<_Tp> &&f, std::function<void(std::future<_Tp> &&)> &&e)
    {
        typename EventQueue<void>::EventWrapper func = [this, f = std::move(f), e = std::move(e)]() mutable {
            if(f.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready)
            {
                WaitFor<_Tp>(std::move(f), std::move(e));
                return;
            }
            e(std::move(f));
        };
        Enqueue(std::move(func));
    }

    void Enqueue(EventCallback &&e)
    { _M_event_queue.Push(std::move(e)); }

    void Stop() noexcept
    { _M_flag.store(false, std::memory_order_release); }

private:
    std::atomic_bool _M_flag;
    EventQueue<void> _M_event_queue;
};

// 测试源码
#ifdef _WIN32
#include <cstdlib>  // system
#endif
#include <iostream>  // std::cout
#include <format>  // std::format

static EventLoop Main;

void TestRequest(const std::string &req, std::function<void(const std::string &)> &&callback)
{
    std::cout << "TestRequest start" << std::endl;
    Main.Enqueue(std::bind(callback, req + " get"));
    std::cout << "TestRequest end" << std::endl;
}

struct A
{
    std::string value;
};

/**
 * 请求接口,请求参数是req,然后回调函数是callback
 * @param req  请求参数
 * @param callback  回调函数,当请求完成后,会调用该函数,并传入请求结果
 */
void TestRequestA(const std::string &req, std::function<void(const A &, const A &)> &&callback)
{
    std::cout << "TestRequest start" << std::endl;
    A a1, a2;
    a1.value = req + "1";
    a2.value = req + "2";
    Main.Enqueue(std::bind(callback, std::move(a1), std::move(a2)));
    std::cout << "TestRequest end" << std::endl;
}

// 协程函数
Task<void> MyCoroutine()
{
    std::cout << std::format("[Event Loop] start coroutine 3 thread_ID: ") << std::this_thread::get_id() << std::endl;
    using CallbackType = std::function<void(const std::string &)>;
    std::function<void(CallbackType &&)> f = std::bind(TestRequest, "test request", std::placeholders::_1);
    // 第一种方式
    // std::function对象可直接传入AsyncExec,该对象的形参类型就是co_await表达式的返回类型
    auto resp1 = co_await AsyncExec(std::move(f));
    std::cout << std::format("[Event Loop] get response: {}", resp1) << std::endl;
	
    // 第二种方式
    // AsyncExec的实参是std::bind时,需要指定模板参数类型,类型是回调函数中的形参类型,也是co_await表达式的返回类型
    auto resp2 = co_await AsyncExec<const A &, const A &>(std::bind(TestRequestA, "test request", std::placeholders::_1));
    A &a1 = std::get<0>(resp2);
    A &a2 = std::get<1>(resp2);
    std::cout << std::format("[Event Loop] get response 1: {}", a1.value) << std::endl;
    std::cout << std::format("[Event Loop] get response 2: {}", a2.value) << std::endl;

    // 第三种方式,传入lambda
    auto f2 = [](std::function<void(const std::string &)> &&callback) {
        TestRequest("test request 2", std::move(callback));
    };
    co_await AsyncExec(std::move(f2));
}

int main() 
{
    // 投递协程任务
    auto t1 = MyCoroutine();
    Main.Enqueue([&t1](){
        t1();
    });
    // 运行主循环
    std::cout << "main loop start" << std::endl;
    Main.Run(); 
    std::cout << "main loop end" << std::endl;
#ifdef _WIN32
    system("pause");
#endif
    return 0;
}

测试执行结果:

main loop start
[Event Loop] start coroutine 3 thread_ID: 1
TestRequest start
TestRequest end
[Event Loop] get response: test request get
TestRequest start
TestRequest end
[Event Loop] get response 1: test request1
[Event Loop] get response 2: test request2
TestRequest start
TestRequest end
main loop end

4.嵌入第三方网络库

以libhv库为例,展示如何嵌入协程机制。如果第三方库是纯C,需要一层C++封装才能使用本协程库

下面写了服务端和客户端的测试代码

// 嵌入 第三方网络库 的测试案例

#ifdef _WIN32
#include <cstdlib>  // system
#endif
#include <iostream>
#include <chrono>  // std::chrono::milliseconds
#include <hv/EventLoop.h>  // hv::EventLoop  hv::TimerID  INVALID_TIMER_ID
#include <hv/TcpServer.h>  // hv::TcpServer
#include <hv/TcpClient.h>  // hv::TcpClient
#include "co_task.h"

static constexpr unsigned short PORT = 23333;
std::shared_ptr<hv::EventLoop> Main = std::make_shared<hv::EventLoop>();



hv::TcpServer Server(Main);
// 展示监听网络消息的协程
Task<void> OnMessageCoroutine()
{
    while(Main->isRunning())
    {
        std::cout << "start listen message" << std::endl;
        auto c = [](std::function<void(const hv::SocketChannelPtr &, hv::Buffer*)> &&f){
            Server.onMessage = std::move(f);
        };
        auto arg = co_await AsyncExec(std::move(c));
        hv::SocketChannelPtr channel = std::get<0>(arg);
        hv::Buffer* buffer = std::get<1>(arg);
        // TODO: 处理网络消息
        std::string msg(static_cast<const char *>(buffer->data()), buffer->size());
        std::cout << "recv: " << msg << std::endl;

        channel->write("i am server");
    }
}

// 展示监听网络连接事件的协程
Task<void> ListenCoroutine() 
{
    int listenfd = Server.createsocket(PORT);
    if(listenfd < 0) 
    { co_return; }

    Server.startAccept();
    while(Main->isRunning())
    {
        std::cout << "start listen connection" << std::endl;
        auto c = [](std::function<void(const hv::SocketChannelPtr &)> &&f){
            Server.onConnection = std::move(f);
        };
        hv::SocketChannelPtr channel = co_await AsyncExec(std::move(c));
        if(channel->isConnected())
        {
            // 新连接接入
            channel->setKeepaliveTimeout(3000);
            // 处理网络连接事件
            std::cout << "new connection" << std::endl;
        }
        else 
        {
            // 连接断开
            std::cout << "connection lost" << std::endl;
        }
    }
}

Task<void> listenco;
Task<void> msgco;
void StartServer()
{
    listenco = ListenCoroutine();
    Main->postEvent([](hv::Event *){
        listenco();
    });
    msgco = OnMessageCoroutine();
    Main->postEvent([](hv::Event *){
        msgco();
    });
}



hv::TcpClient Client(Main);
Task<void> StartConnectCoroutine()
{
    std::cout << "start connect to server" << std::endl;
    int listenfd = Client.createsocket(PORT, "127.0.0.1");
    if(listenfd < 0) 
    { co_return; }

    Client.startConnect();
    // 等待连接成功
    auto c1 = [](std::function<void(const hv::SocketChannelPtr &)> &&f){
        Client.onConnection = std::move(f);
    };
    hv::SocketChannelPtr channel = co_await AsyncExec(std::move(c1));
    if(channel->isConnected())
    {
        std::cout << "connect server success" << std::endl;
    }
    else 
    {
        std::cout << "disconnect server" << std::endl;
        Client.closesocket(); // 断开连接
        co_return;
    }
    // 发送消息
    channel->write("hello world! i am client");

    // 等待消息返回
    std::cout << "wait for recv message" << std::endl;
    auto c2 = [](std::function<void(const hv::SocketChannelPtr &, hv::Buffer*)> &&f){
        Client.onMessage = std::move(f);
    };
    auto arg = co_await AsyncExec(std::move(c2));
    hv::SocketChannelPtr c = std::get<0>(arg);
    hv::Buffer* buffer = std::get<1>(arg);
    std::string msg(static_cast<const char *>(buffer->data()), buffer->size());
    std::cout << "recv server response: " << msg << std::endl;
    c->write("i recv your message");

    Client.closesocket(); // 断开连接
}

Task<void> clientco;
void StartClient()
{
    clientco = StartConnectCoroutine();
    Main->postEvent([](hv::Event *){
        clientco();
    });
}

// 用于终止循环的协程函数
Task<void> StopLoopCoroutine()
{
    // 设置定时器,10秒后停止主循环(展示两种方法)
    if(true)
    {
        // 方法1
        auto timer_func = std::bind(&hv::EventLoop::setTimer, Main.get(), 10000, std::placeholders::_1, 1, INVALID_TIMER_ID);
        [[maybe_unused]] hv::TimerID timer_id1 = co_await AsyncExec<hv::TimerID>(std::move(timer_func));  // 这里接收到的定时器ID已经是过期了的,所以可以不接收
    }
    else 
    {
        // 方法2
        auto timer_func2 = [](hv::TimerCallback &&cb){
            Main->setTimer(10000, std::move(cb));
        };
        [[maybe_unused]] hv::TimerID timer_id2 = co_await AsyncExec(timer_func2);  // 这里接收到的定时器ID已经是过期了的,所以可以不接收
    }
    std::cout << "timeout ---->>> stop loop" << std::endl;
    Main->stop();  // 停止主循环
}

int main(int argc, char* argv[]) 
{
    // 投递协程任务
    if(argc > 1 && std::string(argv[1]) == "0")
    {
        StartServer();
    }
    else 
    {
        StartClient();
    }
    auto c = StopLoopCoroutine();
    Main->postEvent([c](hv::Event *){
        c();  // 我想在主循环启动后再启动这个协程
    });
    // 运行主循环
    std::cout << "main loop start" << std::endl;
    Main->run();
    std::cout << "main loop end" << std::endl;
#ifdef _WIN32
    system("pause");
#endif
    return 0;
}

服务端输出结果:

main loop start
start listen connection
start listen message
new connection
start listen connection
recv: hello world! i am client
start listen message
recv: i recv your message
start listen message
connection lost
start listen connection
timeout ---->>> stop loop
main loop end

客户端输出结果:

main loop start
start connect to server
connect server success
wait for recv message
recv server response: i am server
timeout ---->>> stop loop
main loop end

到此这篇关于基于C++20实现协程库的示例代码的文章就介绍到这了,更多相关C++20 协程库内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文