This is the mail archive of the
libstdc++@gcc.gnu.org
mailing list for the libstdc++ project.
[Patch] Change std::nth_element to O(N log N) worst case
- From: Paolo Carlini <pcarlini at suse dot de>
- To: libstdc++ <libstdc++ at gcc dot gnu dot org>
- Cc: Roger Sayle <roger at eyesopen dot com>
- Date: Sun, 27 Aug 2006 19:44:16 +0200
- Subject: [Patch] Change std::nth_element to O(N log N) worst case
Hi,
some background: a few days ago I have been contacted privately by
Roger, concerned by the worst case quadraticity of our implementation of
nth_element and full of ideas to attack the problem. I tried to reply to
the kind offer of collaboration pointing him to Musser' paper about
introsort and introselect, foundation of our sorting routines. Roger
quickly came back with a straightforward implementation of introselect,
which we were missing, + a few additional proposals to improve the
constant. For 4.2.0 I'd like to apply the below *minimal* patch, which I
tested carefully and adjusted a bit to our needs. But hopefully much
more to come!
I'd like to go ahead on monday, after Roger gets a chance to see how I
"censored" his work ;)
Paolo.
////////////////////
2006-08-28 Roger Sayle <roger@eyesopen.com>
Paolo Carlini <pcarlini@suse.de>
* include/bits/stl_algo.h (__heap_select, __introselect): New.
(nth_element): New implementation.
(partial_copy): Use __heap_select.
Index: include/bits/stl_algo.h
===================================================================
--- include/bits/stl_algo.h (revision 116401)
+++ include/bits/stl_algo.h (working copy)
@@ -2464,9 +2464,49 @@
/**
* @if maint
- * This is a helper function for the sort routine.
+ * This is a helper function for the sort routines.
* @endif
*/
+ template<typename _RandomAccessIterator>
+ void
+ __heap_select(_RandomAccessIterator __first,
+ _RandomAccessIterator __middle,
+ _RandomAccessIterator __last)
+ {
+ typedef typename iterator_traits<_RandomAccessIterator>::value_type
+ _ValueType;
+
+ std::make_heap(__first, __middle);
+ for (_RandomAccessIterator __i = __middle; __i < __last; ++__i)
+ if (*__i < *__first)
+ std::__pop_heap(__first, __middle, __i, _ValueType(*__i));
+ }
+
+ /**
+ * @if maint
+ * This is a helper function for the sort routines.
+ * @endif
+ */
+ template<typename _RandomAccessIterator, typename _Compare>
+ void
+ __heap_select(_RandomAccessIterator __first,
+ _RandomAccessIterator __middle,
+ _RandomAccessIterator __last, _Compare __comp)
+ {
+ typedef typename iterator_traits<_RandomAccessIterator>::value_type
+ _ValueType;
+
+ std::make_heap(__first, __middle, __comp);
+ for (_RandomAccessIterator __i = __middle; __i < __last; ++__i)
+ if (__comp(*__i, *__first))
+ std::__pop_heap(__first, __middle, __i, _ValueType(*__i), __comp);
+ }
+
+ /**
+ * @if maint
+ * This is a helper function for the sort routines.
+ * @endif
+ */
template<typename _Size>
inline _Size
__lg(_Size __n)
@@ -2493,7 +2533,7 @@
* the range @p [middle,last) then @p *j<*i and @p *k<*i are both false.
*/
template<typename _RandomAccessIterator>
- void
+ inline void
partial_sort(_RandomAccessIterator __first,
_RandomAccessIterator __middle,
_RandomAccessIterator __last)
@@ -2508,10 +2548,7 @@
__glibcxx_requires_valid_range(__first, __middle);
__glibcxx_requires_valid_range(__middle, __last);
- std::make_heap(__first, __middle);
- for (_RandomAccessIterator __i = __middle; __i < __last; ++__i)
- if (*__i < *__first)
- std::__pop_heap(__first, __middle, __i, _ValueType(*__i));
+ std::__heap_select(__first, __middle, __last);
std::sort_heap(__first, __middle);
}
@@ -2534,7 +2571,7 @@
* are both false.
*/
template<typename _RandomAccessIterator, typename _Compare>
- void
+ inline void
partial_sort(_RandomAccessIterator __first,
_RandomAccessIterator __middle,
_RandomAccessIterator __last,
@@ -2551,10 +2588,7 @@
__glibcxx_requires_valid_range(__first, __middle);
__glibcxx_requires_valid_range(__middle, __last);
- std::make_heap(__first, __middle, __comp);
- for (_RandomAccessIterator __i = __middle; __i < __last; ++__i)
- if (__comp(*__i, *__first))
- std::__pop_heap(__first, __middle, __i, _ValueType(*__i), __comp);
+ std::__heap_select(__first, __middle, __last, __comp);
std::sort_heap(__first, __middle, __comp);
}
@@ -2792,7 +2826,8 @@
if (__first != __last)
{
- std::__introsort_loop(__first, __last, __lg(__last - __first) * 2);
+ std::__introsort_loop(__first, __last,
+ std::__lg(__last - __first) * 2);
std::__final_insertion_sort(__first, __last);
}
}
@@ -2828,8 +2863,8 @@
if (__first != __last)
{
- std::__introsort_loop(__first, __last, __lg(__last - __first) * 2,
- __comp);
+ std::__introsort_loop(__first, __last,
+ std::__lg(__last - __first) * 2, __comp);
std::__final_insertion_sort(__first, __last, __comp);
}
}
@@ -3904,6 +3939,75 @@
_DistanceType(__buf.size()), __comp);
}
+
+ template<typename _RandomAccessIterator, typename _Size>
+ void
+ __introselect(_RandomAccessIterator __first, _RandomAccessIterator __nth,
+ _RandomAccessIterator __last, _Size __depth_limit)
+ {
+ typedef typename iterator_traits<_RandomAccessIterator>::value_type
+ _ValueType;
+
+ while (__last - __first > 3)
+ {
+ if (__depth_limit == 0)
+ {
+ std::__heap_select(__first, __nth, __last);
+ return;
+ }
+ --__depth_limit;
+ _RandomAccessIterator __cut =
+ std::__unguarded_partition(__first, __last,
+ _ValueType(std::__median(*__first,
+ *(__first
+ + (__last
+ - __first)
+ / 2),
+ *(__last
+ - 1))));
+ if (__cut <= __nth)
+ __first = __cut;
+ else
+ __last = __cut;
+ }
+ std::__insertion_sort(__first, __last);
+ }
+
+ template<typename _RandomAccessIterator, typename _Size, typename _Compare>
+ void
+ __introselect(_RandomAccessIterator __first, _RandomAccessIterator __nth,
+ _RandomAccessIterator __last, _Size __depth_limit,
+ _Compare __comp)
+ {
+ typedef typename iterator_traits<_RandomAccessIterator>::value_type
+ _ValueType;
+
+ while (__last - __first > 3)
+ {
+ if (__depth_limit == 0)
+ {
+ std::__heap_select(__first, __nth, __last, __comp);
+ return;
+ }
+ --__depth_limit;
+ _RandomAccessIterator __cut =
+ std::__unguarded_partition(__first, __last,
+ _ValueType(std::__median(*__first,
+ *(__first
+ + (__last
+ - __first)
+ / 2),
+ *(__last - 1),
+ __comp)),
+ __comp);
+ if (__cut <= __nth)
+ __first = __cut;
+ else
+ __last = __cut;
+ }
+ std::__insertion_sort(__first, __last, __comp);
+ }
+
/**
* @brief Sort a sequence just enough to find a particular position.
* @param first An iterator.
@@ -3920,9 +4024,8 @@
* holds that @p *j<*i is false.
*/
template<typename _RandomAccessIterator>
- void
- nth_element(_RandomAccessIterator __first,
- _RandomAccessIterator __nth,
+ inline void
+ nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth,
_RandomAccessIterator __last)
{
typedef typename iterator_traits<_RandomAccessIterator>::value_type
@@ -3935,23 +4038,9 @@
__glibcxx_requires_valid_range(__first, __nth);
__glibcxx_requires_valid_range(__nth, __last);
- while (__last - __first > 3)
- {
- _RandomAccessIterator __cut =
- std::__unguarded_partition(__first, __last,
- _ValueType(std::__median(*__first,
- *(__first
- + (__last
- - __first)
- / 2),
- *(__last
- - 1))));
- if (__cut <= __nth)
- __first = __cut;
- else
- __last = __cut;
- }
- std::__insertion_sort(__first, __last);
+ if (__first != __last)
+ std::__introselect(__first, __nth, __last,
+ std::__lg(__last - __first) * 2);
}
/**
@@ -3971,11 +4060,9 @@
* holds that @p comp(*j,*i) is false.
*/
template<typename _RandomAccessIterator, typename _Compare>
- void
- nth_element(_RandomAccessIterator __first,
- _RandomAccessIterator __nth,
- _RandomAccessIterator __last,
- _Compare __comp)
+ inline void
+ nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth,
+ _RandomAccessIterator __last, _Compare __comp)
{
typedef typename iterator_traits<_RandomAccessIterator>::value_type
_ValueType;
@@ -3988,23 +4075,9 @@
__glibcxx_requires_valid_range(__first, __nth);
__glibcxx_requires_valid_range(__nth, __last);
- while (__last - __first > 3)
- {
- _RandomAccessIterator __cut =
- std::__unguarded_partition(__first, __last,
- _ValueType(std::__median(*__first,
- *(__first
- + (__last
- - __first)
- / 2),
- *(__last - 1),
- __comp)), __comp);
- if (__cut <= __nth)
- __first = __cut;
- else
- __last = __cut;
- }
- std::__insertion_sort(__first, __last, __comp);
+ if (__first != __last)
+ std::__introselect(__first, __nth, __last,
+ std::__lg(__last - __first) * 2, __comp);
}
/**
Index: testsuite/performance/25_algorithms/nth_element_worst_case.cc
===================================================================
--- testsuite/performance/25_algorithms/nth_element_worst_case.cc (revision 0)
+++ testsuite/performance/25_algorithms/nth_element_worst_case.cc (revision 0)
@@ -0,0 +1,62 @@
+// Copyright (C) 2006 Free Software Foundation, Inc.
+//
+// This file is part of the GNU ISO C++ Library. This library is free
+// software; you can redistribute it and/or modify it under the
+// terms of the GNU General Public License as published by the
+// Free Software Foundation; either version 2, or (at your option)
+// any later version.
+
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+
+// You should have received a copy of the GNU General Public License along
+// with this library; see the file COPYING. If not, write to the Free
+// Software Foundation, 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
+// USA.
+
+// As a special exception, you may use this file as part of a free software
+// library without restriction. Specifically, if other files instantiate
+// templates or use macros or inline functions from this file, or you compile
+// this file and link it with other files to produce an executable, this
+// file does not by itself cause the resulting executable to be covered by
+// the GNU General Public License. This exception does not however
+// invalidate any other reasons why the executable file might be covered by
+// the GNU General Public License.
+
+#include <vector>
+#include <algorithm>
+#include <testsuite_performance.h>
+
+int main()
+{
+ using namespace __gnu_test;
+
+ time_counter time;
+ resource_counter resource;
+
+ const int max_size = 8192;
+
+ std::vector<int> v[max_size];
+
+ for (int i = 0; i < max_size; ++i)
+ {
+ for (int j = 0; j < i; j += 4)
+ {
+ v[i].push_back(j / 2);
+ v[i].push_back((i - 2) - (j / 2));
+ }
+
+ for (int j = 1; j < i; j += 2)
+ v[i].push_back(j);
+ }
+
+ start_counters(time, resource);
+ for (int i = 0; i < max_size; ++i)
+ std::nth_element(v[i].begin(), v[i].begin() + i, v[i].end());
+ stop_counters(time, resource);
+ report_performance(__FILE__, "", time, resource);
+
+ return 0;
+}