[gcc(refs/users/redi/heads/pr92895)] libstdc++: Fix conformance issues in <stop_token> (PR92895)

Jonathan Wakely redi@gcc.gnu.org
Wed Jan 22 14:22:00 GMT 2020


https://gcc.gnu.org/g:4681b7a3f8977a825523417c3f3a6926a6e19187

commit 4681b7a3f8977a825523417c3f3a6926a6e19187
Author: Jonathan Wakely <jwakely@redhat.com>
Date:   Wed Jan 22 14:09:56 2020 +0000

    libstdc++: Fix conformance issues in <stop_token> (PR92895)
    
    Replaces shared_ptr with _Stop_state_ref.
    
    Replaces std::mutex with spinlock using one bit of a std::atomic<> that
    also tracks whether a stop request has been made and how many
    stop_source objects share ownership of the state.
    
    TODO: see TODO comments in the header about callback synchronization
    
    TODO: more tests, including verifying the following:
    
    - stop_source move assignment leaves RHS empty
    
    - stop_token::stop_possible() must be false if there are no more stop_source
      owners
    
    - no lock held while callbacks invoked (so won't deadlock if a callback
      tries to unregister another callback).
    
    - stop_callback<CB> is ill-formed unless destructible<CB> && invocable<CB>
    
    - callback must be forwarded and invoked with correct value category
    
    - try to test TOCTTOU race is fixed in stop_callback constructor.

Diff:
---
 libstdc++-v3/include/std/stop_token | 406 +++++++++++++++++++++++++++---------
 1 file changed, 306 insertions(+), 100 deletions(-)

diff --git a/libstdc++-v3/include/std/stop_token b/libstdc++-v3/include/std/stop_token
index e23d139..37dd80d 100644
--- a/libstdc++-v3/include/std/stop_token
+++ b/libstdc++-v3/include/std/stop_token
@@ -32,10 +32,6 @@
 #if __cplusplus > 201703L
 
 #include <atomic>
-#include <bits/std_mutex.h>
-#include <ext/concurrence.h>
-#include <bits/unique_ptr.h>
-#include <bits/shared_ptr.h>
 
 #ifdef _GLIBCXX_HAS_GTHREADS
 # define __cpp_lib_jthread 201907L
@@ -49,6 +45,8 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
   struct nostopstate_t { explicit nostopstate_t() = default; };
   inline constexpr nostopstate_t nostopstate{};
 
