fork_join.cpp 7.27 KB
#include <asio/execution.hpp>
#include <asio/static_thread_pool.hpp>
#include <algorithm>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <numeric>

using asio::static_thread_pool;
namespace execution = asio::execution;

// A fixed-size thread pool used to implement fork/join semantics. Functions
// are scheduled using a simple FIFO queue. Implementing work stealing, or
// using a queue based on atomic operations, are left as tasks for the reader.
class fork_join_pool
{
public:
  // The constructor starts a thread pool with the specified number of threads.
  // Note that the thread_count is not a fixed limit on the pool's concurrency.
  // Additional threads may temporarily be added to the pool if they join a
  // fork_executor.
  explicit fork_join_pool(
      std::size_t thread_count = std::max(std::thread::hardware_concurrency(), 1u) * 2)
    : use_count_(1),
      threads_(thread_count)
  {
    try
    {
      // Ask each thread in the pool to dequeue and execute functions until
      // it is time to shut down, i.e. the use count is zero.
      for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
      {
        execution::execute(
            threads_.executor(),
            [this]
            {
              std::unique_lock<std::mutex> lock(mutex_);
              while (use_count_ > 0)
                if (!execute_next(lock))
                  condition_.wait(lock);
            });
      }
    }
    catch (...)
    {
      stop_threads();
      threads_.wait();
      throw;
    }
  }

  // The destructor waits for the pool to finish executing functions.
  ~fork_join_pool()
  {
    stop_threads();
    threads_.wait();
  }

private:
  friend class fork_executor;

  // The base for all functions that are queued in the pool.
  struct function_base
  {
    std::shared_ptr<std::size_t> work_count_;
    void (*execute_)(std::shared_ptr<function_base>& p);
  };

  // Execute the next function from the queue, if any. Returns true if a
  // function was executed, and false if the queue was empty.
  bool execute_next(std::unique_lock<std::mutex>& lock)
  {
    if (queue_.empty())
      return false;
    auto p(queue_.front());
    queue_.pop();
    lock.unlock();
    execute(lock, p); 
    return true;
  }

  // Execute a function and decrement the outstanding work.
  void execute(std::unique_lock<std::mutex>& lock,
      std::shared_ptr<function_base>& p)
  {
    std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
    try
    {
      p->execute_(p);
      lock.lock();
      do_work_finished(work_count);
    }
    catch (...)
    {
      lock.lock();
      do_work_finished(work_count);
      throw;
    }
  }

  // Increment outstanding work.
  void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
  {
    if (++(*work_count) == 1)
      ++use_count_;
  }

  // Decrement outstanding work. Notify waiting threads if we run out.
  void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
  {
    if (--(*work_count) == 0)
    {
      --use_count_;
      condition_.notify_all();
    }
  }

  // Dispatch a function, executing it immediately if the queue is already
  // loaded. Otherwise adds the function to the queue and wakes a thread.
  void do_execute(std::shared_ptr<function_base> p,
      const std::shared_ptr<std::size_t>& work_count)
  {
    std::unique_lock<std::mutex> lock(mutex_);
    if (queue_.size() > thread_count_ * 16)
    {
      do_work_started(work_count);
      lock.unlock();
      execute(lock, p);
    }
    else
    {
      queue_.push(p);
      do_work_started(work_count);
      condition_.notify_one();
    }
  }

  // Ask all threads to shut down.
  void stop_threads()
  {
    std::lock_guard<std::mutex> lock(mutex_);
    --use_count_;
    condition_.notify_all();
  }

  std::mutex mutex_;
  std::condition_variable condition_;
  std::queue<std::shared_ptr<function_base>> queue_;
  std::size_t use_count_;
  std::size_t thread_count_;
  static_thread_pool threads_;
};

// A class that satisfies the Executor requirements. Every function or piece of
// work associated with a fork_executor is part of a single, joinable group.
class fork_executor
{
public:
  fork_executor(fork_join_pool& ctx)
    : context_(ctx),
      work_count_(std::make_shared<std::size_t>(0))
  {
  }

  fork_join_pool& query(execution::context_t) const noexcept
  {
    return context_;
  }

  template <class Func>
  void execute(Func f) const
  {
    auto p(std::make_shared<function<Func>>(std::move(f), work_count_));
    context_.do_execute(p, work_count_);
  }

  friend bool operator==(const fork_executor& a,
      const fork_executor& b) noexcept
  {
    return a.work_count_ == b.work_count_;
  }

  friend bool operator!=(const fork_executor& a,
      const fork_executor& b) noexcept
  {
    return a.work_count_ != b.work_count_;
  }

  // Block until all work associated with the executor is complete. While it is
  // waiting, the thread may be borrowed to execute functions from the queue.
  void join() const
  {
    std::unique_lock<std::mutex> lock(context_.mutex_);
    while (*work_count_ > 0)
      if (!context_.execute_next(lock))
        context_.condition_.wait(lock);
  }

private:
  template <class Func>
  struct function : fork_join_pool::function_base
  {
    explicit function(Func f, const std::shared_ptr<std::size_t>& w)
      : function_(std::move(f))
    {
      work_count_ = w;
      execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
      {
        Func tmp(std::move(static_cast<function*>(p.get())->function_));
        p.reset();
        tmp();
      };
    }

    Func function_;
  };

  fork_join_pool& context_;
  std::shared_ptr<std::size_t> work_count_;
};

// Helper class to automatically join a fork_executor when exiting a scope.
class join_guard
{
public:
  explicit join_guard(const fork_executor& ex) : ex_(ex) {}
  join_guard(const join_guard&) = delete;
  join_guard(join_guard&&) = delete;
  ~join_guard() { ex_.join(); }

private:
  fork_executor ex_;
};

//------------------------------------------------------------------------------

#include <algorithm>
#include <iostream>
#include <random>
#include <vector>

fork_join_pool pool;

template <class Iterator>
void fork_join_sort(Iterator begin, Iterator end)
{
  std::size_t n = end - begin;
  if (n > 32768)
  {
    {
      fork_executor fork(pool);
      join_guard join(fork);
      execution::execute(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
      execution::execute(fork, [=]{ fork_join_sort(begin + n / 2, end); });
    }
    std::inplace_merge(begin, begin + n / 2, end);
  }
  else
  {
    std::sort(begin, end);
  }
}

int main(int argc, char* argv[])
{
  if (argc != 2)
  {
    std::cerr << "Usage: fork_join <size>\n";
    return 1;
  }

  std::vector<double> vec(std::atoll(argv[1]));
  std::iota(vec.begin(), vec.end(), 0);

  std::random_device rd;
  std::mt19937 g(rd());
  std::shuffle(vec.begin(), vec.end(), g);

  std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();

  fork_join_sort(vec.begin(), vec.end());

  std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;

  std::cout << "sort took ";
  std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
  std::cout << " microseconds" << std::endl;
}