00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 #ifndef _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H
00043 #define _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H 1
00044
00045 #include <parallel/basic_iterator.h>
00046 #include <bits/stl_algo.h>
00047
00048 #include <parallel/settings.h>
00049 #include <parallel/partition.h>
00050 #include <parallel/random_number.h>
00051 #include <parallel/queue.h>
00052 #include <functional>
00053
00054 #if _GLIBCXX_ASSERTIONS
00055 #include <parallel/checkers.h>
00056 #endif
00057
00058 namespace __gnu_parallel
00059 {
00060
00061 template<typename RandomAccessIterator>
00062 struct QSBThreadLocal
00063 {
00064 typedef std::iterator_traits<RandomAccessIterator> traits_type;
00065 typedef typename traits_type::difference_type difference_type;
00066
00067
00068
00069 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00070
00071
00072 Piece initial;
00073
00074
00075 RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
00076
00077
00078 thread_index_t num_threads;
00079
00080
00081 volatile difference_type* elements_leftover;
00082
00083
00084 Piece global;
00085
00086
00087
00088 QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
00089 };
00090
00091
00092
00093
00094
00095
00096
00097
00098 template<typename RandomAccessIterator, typename Comparator>
00099 typename std::iterator_traits<RandomAccessIterator>::difference_type
00100 qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
00101 Comparator comp, thread_index_t num_threads)
00102 {
00103 _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
00104
00105 typedef std::iterator_traits<RandomAccessIterator> traits_type;
00106 typedef typename traits_type::value_type value_type;
00107 typedef typename traits_type::difference_type difference_type;
00108
00109 RandomAccessIterator pivot_pos =
00110 median_of_three_iterators(begin, begin + (end - begin) / 2,
00111 end - 1, comp);
00112
00113 #if defined(_GLIBCXX_ASSERTIONS)
00114
00115 difference_type n = end - begin;
00116
00117 _GLIBCXX_PARALLEL_ASSERT(
00118 (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
00119 || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
00120 || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
00121 || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
00122 || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
00123 || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
00124 #endif
00125
00126
00127 if (pivot_pos != (end - 1))
00128 std::swap(*pivot_pos, *(end - 1));
00129 pivot_pos = end - 1;
00130
00131 __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
00132 pred(comp, *pivot_pos);
00133
00134
00135 difference_type split_pos = parallel_partition(
00136 begin, end - 1, pred, num_threads);
00137
00138
00139 std::swap(*(begin + split_pos), *pivot_pos);
00140 pivot_pos = begin + split_pos;
00141
00142 #if _GLIBCXX_ASSERTIONS
00143 RandomAccessIterator r;
00144 for (r = begin; r != pivot_pos; ++r)
00145 _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
00146 for (; r != end; ++r)
00147 _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
00148 #endif
00149
00150 return split_pos;
00151 }
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161 template<typename RandomAccessIterator, typename Comparator>
00162 void
00163 qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
00164 RandomAccessIterator begin, RandomAccessIterator end,
00165 Comparator comp,
00166 thread_index_t iam, thread_index_t num_threads,
00167 bool parent_wait)
00168 {
00169 typedef std::iterator_traits<RandomAccessIterator> traits_type;
00170 typedef typename traits_type::value_type value_type;
00171 typedef typename traits_type::difference_type difference_type;
00172
00173 difference_type n = end - begin;
00174
00175 if (num_threads <= 1 || n <= 1)
00176 {
00177 tls[iam]->initial.first = begin;
00178 tls[iam]->initial.second = end;
00179
00180 qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
00181
00182 return;
00183 }
00184
00185
00186 difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
00187
00188 #if _GLIBCXX_ASSERTIONS
00189 _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
00190 #endif
00191
00192 thread_index_t num_threads_leftside =
00193 std::max<thread_index_t>(1, std::min<thread_index_t>(
00194 num_threads - 1, split_pos * num_threads / n));
00195
00196 # pragma omp atomic
00197 *tls[iam]->elements_leftover -= (difference_type)1;
00198
00199
00200 # pragma omp parallel num_threads(2)
00201 {
00202 bool wait;
00203 if(omp_get_num_threads() < 2)
00204 wait = false;
00205 else
00206 wait = parent_wait;
00207
00208 # pragma omp sections
00209 {
00210 # pragma omp section
00211 {
00212 qsb_conquer(tls, begin, begin + split_pos, comp,
00213 iam,
00214 num_threads_leftside,
00215 wait);
00216 wait = parent_wait;
00217 }
00218
00219 # pragma omp section
00220 {
00221 qsb_conquer(tls, begin + split_pos + 1, end, comp,
00222 iam + num_threads_leftside,
00223 num_threads - num_threads_leftside,
00224 wait);
00225 wait = parent_wait;
00226 }
00227 }
00228 }
00229 }
00230
00231
00232
00233
00234
00235
00236
00237 template<typename RandomAccessIterator, typename Comparator>
00238 void
00239 qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
00240 Comparator& comp, int iam, bool wait)
00241 {
00242 typedef std::iterator_traits<RandomAccessIterator> traits_type;
00243 typedef typename traits_type::value_type value_type;
00244 typedef typename traits_type::difference_type difference_type;
00245 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00246
00247 QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
00248
00249 difference_type base_case_n =
00250 _Settings::get().sort_qsb_base_case_maximal_n;
00251 if (base_case_n < 2)
00252 base_case_n = 2;
00253 thread_index_t num_threads = tl.num_threads;
00254
00255
00256 random_number rng(iam + 1);
00257
00258 Piece current = tl.initial;
00259
00260 difference_type elements_done = 0;
00261 #if _GLIBCXX_ASSERTIONS
00262 difference_type total_elements_done = 0;
00263 #endif
00264
00265 for (;;)
00266 {
00267
00268 RandomAccessIterator begin = current.first, end = current.second;
00269 difference_type n = end - begin;
00270
00271 if (n > base_case_n)
00272 {
00273
00274 RandomAccessIterator pivot_pos = begin + rng(n);
00275
00276
00277 if (pivot_pos != (end - 1))
00278 std::swap(*pivot_pos, *(end - 1));
00279 pivot_pos = end - 1;
00280
00281 __gnu_parallel::binder2nd
00282 <Comparator, value_type, value_type, bool>
00283 pred(comp, *pivot_pos);
00284
00285
00286 RandomAccessIterator split_pos1, split_pos2;
00287 split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
00288
00289
00290 #if _GLIBCXX_ASSERTIONS
00291 _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
00292 #endif
00293
00294 if (split_pos1 != pivot_pos)
00295 std::swap(*split_pos1, *pivot_pos);
00296 pivot_pos = split_pos1;
00297
00298
00299 if ((split_pos1 + 1 - begin) < (n >> 7)
00300 || (end - split_pos1) < (n >> 7))
00301 {
00302
00303
00304 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
00305 <Comparator, value_type, value_type, bool>, value_type>
00306 pred(__gnu_parallel::binder1st
00307 <Comparator, value_type, value_type, bool>(comp,
00308 *pivot_pos));
00309
00310
00311 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
00312 end, pred);
00313 }
00314 else
00315
00316 split_pos2 = split_pos1 + 1;
00317
00318
00319 elements_done += (split_pos2 - split_pos1);
00320 #if _GLIBCXX_ASSERTIONS
00321 total_elements_done += (split_pos2 - split_pos1);
00322 #endif
00323
00324 if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
00325 {
00326
00327 if ((split_pos2) != end)
00328 tl.leftover_parts.push_front(std::make_pair(split_pos2,
00329 end));
00330
00331
00332 current.second = split_pos1;
00333 continue;
00334 }
00335 else
00336 {
00337
00338 if (begin != split_pos1)
00339 tl.leftover_parts.push_front(std::make_pair(begin,
00340 split_pos1));
00341
00342 current.first = split_pos2;
00343
00344 continue;
00345 }
00346 }
00347 else
00348 {
00349 __gnu_sequential::sort(begin, end, comp);
00350 elements_done += n;
00351 #if _GLIBCXX_ASSERTIONS
00352 total_elements_done += n;
00353 #endif
00354
00355
00356 if (tl.leftover_parts.pop_front(current))
00357 continue;
00358
00359 # pragma omp atomic
00360 *tl.elements_leftover -= elements_done;
00361
00362 elements_done = 0;
00363
00364 #if _GLIBCXX_ASSERTIONS
00365 double search_start = omp_get_wtime();
00366 #endif
00367
00368
00369 bool successfully_stolen = false;
00370 while (wait && *tl.elements_leftover > 0 && !successfully_stolen
00371 #if _GLIBCXX_ASSERTIONS
00372
00373 && (omp_get_wtime() < (search_start + 1.0))
00374 #endif
00375 )
00376 {
00377 thread_index_t victim;
00378 victim = rng(num_threads);
00379
00380
00381 successfully_stolen = (victim != iam)
00382 && tls[victim]->leftover_parts.pop_back(current);
00383 if (!successfully_stolen)
00384 yield();
00385 #if !defined(__ICC) && !defined(__ECC)
00386 # pragma omp flush
00387 #endif
00388 }
00389
00390 #if _GLIBCXX_ASSERTIONS
00391 if (omp_get_wtime() >= (search_start + 1.0))
00392 {
00393 sleep(1);
00394 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
00395 < (search_start + 1.0));
00396 }
00397 #endif
00398 if (!successfully_stolen)
00399 {
00400 #if _GLIBCXX_ASSERTIONS
00401 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
00402 #endif
00403 return;
00404 }
00405 }
00406 }
00407 }
00408
00409
00410
00411
00412
00413
00414
00415
00416 template<typename RandomAccessIterator, typename Comparator>
00417 void
00418 parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
00419 Comparator comp,
00420 thread_index_t num_threads)
00421 {
00422 _GLIBCXX_CALL(end - begin)
00423
00424 typedef std::iterator_traits<RandomAccessIterator> traits_type;
00425 typedef typename traits_type::value_type value_type;
00426 typedef typename traits_type::difference_type difference_type;
00427 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00428
00429 typedef QSBThreadLocal<RandomAccessIterator> tls_type;
00430
00431 difference_type n = end - begin;
00432
00433 if (n <= 1)
00434 return;
00435
00436
00437 if (num_threads > n)
00438 num_threads = static_cast<thread_index_t>(n);
00439
00440
00441 tls_type** tls = new tls_type*[num_threads];
00442 difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
00443 for (thread_index_t t = 0; t < num_threads; ++t)
00444 tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
00445
00446
00447
00448
00449
00450 volatile difference_type elements_leftover = n;
00451 for (int i = 0; i < num_threads; ++i)
00452 {
00453 tls[i]->elements_leftover = &elements_leftover;
00454 tls[i]->num_threads = num_threads;
00455 tls[i]->global = std::make_pair(begin, end);
00456
00457
00458 tls[i]->initial = std::make_pair(end, end);
00459 }
00460
00461
00462 qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
00463
00464 #if _GLIBCXX_ASSERTIONS
00465
00466 Piece dummy;
00467 for (int i = 1; i < num_threads; ++i)
00468 _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
00469 #endif
00470
00471 for (int i = 0; i < num_threads; ++i)
00472 delete tls[i];
00473 delete[] tls;
00474 }
00475 }
00476
00477 #endif