+  class stop_source;
+
   /// Allow testing whether a stop request has been made on a `stop_source`.
   class stop_token
   {
@@ -70,14 +68,14 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     bool
     stop_possible() const noexcept
     {
-      return static_cast<bool>(_M_state);
+      return static_cast<bool>(_M_state) && _M_state->_M_stop_possible();
     }
 
     [[nodiscard]]
     bool
     stop_requested() const noexcept
     {
-      return stop_possible() && _M_state->_M_stop_requested();
+      return static_cast<bool>(_M_state) && _M_state->_M_stop_requested();
     }
 
     void
@@ -100,74 +98,154 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
     struct _Stop_cb
     {
-      void(*_M_callback)(_Stop_cb*);
+      using __cb_type = void(_Stop_cb*) noexcept;
+      __cb_type* _M_callback;
       _Stop_cb* _M_prev = nullptr;
       _Stop_cb* _M_next = nullptr;
 
-      template<typename _Cb>
-	_Stop_cb(_Cb&& __cb)
-	: _M_callback(std::forward<_Cb>(__cb))
-	{ }
-
-      bool
-      _M_linked() const noexcept
-      {
-        return (_M_prev != nullptr)
-          || (_M_next != nullptr);
-      }
+      [[__gnu__::__nonnull__]]
+      explicit
+      _Stop_cb(__cb_type* __cb)
+      : _M_callback(__cb)
+      { }
 
-      static void
-      _S_execute(_Stop_cb* __cb) noexcept
-      {
-        __cb->_M_callback(__cb);
-        __cb->_M_prev = __cb->_M_next = nullptr;
-      }
+      void _M_run() noexcept { _M_callback(this); }
     };
 
     struct _Stop_state_t
     {
-      std::atomic<bool> _M_stopped{false};
+      using value_type = uint32_t;
+      static constexpr value_type _S_stop_requested_bit = 1;
+      static constexpr value_type _S_locked_bit = 2;
+      static constexpr value_type _S_ssrc_counter_inc = 4;
+
+      std::atomic<value_type> _M_owners{1};
+      std::atomic<value_type> _M_value{_S_ssrc_counter_inc};
       _Stop_cb* _M_head = nullptr;
-#ifdef _GLIBCXX_HAS_GTHREADS
-      std::mutex _M_mtx;
-#endif
 
       _Stop_state_t() = default;
 
       bool
+      _M_stop_possible() noexcept
+      {
+	// true if a stop request has already been made or there are still
+	// stop_source objects that would allow one to be made.
+	return _M_value.load(memory_order::acquire) & ~_S_locked_bit;
+      }
+
+      bool
       _M_stop_requested() noexcept
       {
-        return _M_stopped;
+        return _M_value.load(memory_order::acquire) & _S_stop_requested_bit;
+      }
+
+      void
+      _M_add_owner() noexcept
+      {
+	_M_owners.fetch_add(1, memory_order::relaxed);
+      }
+
+      void
+      _M_release_ownership() noexcept
+      {
+	if (_M_owners.fetch_sub(1, memory_order::release) == 1)
+	  delete this;
+      }
+
+      void
+      _M_add_ssrc() noexcept
+      {
+	_M_value.fetch_add(_S_ssrc_counter_inc, memory_order::relaxed);
+      }
+
+      void
+      _M_sub_ssrc() noexcept
+      {
+	_M_value.fetch_sub(_S_ssrc_counter_inc, memory_order::release);
+      }
+
+      // Obtain lock (even if stop request has already been made).
+      void
+      _M_lock() noexcept
+      {
+	// Can use relaxed loads to get the current value.
+	// The successful call to _M_try_lock is an acquire operation.
+	auto __old = _M_value.load(memory_order::relaxed);
+	while (!_M_try_lock(__old, memory_order::relaxed))
+	  { }
+      }
+
+      // Precondition: calling thread holds the lock.
+      void
+      _M_unlock() noexcept
+      {
+	_M_value.fetch_sub(_S_locked_bit, memory_order::release);
       }
 
       bool
-      _M_request_stop()
+      _M_request_stop() noexcept
       {
-        bool __stopped = false;
-        if (_M_stopped.compare_exchange_strong(__stopped, true))
-          {
-#ifdef _GLIBCXX_HAS_GTHREADS
-            std::lock_guard<std::mutex> __lck{_M_mtx};
-#endif
-            while (_M_head)
-              {
-                auto __p = _M_head;
-                _M_head = _M_head->_M_next;
-                _Stop_cb::_S_execute(__p);
-              }
-            return true;
-          }
-        return false;
+	// obtain lock and set stop_requested bit
+	auto __old = _M_value.load(memory_order::acquire);
+	do
+	  {
+	    if (__old & _S_stop_requested_bit) // stop request already made
+	      return false;
+	  }
+	while (!_M_try_lock_and_stop(__old));
+
+	while (_M_head)
+	  {
+	    bool __last_cb;
+	    _Stop_cb* __cb = _M_head;
+	    _M_head = _M_head->_M_next;
+	    if (_M_head)
+	      {
+		_M_head->_M_prev = nullptr;
+		__last_cb = false;
+	      }
+	    else
+	      __last_cb = true;
+
+	    // Allow other callbacks to be unregistered while __cb runs.
+	    _M_unlock();
+
+	    // run callback
+	    // TODO: synchronize with owning stop_callback's destructor
+	    __cb->_M_run();
+
+	    // TODO: check if __cb is still accessible, might be destroyed
+	    __cb->_M_prev = __cb->_M_next = nullptr;
+
+	    // Avoid relocking if we already know there are no more callbacks.
+	    if (__last_cb)
+	      return true;
+
+	    _M_lock();
+	  }
+
+	_M_unlock();
+	return true;
       }
 
+      [[__gnu__::__nonnull__]]
       bool
-      _M_register_callback(_Stop_cb* __cb)
+      _M_register_callback(_Stop_cb* __cb) noexcept
       {
-#ifdef _GLIBCXX_HAS_GTHREADS
-        std::lock_guard<std::mutex> __lck{_M_mtx};
-#endif
-        if (_M_stopped)
-          return false;
+	auto __old = _M_value.load(memory_order::acquire);
+	do
+	  {
+	    if (__old & _S_stop_requested_bit) // stop request already made
+	      {
+		__cb->_M_run(); // run synchronously
+		return false;
+	      }
+
+	    if (__old < _S_ssrc_counter_inc) // no stop_source owns *this
+	      // No need to register callback if no stop request can be made.
+	      return true;
+	  }
+	while (!_M_try_lock(__old));
 
         __cb->_M_next = _M_head;
         if (_M_head)
@@ -175,43 +253,152 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
             _M_head->_M_prev = __cb;
           }
         _M_head = __cb;
+	_M_unlock();
         return true;
       }
 
+      [[__gnu__::__nonnull__]]
       void
       _M_remove_callback(_Stop_cb* __cb)
       {
-#ifdef _GLIBCXX_HAS_GTHREADS
-        std::lock_guard<std::mutex> __lck{_M_mtx};
-#endif
+	_M_lock();
+
         if (__cb == _M_head)
           {
             _M_head = _M_head->_M_next;
             if (_M_head)
-              {
-                _M_head->_M_prev = nullptr;
-              }
+	      _M_head->_M_prev = nullptr;
+	    _M_unlock();
+	    return;
           }
-        else if (!__cb->_M_linked())
-          {
-            return;
-          }
-        else
+        else if (__cb->_M_prev)
           {
             __cb->_M_prev->_M_next = __cb->_M_next;
             if (__cb->_M_next)
-              {
-                __cb->_M_next->_M_prev = __cb->_M_prev;
-              }
+	      __cb->_M_next->_M_prev = __cb->_M_prev;
+	    _M_unlock();
+            return;
           }
+
+	_M_unlock();
+
+	// Callback is not in the list, so must be currently executing.
+	// TODO: synchronize with completion of callback
+      }
+
+      // Try to obtain the lock.
+      // Returns true if the lock is acquired (with memory order acquire).
+      // Otherwise, sets __curval = _M_value.load(__failure) and returns false.
+      // Might fail spuriously, so must be called in a loop.
+      bool
+      _M_try_lock(value_type& __curval,
+		  memory_order __failure = memory_order::acquire) noexcept
+      {
+	return _M_do_try_lock(__curval, 0, memory_order::acquire, __failure);
+      }
+
+      // Try to obtain the lock to make a stop request.
+      // Returns true if the lock is acquired and the _S_stop_requested_bit is
+      // set (with memory order acq_rel so that other threads see the request).
+      // Otherwise, sets __curval = _M_value.load(memory_order::acquire) and
+      // returns false.
+      // Might fail spuriously, so must be called in a loop.
+      bool
+      _M_try_lock_and_stop(value_type& __curval) noexcept
+      {
+	return _M_do_try_lock(__curval, _S_stop_requested_bit,
+			      memory_order::acq_rel, memory_order::acquire);
+      }
+
+      bool
+      _M_do_try_lock(value_type& __curval, value_type __newbits,
+		     memory_order __success, memory_order __failure) noexcept
+      {
+	if (__curval & _S_locked_bit)
+	  {
+	    if constexpr (__has_builtin(__builtin_ia32_pause))
+	      __builtin_ia32_pause();
+#ifdef _GLIBCXX_USE_SCHED_YIELD
+	    else
+	      __gthread_yield();
+#endif
+	    __curval = _M_value.load(__failure);
+	    return false;
+	  }
+	__newbits |= _S_locked_bit;
+	return _M_value.compare_exchange_weak(__curval, __curval | __newbits,
+					      __success, __failure);
       }
     };
 
