[PATCH] libstdc++: Simplify metaprogramming in <random>

Jonathan Wakely jwakely@redhat.com
Fri Oct 9 17:09:03 GMT 2020


This removes the __detail::_Shift class template, replacing it with a
constexpr function template __pow2m1. Instead of using the _Mod class
template to calculate a modulus just perform a bitwise AND with the
result of __pow2m1. This works because the places that change all
perform a modulus operation with a power of two, x mod 2^w, which can be
replaced with x & (2^w - 1).

The _Mod class template is still needed by linear_congruential_engine
which needs to calculate (a * x + c) % m without overflow.

I'm not committing this yet, please review and check I've not broken
anything.


-------------- next part --------------
commit 707145dfc0a034df1be027080c2f7c9dfe314c2b
Author: Jonathan Wakely <jwakely@redhat.com>
Date:   Fri Oct 9 18:06:00 2020

    libstdc++: Simplify metaprogramming in <random>
    
    This removes the __detail::_Shift class template, replacing it with a
    constexpr function template __pow2m1. Instead of using the _Mod class
    template to calculate a modulus just perform a bitwise AND with the
    result of __pow2m1. This works because the places that change all
    perform a modulus operation with a power of two, x mod 2^w, which can be
    replaced with x & (2^w - 1).
    
    The _Mod class template is still needed by linear_congruential_engine
    which needs to calculate (a * x + c) % m without overflow.
    
    libstdc++-v3/ChangeLog:
    
            * include/bits/random.h (__detail::_Shift): Remove.
            (__detail::_Select_uint_least_t<__s, 1>): Allow __int128 to be
            used when supported by the compiler, even if _GLIBCXX_USE_INT128
            is not defined.
            (__pow2m1): New constexpr function template for 2^w - 1.
            (__detail::__mod): Remove.
            (_Adaptor::min(), _Adaptor::max()): Add constexpr.
            (linear_congruential_engine::operator()): Use _Mod::__calc
            directly instead of __mod.
            (mersenne_twister_engine): Assert 2u < w. Use max() in
            assertions.
            (mersenne_twister_engine::max()): Use __pow2m1.
            (subtract_with_carry_engine::max()): Likewise.
            (independent_bits_engine::max()): Likewise.
            (seed_seq::seed_seq(initializer_list<IntType>)): Define inline,
            using constructor delegation.
            * include/bits/random.tcc (__detail::_Mod<>::__calc): Add
            constexpr.
            (linear_congruential_engine::seed(result_type)): Replace uses
            of __mod function with explicit % operations.
            (linear_congruential_engine::seed(Sseq&)): Remove factor
            variable and replace multiplications by shifts.
            (mersenne_twister_engine::seed(result_type)): Replace uses of
            __mod and _Shift by % and & operations.
            (mersenne_twister_engine::seed(Sseq&)): Likewise. Replace
            multiplications by shifts.
            (subtract_with_carry_engine::seed(result_type)): Likewise.
            (subtract_with_carry_engine::seed(Sseq&)): Likewise.
            (subtract_with_carry_engine::operator()()): Replace _Shift with
            __pow2m1.
            (seed_seq::seed_seq(initializer_list<IntType>)): Remove
            out-of-line definition.
            (seed_seq::seed_seq(InputIterator, InputIterator)): Replace
            __mod and _Shift by bitwise AND.
            * testsuite/26_numerics/random/pr60037-neg.cc: Adjust dg-error
            line number.

