[gcc r10-10196] libstdc++: Reduce ranges::minmax/minmax_element comparison complexity

Patrick Palka ppalka@gcc.gnu.org
Tue Oct 12 18:37:40 GMT 2021


https://gcc.gnu.org/g:1335d35a530f31f60874cea1d5c98de81deed335

commit r10-10196-g1335d35a530f31f60874cea1d5c98de81deed335
Author: Patrick Palka <ppalka@redhat.com>
Date:   Fri Jun 18 19:33:39 2021 -0400

    libstdc++: Reduce ranges::minmax/minmax_element comparison complexity
    
    This rewrites ranges::minmax and ranges::minmax_element so that it
    performs at most 3*N/2 many comparisons, as required by the standard.
    In passing, this also fixes PR100387 by avoiding a premature std::move
    in ranges::minmax and in std::shift_right.
    
            PR libstdc++/100387
    
    libstdc++-v3/ChangeLog:
    
            * include/bits/ranges_algo.h (__minmax_fn::operator()): Rewrite
            to limit comparison complexity to 3*N/2.
            (__minmax_element_fn::operator()): Likewise.
            (shift_right): Avoid premature std::move of __result.
            * testsuite/25_algorithms/minmax/constrained.cc (test04, test05):
            New tests.
            * testsuite/25_algorithms/minmax_element/constrained.cc (test02):
            Likewise.
    
    (cherry picked from commit cc9c94d43dcfa98436152af9c00f011e9dab25f6)

Diff:
---
 libstdc++-v3/include/bits/ranges_algo.h            | 113 ++++++++++++++++-----
 .../testsuite/25_algorithms/minmax/constrained.cc  |  42 ++++++++
 .../25_algorithms/minmax_element/constrained.cc    |  27 +++++
 3 files changed, 156 insertions(+), 26 deletions(-)

diff --git a/libstdc++-v3/include/bits/ranges_algo.h b/libstdc++-v3/include/bits/ranges_algo.h
index 92ab045aa6a..fe9ee31d8a2 100644
--- a/libstdc++-v3/include/bits/ranges_algo.h
+++ b/libstdc++-v3/include/bits/ranges_algo.h
@@ -3284,26 +3284,59 @@ namespace ranges
     template<input_range _Range, typename _Proj = identity,
 	     indirect_strict_weak_order<projected<iterator_t<_Range>, _Proj>>
 	       _Comp = ranges::less>
-      requires indirectly_copyable_storable<iterator_t<_Range>,
-      range_value_t<_Range>*>
+      requires indirectly_copyable_storable<iterator_t<_Range>, range_value_t<_Range>*>
       constexpr minmax_result<range_value_t<_Range>>
       operator()(_Range&& __r, _Comp __comp = {}, _Proj __proj = {}) const
       {
 	auto __first = ranges::begin(__r);
 	auto __last = ranges::end(__r);
 	__glibcxx_assert(__first != __last);
+	auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
 	minmax_result<range_value_t<_Range>> __result = {*__first, *__first};
+	if (++__first == __last)
+	  return __result;
+	else
+	  {
+	    // At this point __result.min == __result.max, so a single
+	    // comparison with the next element suffices.
+	    auto&& __val = *__first;
+	    if (__comp_proj(__val, __result.min))
+	      __result.min = std::forward<decltype(__val)>(__val);
+	    else
+	      __result.max = std::forward<decltype(__val)>(__val);
+	  }
 	while (++__first != __last)
 	  {
-	    auto __tmp = *__first;
-	    if (std::__invoke(__comp,
-			      std::__invoke(__proj, __tmp),
-			      std::__invoke(__proj, __result.min)))
-	      __result.min = std::move(__tmp);
-	    if (!(bool)std::__invoke(__comp,
-				     std::__invoke(__proj, __tmp),
-				     std::__invoke(__proj, __result.max)))
-	      __result.max = std::move(__tmp);
+	    // Now process two elements at a time so that we perform at most
+	    // 1 + 3*(N-2)/2 comparisons in total (each of the (N-2)/2
+	    // iterations of this loop performs three comparisons).
+	    range_value_t<_Range> __val1 = *__first;
+	    if (++__first == __last)
+	      {
+		// N is odd; in this final iteration, we perform at most two
+		// comparisons, for a total of 1 + 3*(N-3)/2 + 2 comparisons,
+		// which is not more than 3*N/2, as required.
+		if (__comp_proj(__val1, __result.min))
+		  __result.min = std::move(__val1);
+		else if (!__comp_proj(__val1, __result.max))
+		  __result.max = std::move(__val1);
+		break;
+	      }
+	    auto&& __val2 = *__first;
+	    if (!__comp_proj(__val2, __val1))
+	      {
+		if (__comp_proj(__val1, __result.min))
+		  __result.min = std::move(__val1);
+		if (!__comp_proj(__val2, __result.max))
+		  __result.max = std::forward<decltype(__val2)>(__val2);
+	      }
+	    else
+	      {
+		if (__comp_proj(__val2, __result.min))
+		  __result.min = std::forward<decltype(__val2)>(__val2);
+		if (!__comp_proj(__val1, __result.max))
+		  __result.max = std::move(__val1);
+	      }
 	  }
 	return __result;
       }