-    using _Stop_state = std::shared_ptr<_Stop_state_t>;
-    _Stop_state _M_state;
+    struct _Stop_state_ref
+    {
+      _Stop_state_ref() = default;
+
+      explicit
+      _Stop_state_ref(const stop_source&)
+      : _M_ptr(new _Stop_state_t())
+      { }
+
+      _Stop_state_ref(const _Stop_state_ref& __other) noexcept
+      : _M_ptr(__other._M_ptr)
+      {
+	if (_M_ptr)
+	  _M_ptr->_M_add_owner();
+      }
+
+      _Stop_state_ref(_Stop_state_ref&& __other) noexcept
+      : _M_ptr(__other._M_ptr)
+      {
+	__other._M_ptr = nullptr;
+      }
+
+      _Stop_state_ref&
+      operator=(const _Stop_state_ref& __other) noexcept
+      {
+	if (auto __ptr = __other._M_ptr; __ptr != _M_ptr)
+	  {
+	    if (__ptr)
+	      __ptr->_M_add_owner();
+	    if (_M_ptr)
+	      _M_ptr->_M_release_ownership();
+	    _M_ptr = __ptr;
+	  }
+	return *this;
+      }
+
+      _Stop_state_ref&
+      operator=(_Stop_state_ref&& __other) noexcept
+      {
+	_Stop_state_ref(std::move(__other)).swap(*this);
+	return *this;
+      }
+
+      ~_Stop_state_ref()
+      {
+	if (_M_ptr)
+	  _M_ptr->_M_release_ownership();
+      }
+
+      void
+      swap(_Stop_state_ref& __other) noexcept
+      { std::swap(_M_ptr, __other._M_ptr); }
+
+      explicit operator bool() const noexcept { return _M_ptr != nullptr; }
+
+      _Stop_state_t* operator->() const noexcept { return _M_ptr; }
+
+      friend bool
+      operator==(const _Stop_state_ref&, const _Stop_state_ref&) = default;
+
+    private:
+      _Stop_state_t* _M_ptr = nullptr;
+    };
+
+    _Stop_state_ref _M_state;
 
     explicit