diff --git a/libstdc++-v3/include/bits/random.h b/libstdc++-v3/include/bits/random.h
index 0be1191e07d..4be1819d465 100644
--- a/libstdc++-v3/include/bits/random.h
+++ b/libstdc++-v3/include/bits/random.h
@@ -65,16 +65,6 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
    */
   namespace __detail
   {
-    template<typename _UIntType, size_t __w,
-	     bool = __w < static_cast<size_t>
-			  (std::numeric_limits<_UIntType>::digits)>
-      struct _Shift
-      { static const _UIntType __value = 0; };
-
-    template<typename _UIntType, size_t __w>
-      struct _Shift<_UIntType, __w, true>
-      { static const _UIntType __value = _UIntType(1) << __w; };
-
     template<int __s,
 	     int __which = ((__s <= __CHAR_BIT__ * sizeof (int))
 			    + (__s <= __CHAR_BIT__ * sizeof (long))
@@ -99,12 +89,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       struct _Select_uint_least_t<__s, 2>
       { typedef unsigned long long type; };
 
-#ifdef _GLIBCXX_USE_INT128
+#ifdef __SIZEOF_INT128__
     template<int __s>
       struct _Select_uint_least_t<__s, 1>
       { typedef unsigned __int128 type; };
 #endif
 
+    // `_Mod<T, m, a, c>::__calc(x)` returns `(a x + c) mod m`.
     // Assume a != 0, a < m, c < m, x < m.
     template<typename _Tp, _Tp __m, _Tp __a, _Tp __c,
 	     bool __big_enough = (!(__m & (__m - 1))
@@ -143,18 +134,20 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	}
       };
 
-    template<typename _Tp, _Tp __m, _Tp __a = 1, _Tp __c = 0>
-      inline _Tp
-      __mod(_Tp __x)
+    // Returns `static_cast<_Tp>(pow(2, w) - 1)`
+    template<typename _Tp>
+      constexpr _Tp
+      __pow2m1(_Tp __w)
       {
-	if _GLIBCXX17_CONSTEXPR (__a == 0)
-	  return __c;
-	else
-	  {
-	    // _Mod must not be instantiated with a == 0
-	    constexpr _Tp __a1 = __a ? __a : 1;
-	    return _Mod<_Tp, __m, __a1, __c>::__calc(__x);
-	  }
+	static_assert(!numeric_limits<_Tp>::is_signed,
+		      "type must be unsigned");
+
+#if __cplusplus >= 201402L
+	if (__w > numeric_limits<_Tp>::digits)
+	  __builtin_abort();
+#endif
+
+	return ~_Tp(0) >> (numeric_limits<_Tp>::digits - __w);
       }
 
     /*
@@ -171,11 +164,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	_Adaptor(_Engine& __g)
 	: _M_g(__g) { }
 
-	_DInputType
+	constexpr _DInputType
 	min() const
 	{ return _DInputType(0); }
 
-	_DInputType
+	constexpr _DInputType
 	max() const
 	{ return _DInputType(1); }
 
@@ -240,7 +233,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
    * A random number generator that produces pseudorandom numbers via
    * linear function:
    * @f[
-   *     x_{i+1}\leftarrow(ax_{i} + c) \bmod m 
+   *     x_{i+1}\leftarrow(a x_{i} + c) \bmod m
    * @f]
    *
    * The template parameter @p _UIntType must be an unsigned integral type
@@ -357,7 +350,14 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       result_type
       operator()()
       {
-	_M_x = __detail::__mod<_UIntType, __m, __a, __c>(_M_x);
+	if _GLIBCXX17_CONSTEXPR (__a == 0)
+	  _M_x = __c;
+	else
+	  {
+	    // _Mod must not be instantiated with a == 0
+	    constexpr _UIntType __a1 = __a ? __a : 1;
+	    _M_x = __detail::_Mod<_UIntType, __m, __a1, __c>::__calc(_M_x);
+	  }
 	return _M_x;
       }
 
@@ -485,18 +485,8 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 		    "__t out of bound");
       static_assert(__l <= __w, "template argument substituting "
 		    "__l out of bound");
-      static_assert(__w <= std::numeric_limits<_UIntType>::digits,
+      static_assert(2u < __w && __w <= std::numeric_limits<_UIntType>::digits,
 		    "template argument substituting __w out of bound");
-      static_assert(__a <= (__detail::_Shift<_UIntType, __w>::__value - 1),
-		    "template argument substituting __a out of bound");
-      static_assert(__b <= (__detail::_Shift<_UIntType, __w>::__value - 1),
-		    "template argument substituting __b out of bound");
-      static_assert(__c <= (__detail::_Shift<_UIntType, __w>::__value - 1),
-		    "template argument substituting __c out of bound");
-      static_assert(__d <= (__detail::_Shift<_UIntType, __w>::__value - 1),
-		    "template argument substituting __d out of bound");
-      static_assert(__f <= (__detail::_Shift<_UIntType, __w>::__value - 1),
-		    "template argument substituting __f out of bound");
 
       template<typename _Sseq>
 	using _If_seed_seq = typename enable_if<__detail::__is_seed_seq<
@@ -560,7 +550,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        */
       static constexpr result_type
       max()
-      { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+      { return __detail::__pow2m1(result_type(__w)); }
 
       /**
        * @brief Discard a sequence of random numbers.
@@ -642,6 +632,17 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 		   __l1, __f1>& __x);
 
     private:
+      static_assert(__a <= max(),
+		    "template argument substituting __a out of bound");
+      static_assert(__b <= max(),
+		    "template argument substituting __b out of bound");
+      static_assert(__c <= max(),
+		    "template argument substituting __c out of bound");
+      static_assert(__d <= max(),
+		    "template argument substituting __d out of bound");
+      static_assert(__f <= max(),
+		    "template argument substituting __f out of bound");
+
       void _M_gen_rand();
 
       _UIntType _M_x[state_size];
@@ -771,7 +772,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        */
       static constexpr result_type
       max()
-      { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+      { return __detail::__pow2m1(result_type(__w)); }
 
       /**
        * @brief Discard a sequence of random numbers.
@@ -1212,7 +1213,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        */
       static constexpr result_type
       max()
-      { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+      { return __detail::__pow2m1(result_type(__w)); }
 
       /**
        * @brief Discard a sequence of random numbers.
@@ -6072,7 +6073,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     { }
 
     template<typename _IntType>
-      seed_seq(std::initializer_list<_IntType> __il);
+      seed_seq(std::initializer_list<_IntType> __il)
+      : seed_seq(__il.begin(), __il.end())
+      { }
 
     template<typename _InputIterator>
       seed_seq(_InputIterator __begin, _InputIterator __end);
diff --git a/libstdc++-v3/include/bits/random.tcc b/libstdc++-v3/include/bits/random.tcc
index bf39a51559b..1cd840cfd7a 100644
--- a/libstdc++-v3/include/bits/random.tcc
+++ b/libstdc++-v3/include/bits/random.tcc
@@ -56,8 +56,8 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	  __x %= __m;
 	else
 	  {
-	    static const _Tp __q = __m / __a;
-	    static const _Tp __r = __m % __a;
+	    constexpr _Tp __q = __m / __a;
+	    constexpr _Tp __r = __m % __a;
 
 	    _Tp __t1 = __a * (__x % __q);
 	    _Tp __t2 = __r * (__x / __q);
@@ -116,11 +116,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     linear_congruential_engine<_UIntType, __a, __c, __m>::
     seed(result_type __s)
     {
-      if ((__detail::__mod<_UIntType, __m>(__c) == 0)
-	  && (__detail::__mod<_UIntType, __m>(__s) == 0))
+      if _GLIBCXX17_CONSTEXPR (__m == 0)
+	_M_x = (__s == 0 && __c == 0) ? 1 : __s;
+      else if ((__s % __m) == 0 && (__c % __m) == 0)
 	_M_x = 1;
       else
-	_M_x = __detail::__mod<_UIntType, __m>(__s);
+	_M_x = __s % __m;
     }
 
   /**
@@ -135,15 +136,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       {
 	const _UIntType __k0 = __m == 0 ? std::numeric_limits<_UIntType>::digits
 	                                : std::__lg(__m);
-	const _UIntType __k = (__k0 + 31) / 32;
+	constexpr _UIntType __k = (__k0 + 31) / 32;
 	uint_least32_t __arr[__k + 3];
 	__q.generate(__arr + 0, __arr + __k + 3);
-	_UIntType __factor = 1u;
 	_UIntType __sum = 0u;
 	for (size_t __j = 0; __j < __k; ++__j)
 	  {
-	    __sum += __arr[__j + 3] * __factor;
-	    __factor *= __detail::_Shift<_UIntType, 32>::__value;
+	    __sum += __arr[__j + 3] << (32 * __j);
 	  }
 	seed(__sum);
       }
@@ -324,17 +323,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 			    __s, __b, __t, __c, __l, __f>::
     seed(result_type __sd)
     {
-      _M_x[0] = __detail::__mod<_UIntType,
-	__detail::_Shift<_UIntType, __w>::__value>(__sd);
+      _M_x[0] = __sd & max();
 
       for (size_t __i = 1; __i < state_size; ++__i)
 	{
 	  _UIntType __x = _M_x[__i - 1];
 	  __x ^= __x >> (__w - 2);
 	  __x *= __f;
-	  __x += __detail::__mod<_UIntType, __n>(__i);
-	  _M_x[__i] = __detail::__mod<_UIntType,
-	    __detail::_Shift<_UIntType, __w>::__value>(__x);
+	  __x += __i % __n;
+	  _M_x[__i] = __x & max();
 	}
       _M_p = state_size;
     }
@@ -352,22 +349,19 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       -> _If_seed_seq<_Sseq>
       {
 	const _UIntType __upper_mask = (~_UIntType()) << __r;
-	const size_t __k = (__w + 31) / 32;
+	constexpr size_t __k = (__w + 31) / 32;
 	uint_least32_t __arr[__n * __k];
 	__q.generate(__arr + 0, __arr + __n * __k);
 
 	bool __zero = true;
 	for (size_t __i = 0; __i < state_size; ++__i)
 	  {
-	    _UIntType __factor = 1u;
 	    _UIntType __sum = 0u;
 	    for (size_t __j = 0; __j < __k; ++__j)
 	      {
-		__sum += __arr[__k * __i + __j] * __factor;
-		__factor *= __detail::_Shift<_UIntType, 32>::__value;
+		__sum += __arr[__k * __i + __j] << (32 * __j);
 	      }
-	    _M_x[__i] = __detail::__mod<_UIntType,
-	      __detail::_Shift<_UIntType, __w>::__value>(__sum);
+	    _M_x[__i] = __sum & max();
 
 	    if (__zero)
 	      {
@@ -381,7 +375,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	      }
 	  }
         if (__zero)
-          _M_x[0] = __detail::_Shift<_UIntType, __w - 1>::__value;
+          _M_x[0] = _UIntType(1) << (__w - 1);
 	_M_p = state_size;
       }
 
@@ -540,21 +534,16 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       std::linear_congruential_engine<result_type, 40014u, 0u, 2147483563u>
 	__lcg(__value == 0u ? default_seed : __value);
 
-      const size_t __n = (__w + 31) / 32;
+      constexpr size_t __n = (__w + 31) / 32;
 
       for (size_t __i = 0; __i < long_lag; ++__i)
 	{
 	  _UIntType __sum = 0u;
-	  _UIntType __factor = 1u;
 	  for (size_t __j = 0; __j < __n; ++__j)
 	    {
-	      __sum += __detail::__mod<uint_least32_t,
-		       __detail::_Shift<uint_least32_t, 32>::__value>
-			 (__lcg()) * __factor;
-	      __factor *= __detail::_Shift<_UIntType, 32>::__value;
+	      __sum += __lcg() << (32 * __j);
 	    }
-	  _M_x[__i] = __detail::__mod<_UIntType,
-	    __detail::_Shift<_UIntType, __w>::__value>(__sum);
+	  _M_x[__i] = __sum & max();
 	}
       _M_carry = (_M_x[long_lag - 1] == 0) ? 1 : 0;
       _M_p = 0;
@@ -574,14 +563,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	for (size_t __i = 0; __i < long_lag; ++__i)
 	  {
 	    _UIntType __sum = 0u;
-	    _UIntType __factor = 1u;
 	    for (size_t __j = 0; __j < __k; ++__j)
 	      {
-		__sum += __arr[__k * __i + __j] * __factor;
-		__factor *= __detail::_Shift<_UIntType, 32>::__value;
+		__sum += __arr[__k * __i + __j] << (32 * __j);
 	      }
-	    _M_x[__i] = __detail::__mod<_UIntType,
-	      __detail::_Shift<_UIntType, __w>::__value>(__sum);
+	    _M_x[__i] = __sum & max();
 	  }
 	_M_carry = (_M_x[long_lag - 1] == 0) ? 1 : 0;
 	_M_p = 0;
@@ -609,7 +595,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	}
       else
 	{
-	  __xi = (__detail::_Shift<_UIntType, __w>::__value
+	  __xi = (__detail::__pow2m1(_UIntType(__w)) + _UIntType(1)
 		  - _M_x[_M_p] - _M_carry + _M_x[__ps]);
 	  _M_carry = 1;
 	}
@@ -3197,20 +3183,11 @@ namespace __detail
     }
 
 
-  template<typename _IntType>
-    seed_seq::seed_seq(std::initializer_list<_IntType> __il)
-    {
-      for (auto __iter = __il.begin(); __iter != __il.end(); ++__iter)
-	_M_v.push_back(__detail::__mod<result_type,
-		       __detail::_Shift<result_type, 32>::__value>(*__iter));
-    }
-
   template<typename _InputIterator>
     seed_seq::seed_seq(_InputIterator __begin, _InputIterator __end)
     {
       for (_InputIterator __iter = __begin; __iter != __end; ++__iter)
-	_M_v.push_back(__detail::__mod<result_type,
-		       __detail::_Shift<result_type, 32>::__value>(*__iter));
+	_M_v.push_back(*__iter & 0xffffffffu);
     }
 
   template<typename _RandomAccessIterator>
diff --git a/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc b/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
index 0b5f597040b..f4a59799d9c 100644
--- a/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
+++ b/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
@@ -10,6 +10,6 @@ std::__detail::_Adaptor<std::mt19937, unsigned long> aurng(urng);
 auto x = std::generate_canonical<std::size_t,
 			std::numeric_limits<std::size_t>::digits>(urng);
 
-// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 167 }
+// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 160 }
 
-// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 3312 }
+// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 3289 }


More information about the Gcc-patches mailing list