[PATCH] libstdc++: Simplify metaprogramming in <random>
Jonathan Wakely
jwakely@redhat.com
Fri Oct 9 17:09:03 GMT 2020
This removes the __detail::_Shift class template, replacing it with a
constexpr function template __pow2m1. Instead of using the _Mod class
template to calculate a modulus just perform a bitwise AND with the
result of __pow2m1. This works because the places that change all
perform a modulus operation with a power of two, x mod 2^w, which can be
replaced with x & (2^w - 1).
The _Mod class template is still needed by linear_congruential_engine
which needs to calculate (a * x + c) % m without overflow.
I'm not committing this yet, please review and check I've not broken
anything.
-------------- next part --------------
commit 707145dfc0a034df1be027080c2f7c9dfe314c2b
Author: Jonathan Wakely <jwakely@redhat.com>
Date: Fri Oct 9 18:06:00 2020
libstdc++: Simplify metaprogramming in <random>
This removes the __detail::_Shift class template, replacing it with a
constexpr function template __pow2m1. Instead of using the _Mod class
template to calculate a modulus just perform a bitwise AND with the
result of __pow2m1. This works because the places that change all
perform a modulus operation with a power of two, x mod 2^w, which can be
replaced with x & (2^w - 1).
The _Mod class template is still needed by linear_congruential_engine
which needs to calculate (a * x + c) % m without overflow.
libstdc++-v3/ChangeLog:
* include/bits/random.h (__detail::_Shift): Remove.
(__detail::_Select_uint_least_t<__s, 1>): Allow __int128 to be
used when supported by the compiler, even if _GLIBCXX_USE_INT128
is not defined.
(__pow2m1): New constexpr function template for 2^w - 1.
(__detail::__mod): Remove.
(_Adaptor::min(), _Adaptor::max()): Add constexpr.
(linear_congruential_engine::operator()): Use _Mod::__calc
directly instead of __mod.
(mersenne_twister_engine): Assert 2u < w. Use max() in
assertions.
(mersenne_twister_engine::max()): Use __pow2m1.
(subtract_with_carry_engine::max()): Likewise.
(independent_bits_engine::max()): Likewise.
(seed_seq::seed_seq(initializer_list<IntType>)): Define inline,
using constructor delegation.
* include/bits/random.tcc (__detail::_Mod<>::__calc): Add
constexpr.
(linear_congruential_engine::seed(result_type)): Replace uses
of __mod function with explicit % operations.
(linear_congruential_engine::seed(Sseq&)): Remove factor
variable and replace multiplications by shifts.
(mersenne_twister_engine::seed(result_type)): Replace uses of
__mod and _Shift by % and & operations.
(mersenne_twister_engine::seed(Sseq&)): Likewise. Replace
multiplications by shifts.
(subtract_with_carry_engine::seed(result_type)): Likewise.
(subtract_with_carry_engine::seed(Sseq&)): Likewise.
(subtract_with_carry_engine::operator()()): Replace _Shift with
__pow2m1.
(seed_seq::seed_seq(initializer_list<IntType>)): Remove
out-of-line definition.
(seed_seq::seed_seq(InputIterator, InputIterator)): Replace
__mod and _Shift by bitwise AND.
* testsuite/26_numerics/random/pr60037-neg.cc: Adjust dg-error
line number.
diff --git a/libstdc++-v3/include/bits/random.h b/libstdc++-v3/include/bits/random.h
index 0be1191e07d..4be1819d465 100644
--- a/libstdc++-v3/include/bits/random.h
+++ b/libstdc++-v3/include/bits/random.h
@@ -65,16 +65,6 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
*/
namespace __detail
{
- template<typename _UIntType, size_t __w,
- bool = __w < static_cast<size_t>
- (std::numeric_limits<_UIntType>::digits)>
- struct _Shift
- { static const _UIntType __value = 0; };
-
- template<typename _UIntType, size_t __w>
- struct _Shift<_UIntType, __w, true>
- { static const _UIntType __value = _UIntType(1) << __w; };
-
template<int __s,
int __which = ((__s <= __CHAR_BIT__ * sizeof (int))
+ (__s <= __CHAR_BIT__ * sizeof (long))
@@ -99,12 +89,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
struct _Select_uint_least_t<__s, 2>
{ typedef unsigned long long type; };
-#ifdef _GLIBCXX_USE_INT128
+#ifdef __SIZEOF_INT128__
template<int __s>
struct _Select_uint_least_t<__s, 1>
{ typedef unsigned __int128 type; };
#endif
+ // `_Mod<T, m, a, c>::__calc(x)` returns `(a x + c) mod m`.
// Assume a != 0, a < m, c < m, x < m.
template<typename _Tp, _Tp __m, _Tp __a, _Tp __c,
bool __big_enough = (!(__m & (__m - 1))
@@ -143,18 +134,20 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
};
- template<typename _Tp, _Tp __m, _Tp __a = 1, _Tp __c = 0>
- inline _Tp
- __mod(_Tp __x)
+ // Returns `static_cast<_Tp>(pow(2, w) - 1)`
+ template<typename _Tp>
+ constexpr _Tp
+ __pow2m1(_Tp __w)
{
- if _GLIBCXX17_CONSTEXPR (__a == 0)
- return __c;
- else
- {
- // _Mod must not be instantiated with a == 0
- constexpr _Tp __a1 = __a ? __a : 1;
- return _Mod<_Tp, __m, __a1, __c>::__calc(__x);
- }
+ static_assert(!numeric_limits<_Tp>::is_signed,
+ "type must be unsigned");
+
+#if __cplusplus >= 201402L
+ if (__w > numeric_limits<_Tp>::digits)
+ __builtin_abort();
+#endif
+
+ return ~_Tp(0) >> (numeric_limits<_Tp>::digits - __w);
}
/*
@@ -171,11 +164,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
_Adaptor(_Engine& __g)
: _M_g(__g) { }
- _DInputType
+ constexpr _DInputType
min() const
{ return _DInputType(0); }
- _DInputType
+ constexpr _DInputType
max() const
{ return _DInputType(1); }
@@ -240,7 +233,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
* A random number generator that produces pseudorandom numbers via
* linear function:
* @f[
- * x_{i+1}\leftarrow(ax_{i} + c) \bmod m
+ * x_{i+1}\leftarrow(a x_{i} + c) \bmod m
* @f]
*
* The template parameter @p _UIntType must be an unsigned integral type
@@ -357,7 +350,14 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
result_type
operator()()
{
- _M_x = __detail::__mod<_UIntType, __m, __a, __c>(_M_x);
+ if _GLIBCXX17_CONSTEXPR (__a == 0)
+ _M_x = __c;
+ else
+ {
+ // _Mod must not be instantiated with a == 0
+ constexpr _UIntType __a1 = __a ? __a : 1;
+ _M_x = __detail::_Mod<_UIntType, __m, __a1, __c>::__calc(_M_x);
+ }
return _M_x;
}
@@ -485,18 +485,8 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
"__t out of bound");
static_assert(__l <= __w, "template argument substituting "
"__l out of bound");
- static_assert(__w <= std::numeric_limits<_UIntType>::digits,
+ static_assert(2u < __w && __w <= std::numeric_limits<_UIntType>::digits,
"template argument substituting __w out of bound");
- static_assert(__a <= (__detail::_Shift<_UIntType, __w>::__value - 1),
- "template argument substituting __a out of bound");
- static_assert(__b <= (__detail::_Shift<_UIntType, __w>::__value - 1),
- "template argument substituting __b out of bound");
- static_assert(__c <= (__detail::_Shift<_UIntType, __w>::__value - 1),
- "template argument substituting __c out of bound");
- static_assert(__d <= (__detail::_Shift<_UIntType, __w>::__value - 1),
- "template argument substituting __d out of bound");
- static_assert(__f <= (__detail::_Shift<_UIntType, __w>::__value - 1),
- "template argument substituting __f out of bound");
template<typename _Sseq>
using _If_seed_seq = typename enable_if<__detail::__is_seed_seq<
@@ -560,7 +550,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
*/
static constexpr result_type
max()
- { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+ { return __detail::__pow2m1(result_type(__w)); }
/**
* @brief Discard a sequence of random numbers.
@@ -642,6 +632,17 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
__l1, __f1>& __x);
private:
+ static_assert(__a <= max(),
+ "template argument substituting __a out of bound");
+ static_assert(__b <= max(),
+ "template argument substituting __b out of bound");
+ static_assert(__c <= max(),
+ "template argument substituting __c out of bound");
+ static_assert(__d <= max(),
+ "template argument substituting __d out of bound");
+ static_assert(__f <= max(),
+ "template argument substituting __f out of bound");
+
void _M_gen_rand();
_UIntType _M_x[state_size];
@@ -771,7 +772,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
*/
static constexpr result_type
max()
- { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+ { return __detail::__pow2m1(result_type(__w)); }
/**
* @brief Discard a sequence of random numbers.
@@ -1212,7 +1213,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
*/
static constexpr result_type
max()
- { return __detail::_Shift<_UIntType, __w>::__value - 1; }
+ { return __detail::__pow2m1(result_type(__w)); }
/**
* @brief Discard a sequence of random numbers.
@@ -6072,7 +6073,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
{ }
template<typename _IntType>
- seed_seq(std::initializer_list<_IntType> __il);
+ seed_seq(std::initializer_list<_IntType> __il)
+ : seed_seq(__il.begin(), __il.end())
+ { }
template<typename _InputIterator>
seed_seq(_InputIterator __begin, _InputIterator __end);
diff --git a/libstdc++-v3/include/bits/random.tcc b/libstdc++-v3/include/bits/random.tcc
index bf39a51559b..1cd840cfd7a 100644
--- a/libstdc++-v3/include/bits/random.tcc
+++ b/libstdc++-v3/include/bits/random.tcc
@@ -56,8 +56,8 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
__x %= __m;
else
{
- static const _Tp __q = __m / __a;
- static const _Tp __r = __m % __a;
+ constexpr _Tp __q = __m / __a;
+ constexpr _Tp __r = __m % __a;
_Tp __t1 = __a * (__x % __q);
_Tp __t2 = __r * (__x / __q);
@@ -116,11 +116,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
linear_congruential_engine<_UIntType, __a, __c, __m>::
seed(result_type __s)
{
- if ((__detail::__mod<_UIntType, __m>(__c) == 0)
- && (__detail::__mod<_UIntType, __m>(__s) == 0))
+ if _GLIBCXX17_CONSTEXPR (__m == 0)
+ _M_x = (__s == 0 && __c == 0) ? 1 : __s;
+ else if ((__s % __m) == 0 && (__c % __m) == 0)
_M_x = 1;
else
- _M_x = __detail::__mod<_UIntType, __m>(__s);
+ _M_x = __s % __m;
}
/**
@@ -135,15 +136,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
{
const _UIntType __k0 = __m == 0 ? std::numeric_limits<_UIntType>::digits
: std::__lg(__m);
- const _UIntType __k = (__k0 + 31) / 32;
+ constexpr _UIntType __k = (__k0 + 31) / 32;
uint_least32_t __arr[__k + 3];
__q.generate(__arr + 0, __arr + __k + 3);
- _UIntType __factor = 1u;
_UIntType __sum = 0u;
for (size_t __j = 0; __j < __k; ++__j)
{
- __sum += __arr[__j + 3] * __factor;
- __factor *= __detail::_Shift<_UIntType, 32>::__value;
+ __sum += __arr[__j + 3] << (32 * __j);
}
seed(__sum);
}
@@ -324,17 +323,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
__s, __b, __t, __c, __l, __f>::
seed(result_type __sd)
{
- _M_x[0] = __detail::__mod<_UIntType,
- __detail::_Shift<_UIntType, __w>::__value>(__sd);
+ _M_x[0] = __sd & max();
for (size_t __i = 1; __i < state_size; ++__i)
{
_UIntType __x = _M_x[__i - 1];
__x ^= __x >> (__w - 2);
__x *= __f;
- __x += __detail::__mod<_UIntType, __n>(__i);
- _M_x[__i] = __detail::__mod<_UIntType,
- __detail::_Shift<_UIntType, __w>::__value>(__x);
+ __x += __i % __n;
+ _M_x[__i] = __x & max();
}
_M_p = state_size;
}
@@ -352,22 +349,19 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
-> _If_seed_seq<_Sseq>
{
const _UIntType __upper_mask = (~_UIntType()) << __r;
- const size_t __k = (__w + 31) / 32;
+ constexpr size_t __k = (__w + 31) / 32;
uint_least32_t __arr[__n * __k];
__q.generate(__arr + 0, __arr + __n * __k);
bool __zero = true;
for (size_t __i = 0; __i < state_size; ++__i)
{
- _UIntType __factor = 1u;
_UIntType __sum = 0u;
for (size_t __j = 0; __j < __k; ++__j)
{
- __sum += __arr[__k * __i + __j] * __factor;
- __factor *= __detail::_Shift<_UIntType, 32>::__value;
+ __sum += __arr[__k * __i + __j] << (32 * __j);
}
- _M_x[__i] = __detail::__mod<_UIntType,
- __detail::_Shift<_UIntType, __w>::__value>(__sum);
+ _M_x[__i] = __sum & max();
if (__zero)
{
@@ -381,7 +375,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
}
if (__zero)
- _M_x[0] = __detail::_Shift<_UIntType, __w - 1>::__value;
+ _M_x[0] = _UIntType(1) << (__w - 1);
_M_p = state_size;
}
@@ -540,21 +534,16 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
std::linear_congruential_engine<result_type, 40014u, 0u, 2147483563u>
__lcg(__value == 0u ? default_seed : __value);
- const size_t __n = (__w + 31) / 32;
+ constexpr size_t __n = (__w + 31) / 32;
for (size_t __i = 0; __i < long_lag; ++__i)
{
_UIntType __sum = 0u;
- _UIntType __factor = 1u;
for (size_t __j = 0; __j < __n; ++__j)
{
- __sum += __detail::__mod<uint_least32_t,
- __detail::_Shift<uint_least32_t, 32>::__value>
- (__lcg()) * __factor;
- __factor *= __detail::_Shift<_UIntType, 32>::__value;
+ __sum += __lcg() << (32 * __j);
}
- _M_x[__i] = __detail::__mod<_UIntType,
- __detail::_Shift<_UIntType, __w>::__value>(__sum);
+ _M_x[__i] = __sum & max();
}
_M_carry = (_M_x[long_lag - 1] == 0) ? 1 : 0;
_M_p = 0;
@@ -574,14 +563,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
for (size_t __i = 0; __i < long_lag; ++__i)
{
_UIntType __sum = 0u;
- _UIntType __factor = 1u;
for (size_t __j = 0; __j < __k; ++__j)
{
- __sum += __arr[__k * __i + __j] * __factor;
- __factor *= __detail::_Shift<_UIntType, 32>::__value;
+ __sum += __arr[__k * __i + __j] << (32 * __j);
}
- _M_x[__i] = __detail::__mod<_UIntType,
- __detail::_Shift<_UIntType, __w>::__value>(__sum);
+ _M_x[__i] = __sum & max();
}
_M_carry = (_M_x[long_lag - 1] == 0) ? 1 : 0;
_M_p = 0;
@@ -609,7 +595,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
}
else
{
- __xi = (__detail::_Shift<_UIntType, __w>::__value
+ __xi = (__detail::__pow2m1(_UIntType(__w)) + _UIntType(1)
- _M_x[_M_p] - _M_carry + _M_x[__ps]);
_M_carry = 1;
}
@@ -3197,20 +3183,11 @@ namespace __detail
}
- template<typename _IntType>
- seed_seq::seed_seq(std::initializer_list<_IntType> __il)
- {
- for (auto __iter = __il.begin(); __iter != __il.end(); ++__iter)
- _M_v.push_back(__detail::__mod<result_type,
- __detail::_Shift<result_type, 32>::__value>(*__iter));
- }
-
template<typename _InputIterator>
seed_seq::seed_seq(_InputIterator __begin, _InputIterator __end)
{
for (_InputIterator __iter = __begin; __iter != __end; ++__iter)
- _M_v.push_back(__detail::__mod<result_type,
- __detail::_Shift<result_type, 32>::__value>(*__iter));
+ _M_v.push_back(*__iter & 0xffffffffu);
}
template<typename _RandomAccessIterator>
diff --git a/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc b/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
index 0b5f597040b..f4a59799d9c 100644
--- a/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
+++ b/libstdc++-v3/testsuite/26_numerics/random/pr60037-neg.cc
@@ -10,6 +10,6 @@ std::__detail::_Adaptor<std::mt19937, unsigned long> aurng(urng);
auto x = std::generate_canonical<std::size_t,
std::numeric_limits<std::size_t>::digits>(urng);
-// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 167 }
+// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 160 }
-// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 3312 }
+// { dg-error "static assertion failed: template argument must be a floating point type" "" { target *-*-* } 3289 }
More information about the Gcc-patches
mailing list