-    stop_token(const _Stop_state& __state) noexcept
+    stop_token(const _Stop_state_ref& __state) noexcept
     : _M_state{__state}
     { }
   };
@@ -220,34 +407,41 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
   class stop_source
   {
   public:
-    stop_source()
-      : _M_state(std::make_shared<stop_token::_Stop_state_t>())
+    stop_source() : _M_state(*this)
     { }
 
     explicit stop_source(std::nostopstate_t) noexcept
     { }
 
     stop_source(const stop_source& __other) noexcept
-      : _M_state(__other._M_state)
-    { }
+    : _M_state(__other._M_state)
+    {
+      if (_M_state)
+	_M_state->_M_add_ssrc();
+    }
 
-    stop_source(stop_source&& __other) noexcept
-      : _M_state(std::move(__other._M_state))
-    { }
+    stop_source(stop_source&&) noexcept = default;
 
     stop_source&
-    operator=(const stop_source& __rhs) noexcept
+    operator=(const stop_source& __other) noexcept
     {
-      if (_M_state != __rhs._M_state)
-        _M_state = __rhs._M_state;
+      if (_M_state != __other._M_state)
+	{
+	  stop_source __sink(std::move(*this));
+	  _M_state = __other._M_state;
+	  if (_M_state)
+	    _M_state->_M_add_ssrc();
+	}
       return *this;
     }
 
     stop_source&
-    operator=(stop_source&& __rhs) noexcept
+    operator=(stop_source&&) noexcept = default;
+
+    ~stop_source()
     {
-      std::swap(_M_state, __rhs._M_state);
-      return *this;
+      if (_M_state)
+	_M_state->_M_sub_ssrc();
     }
 
     [[nodiscard]]
@@ -261,7 +455,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     bool
     stop_requested() const noexcept
     {
-      return stop_possible() && _M_state->_M_stop_requested();
+      return static_cast<bool>(_M_state) && _M_state->_M_stop_requested();
     }
 
     bool
@@ -299,14 +493,16 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     }
 
   private:
-    stop_token::_Stop_state _M_state;
+    stop_token::_Stop_state_ref _M_state;
   };
 
   /// A wrapper for callbacks to be run when a stop request is made.
   template<typename _Callback>
     class [[nodiscard]] stop_callback
-      : private stop_token::_Stop_cb
     {
+      static_assert(is_nothrow_destructible_v<_Callback>);
+      static_assert(is_invocable_v<_Callback>);
+
     public:
       using callback_type = _Callback;
 
@@ -315,13 +511,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
         explicit
 	stop_callback(const stop_token& __token, _Cb&& __cb)
         noexcept(is_nothrow_constructible_v<_Callback, _Cb>)
-        : _Stop_cb(&_S_execute), _M_cb(std::forward<_Cb>(__cb))
+        : _M_cb(std::forward<_Cb>(__cb))
         {
 	  if (auto __state = __token._M_state)
 	    {
-	      if (__state->_M_stop_requested())
-		_S_execute(this); // ensures std::terminate on throw
-	      else if (__state->_M_register_callback(this))
+	      if (__state->_M_register_callback(&_M_cb))
 		_M_state.swap(__state);
 	    }
         }
@@ -331,13 +525,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
         explicit
 	stop_callback(stop_token&& __token, _Cb&& __cb)
         noexcept(is_nothrow_constructible_v<_Callback, _Cb>)
-        : _Stop_cb(&_S_execute), _M_cb(std::forward<_Cb>(__cb))
+        : _M_cb(std::forward<_Cb>(__cb))
 	{
 	  if (auto& __state = __token._M_state)
 	    {
-	      if (__state->_M_stop_requested())
-		_S_execute(this); // ensures std::terminate on throw
-	      else if (__state->_M_register_callback(this))
+	      if (__state->_M_register_callback(&_M_cb))
 		_M_state.swap(__state);
 	    }
 	}
@@ -346,7 +538,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       {
 	if (_M_state)
 	  {
-	    _M_state->_M_remove_callback(this);
+	    _M_state->_M_remove_callback(&_M_cb);
 	  }
       }
 
@@ -356,14 +548,28 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       stop_callback& operator=(stop_callback&&) = delete;
 
     private:
-      _Callback _M_cb;
-      stop_token::_Stop_state _M_state = nullptr;
-
-      static void
-      _S_execute(_Stop_cb* __that) noexcept
+      struct _Cb_impl : stop_token::_Stop_cb
       {
-	static_cast<stop_callback*>(__that)->_M_cb();
-      }
+	template<typename _Cb>
+	  explicit
+	  _Cb_impl(_Cb&& __cb)
+	  : _Stop_cb(&_S_execute),
+	    _M_cb(std::forward<_Cb>(__cb))
+	  { }
+
+	_Callback _M_cb;
+
+	[[__gnu__::__nonnull__]]
+	static void
+	_S_execute(_Stop_cb* __that) noexcept
+	{
+	  _Callback& __cb = static_cast<_Cb_impl*>(__that)->_M_cb;
+	  std::forward<_Callback>(__cb)();
+	}
+      };
+
+      _Cb_impl _M_cb;
+      stop_token::_Stop_state_ref _M_state;
     };
 
   template<typename _Callback>



More information about the Libstdc++-cvs mailing list