@@ -3409,21 +3442,50 @@ namespace ranges
       operator()(_Iter __first, _Sent __last,
 		 _Comp __comp = {}, _Proj __proj = {}) const
       {
-	if (__first == __last)
-	  return {__first, __first};
-
+	auto __comp_proj = __detail::__make_comp_proj(__comp, __proj);
 	minmax_element_result<_Iter> __result = {__first, __first};
-	auto __i = __first;
-	while (++__i != __last)
+	if (__first == __last || ++__first == __last)
+	  return __result;
+	else
 	  {
-	    if (std::__invoke(__comp,
-			      std::__invoke(__proj, *__i),
-			      std::__invoke(__proj, *__result.min)))
-	      __result.min = __i;
-	    if (!(bool)std::__invoke(__comp,
-				     std::__invoke(__proj, *__i),
-				     std::__invoke(__proj, *__result.max)))
-	      __result.max = __i;
+	    // At this point __result.min == __result.max, so a single
+	    // comparison with the next element suffices.
+	    if (__comp_proj(*__first, *__result.min))
+	      __result.min = __first;
+	    else
+	      __result.max = __first;
+	  }
+	while (++__first != __last)
+	  {
+	    // Now process two elements at a time so that we perform at most
+	    // 1 + 3*(N-2)/2 comparisons in total (each of the (N-2)/2
+	    // iterations of this loop performs three comparisons).
+	    auto __prev = __first;
+	    if (++__first == __last)
+	      {
+		// N is odd; in this final iteration, we perform at most two
+		// comparisons, for a total of 1 + 3*(N-3)/2 + 2 comparisons,
+		// which is not more than 3*N/2, as required.
+		if (__comp_proj(*__prev, *__result.min))
+		  __result.min = __prev;
+		else if (!__comp_proj(*__prev, *__result.max))
+		  __result.max = __prev;
+		break;
+	      }
+	    if (!__comp_proj(*__first, *__prev))
+	      {
+		if (__comp_proj(*__prev, *__result.min))
+		  __result.min = __prev;
+		if (!__comp_proj(*__first, *__result.max))
+		  __result.max = __first;
+	      }
+	    else
+	      {
+		if (__comp_proj(*__first, *__result.min))
+		  __result.min = __first;
+		if (!__comp_proj(*__prev, *__result.max))
+		  __result.max = __prev;
+	      }
 	  }
 	return __result;
       }
@@ -3750,8 +3812,7 @@ namespace ranges
 		  // i.e. we are shifting out at least half of the range.  In
 		  // this case we can safely perform the shift with a single
 		  // move.
-		  std::move(std::move(__first), std::move(__dest_head),
-			    std::move(__result));
+		  std::move(std::move(__first), std::move(__dest_head), __result);
 		  return __result;
 		}
 	      ++__dest_head;
diff --git a/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc b/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
index aa9364ab04c..af14152d345 100644
--- a/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
+++ b/libstdc++-v3/testsuite/25_algorithms/minmax/constrained.cc
@@ -19,6 +19,8 @@
 // { dg-do run { target c++2a } }
 
 #include <algorithm>
+#include <string>
+#include <vector>
 #include <testsuite_hooks.h>
 #include <testsuite_iterators.h>
 
@@ -89,10 +91,50 @@ test03()
 	  == res_t(1,4) );
 }
 
+void
+test04()
+{
+  // Verify we perform at most 3*N/2 applications of the comparison predicate.
+  static int counter;
+  struct counted_less
+  { bool operator()(int a, int b) { ++counter; return a < b; } };
+
+  ranges::minmax({1,2}, counted_less{});
+  VERIFY( counter == 1 );
+
+  counter = 0;
+  ranges::minmax({1,2,3}, counted_less{});
+  VERIFY( counter == 3 );
+
+  counter = 0;
+  ranges::minmax({1,2,3,4,5,6,7,8,9,10}, counted_less{});
+  VERIFY( counter <= 15 );
+
+  counter = 0;
+  ranges::minmax({10,9,8,7,6,5,4,3,2,1}, counted_less{});
+  VERIFY( counter <= 15 );
+}
+
+void
+test05()
+{
+  // PR libstdc++/100387
+  using namespace std::literals::string_literals;
+  auto comp = [](const auto& a, const auto& b) {
+    return a.size() == b.size() ? a.front() < b.front() : a.size() > b.size();
+  };
+  auto result = ranges::minmax({"b"s, "a"s}, comp);
+  VERIFY( result.min == "a"s && result.max == "b"s );
+  result = ranges::minmax({"c"s, "b"s, "a"s}, comp);
+  VERIFY( result.min == "a"s && result.max == "c"s );
+}
+
 int
 main()
 {
   test01();
   test02();
   test03();
+  test04();
+  test05();
 }
diff --git a/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc b/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
index 40019c43326..ece1f93d04e 100644
--- a/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
+++ b/libstdc++-v3/testsuite/25_algorithms/minmax_element/constrained.cc
@@ -61,8 +61,35 @@ test01()
   static_assert(ranges::minmax_element(y, y+3, {}, &X::i).max->j == 3);
 }
 
+void
+test02()
+{
+  // Verify we perform at most 3*N/2 applications of the comparison predicate.
+  static int counter;
+  struct counted_less
+  { bool operator()(int a, int b) { ++counter; return a < b; } };
+
+  int x[] = {1,2,3,4,5,6,7,8,9,10};
+  ranges::minmax_element(x, x+2, counted_less{});
+  VERIFY( counter == 1 );
+
+  counter = 0;
+  ranges::minmax_element(x, x+3, counted_less{});
+  VERIFY( counter == 3 );
+
+  counter = 0;
+  ranges::minmax_element(x, counted_less{});
+  VERIFY( counter <= 15 );
+
+  ranges::reverse(x);
+  counter = 0;
+  ranges::minmax_element(x, counted_less{});
+  VERIFY( counter <= 15 );
+}
+
 int
 main()
 {
   test01();
+  test02();
 }


More information about the Libstdc++-cvs mailing list