multiway_mergesort.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 
00003 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
00004 //
00005 // This file is part of the GNU ISO C++ Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the terms
00007 // of the GNU General Public License as published by the Free Software
00008 // Foundation; either version 3, or (at your option) any later
00009 // version.
00010 
00011 // This library is distributed in the hope that it will be useful, but
00012 // WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014 // General Public License for more details.
00015 
00016 // Under Section 7 of GPL version 3, you are granted additional
00017 // permissions described in the GCC Runtime Library Exception, version
00018 // 3.1, as published by the Free Software Foundation.
00019 
00020 // You should have received a copy of the GNU General Public License and
00021 // a copy of the GCC Runtime Library Exception along with this program;
00022 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
00023 // <http://www.gnu.org/licenses/>.
00024 
00025 /** @file parallel/multiway_mergesort.h
00026  *  @brief Parallel multiway merge sort.
00027  *  This file is a GNU parallel extension to the Standard C++ Library.
00028  */
00029 
00030 // Written by Johannes Singler.
00031 
00032 #ifndef _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H
00033 #define _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H 1
00034 
00035 #include <vector>
00036 
00037 #include <parallel/basic_iterator.h>
00038 #include <bits/stl_algo.h>
00039 #include <parallel/parallel.h>
00040 #include <parallel/multiway_merge.h>
00041 
00042 namespace __gnu_parallel
00043 {
00044 
00045 /** @brief Subsequence description. */
00046 template<typename _DifferenceTp>
00047   struct Piece
00048   {
00049     typedef _DifferenceTp difference_type;
00050 
00051     /** @brief Begin of subsequence. */
00052     difference_type begin;
00053 
00054     /** @brief End of subsequence. */
00055     difference_type end;
00056   };
00057 
00058 /** @brief Data accessed by all threads.
00059   *
00060   *  PMWMS = parallel multiway mergesort */
00061 template<typename RandomAccessIterator>
00062   struct PMWMSSortingData
00063   {
00064     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00065     typedef typename traits_type::value_type value_type;
00066     typedef typename traits_type::difference_type difference_type;
00067 
00068     /** @brief Number of threads involved. */
00069     thread_index_t num_threads;
00070 
00071     /** @brief Input begin. */
00072     RandomAccessIterator source;
00073 
00074     /** @brief Start indices, per thread. */
00075     difference_type* starts;
00076 
00077     /** @brief Storage in which to sort. */
00078     value_type** temporary;
00079 
00080     /** @brief Samples. */
00081     value_type* samples;
00082 
00083     /** @brief Offsets to add to the found positions. */
00084     difference_type* offsets;
00085 
00086     /** @brief Pieces of data to merge @c [thread][sequence] */
00087     std::vector<Piece<difference_type> >* pieces;
00088 };
00089 
00090 /**
00091   *  @brief Select samples from a sequence.
00092   *  @param sd Pointer to algorithm data. Result will be placed in
00093   *  @c sd->samples.
00094   *  @param num_samples Number of samples to select.
00095   */
00096 template<typename RandomAccessIterator, typename _DifferenceTp>
00097   void 
00098   determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
00099                     _DifferenceTp num_samples)
00100   {
00101     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00102     typedef typename traits_type::value_type value_type;
00103     typedef _DifferenceTp difference_type;
00104 
00105     thread_index_t iam = omp_get_thread_num();
00106 
00107     difference_type* es = new difference_type[num_samples + 2];
00108 
00109     equally_split(sd->starts[iam + 1] - sd->starts[iam], 
00110                   num_samples + 1, es);
00111 
00112     for (difference_type i = 0; i < num_samples; ++i)
00113       ::new(&(sd->samples[iam * num_samples + i]))
00114       value_type(sd->source[sd->starts[iam] + es[i + 1]]);
00115 
00116     delete[] es;
00117   }
00118 
00119 /** @brief Split consistently. */
00120 template<bool exact, typename RandomAccessIterator,
00121           typename Comparator, typename SortingPlacesIterator>
00122   struct split_consistently
00123   {
00124   };
00125 
00126 /** @brief Split by exact splitting. */
00127 template<typename RandomAccessIterator, typename Comparator,
00128           typename SortingPlacesIterator>
00129   struct split_consistently
00130     <true, RandomAccessIterator, Comparator, SortingPlacesIterator>
00131   {
00132     void operator()(
00133       const thread_index_t iam,
00134       PMWMSSortingData<RandomAccessIterator>* sd,
00135       Comparator& comp,
00136       const typename
00137         std::iterator_traits<RandomAccessIterator>::difference_type
00138           num_samples)
00139       const
00140   {
00141 #   pragma omp barrier
00142 
00143     std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
00144         seqs(sd->num_threads);
00145     for (thread_index_t s = 0; s < sd->num_threads; s++)
00146       seqs[s] = std::make_pair(sd->temporary[s],
00147                                 sd->temporary[s]
00148                                     + (sd->starts[s + 1] - sd->starts[s]));
00149 
00150     std::vector<SortingPlacesIterator> offsets(sd->num_threads);
00151 
00152     // if not last thread
00153     if (iam < sd->num_threads - 1)
00154       multiseq_partition(seqs.begin(), seqs.end(),
00155                           sd->starts[iam + 1], offsets.begin(), comp);
00156 
00157     for (int seq = 0; seq < sd->num_threads; seq++)
00158       {
00159         // for each sequence
00160         if (iam < (sd->num_threads - 1))
00161           sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
00162         else
00163           // very end of this sequence
00164           sd->pieces[iam][seq].end =
00165               sd->starts[seq + 1] - sd->starts[seq];
00166       }
00167 
00168 #   pragma omp barrier
00169 
00170     for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
00171       {
00172         // For each sequence.
00173         if (iam > 0)
00174           sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
00175         else
00176           // Absolute beginning.
00177           sd->pieces[iam][seq].begin = 0;
00178       }
00179   }   
00180   };
00181 
00182 /** @brief Split by sampling. */ 
00183 template<typename RandomAccessIterator, typename Comparator,
00184           typename SortingPlacesIterator>
00185   struct split_consistently<false, RandomAccessIterator, Comparator,
00186                              SortingPlacesIterator>
00187   {
00188     void operator()(
00189         const thread_index_t iam,
00190         PMWMSSortingData<RandomAccessIterator>* sd,
00191         Comparator& comp,
00192         const typename
00193           std::iterator_traits<RandomAccessIterator>::difference_type
00194             num_samples)
00195         const
00196     {
00197       typedef std::iterator_traits<RandomAccessIterator> traits_type;
00198       typedef typename traits_type::value_type value_type;
00199       typedef typename traits_type::difference_type difference_type;
00200 
00201       determine_samples(sd, num_samples);
00202 
00203 #     pragma omp barrier
00204 
00205 #     pragma omp single
00206       __gnu_sequential::sort(sd->samples,
00207                              sd->samples + (num_samples * sd->num_threads),
00208                              comp);
00209 
00210 #     pragma omp barrier
00211 
00212       for (thread_index_t s = 0; s < sd->num_threads; ++s)
00213         {
00214           // For each sequence.
00215           if (num_samples * iam > 0)
00216             sd->pieces[iam][s].begin =
00217                 std::lower_bound(sd->temporary[s],
00218                     sd->temporary[s]
00219                         + (sd->starts[s + 1] - sd->starts[s]),
00220                     sd->samples[num_samples * iam],
00221                     comp)
00222                 - sd->temporary[s];
00223           else
00224             // Absolute beginning.
00225             sd->pieces[iam][s].begin = 0;
00226 
00227           if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
00228             sd->pieces[iam][s].end =
00229                 std::lower_bound(sd->temporary[s],
00230                         sd->temporary[s]
00231                             + (sd->starts[s + 1] - sd->starts[s]),
00232                         sd->samples[num_samples * (iam + 1)],
00233                         comp)
00234                 - sd->temporary[s];
00235           else
00236             // Absolute end.
00237             sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
00238         }
00239     }
00240   };
00241   
00242 template<bool stable, typename RandomAccessIterator, typename Comparator>
00243   struct possibly_stable_sort
00244   {
00245   };
00246 
00247 template<typename RandomAccessIterator, typename Comparator>
00248   struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
00249   {
00250     void operator()(const RandomAccessIterator& begin,
00251                      const RandomAccessIterator& end, Comparator& comp) const
00252     {
00253       __gnu_sequential::stable_sort(begin, end, comp); 
00254     }
00255   };
00256 
00257 template<typename RandomAccessIterator, typename Comparator>
00258   struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
00259   {
00260     void operator()(const RandomAccessIterator begin,
00261                      const RandomAccessIterator end, Comparator& comp) const
00262     {
00263       __gnu_sequential::sort(begin, end, comp); 
00264     }
00265   };
00266 
00267 template<bool stable, typename SeqRandomAccessIterator,
00268           typename RandomAccessIterator, typename Comparator,
00269           typename DiffType>
00270   struct possibly_stable_multiway_merge
00271   {
00272   };
00273 
00274 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
00275           typename Comparator, typename DiffType>
00276   struct possibly_stable_multiway_merge
00277     <true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
00278     DiffType>
00279   {
00280     void operator()(const SeqRandomAccessIterator& seqs_begin,
00281                       const SeqRandomAccessIterator& seqs_end,
00282                       const RandomAccessIterator& target,
00283                       Comparator& comp,
00284                       DiffType length_am) const
00285     {
00286       stable_multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
00287                        sequential_tag());
00288     }
00289   };
00290 
00291 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
00292           typename Comparator, typename DiffType>
00293   struct possibly_stable_multiway_merge
00294     <false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
00295     DiffType>
00296   {
00297     void operator()(const SeqRandomAccessIterator& seqs_begin,
00298                       const SeqRandomAccessIterator& seqs_end,
00299                       const RandomAccessIterator& target,
00300                       Comparator& comp,
00301                       DiffType length_am) const
00302     {
00303       multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
00304                        sequential_tag());
00305     }
00306   };
00307 
00308 /** @brief PMWMS code executed by each thread.
00309   *  @param sd Pointer to algorithm data.
00310   *  @param comp Comparator.
00311   */
00312 template<bool stable, bool exact, typename RandomAccessIterator,
00313           typename Comparator>
00314   void 
00315   parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
00316                         Comparator& comp)
00317   {
00318     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00319     typedef typename traits_type::value_type value_type;
00320     typedef typename traits_type::difference_type difference_type;
00321 
00322     thread_index_t iam = omp_get_thread_num();
00323 
00324     // Length of this thread's chunk, before merging.
00325     difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
00326 
00327     // Sort in temporary storage, leave space for sentinel.
00328 
00329     typedef value_type* SortingPlacesIterator;
00330 
00331     sd->temporary[iam] =
00332         static_cast<value_type*>(
00333         ::operator new(sizeof(value_type) * (length_local + 1)));
00334 
00335     // Copy there.
00336     std::uninitialized_copy(sd->source + sd->starts[iam],
00337                             sd->source + sd->starts[iam] + length_local,
00338                             sd->temporary[iam]);
00339 
00340     possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
00341         (sd->temporary[iam], sd->temporary[iam] + length_local, comp);
00342 
00343     // Invariant: locally sorted subsequence in sd->temporary[iam],
00344     // sd->temporary[iam] + length_local.
00345 
00346     // No barrier here: Synchronization is done by the splitting routine.
00347 
00348     difference_type num_samples =
00349         _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
00350     split_consistently
00351       <exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
00352         (iam, sd, comp, num_samples);
00353 
00354     // Offset from target begin, length after merging.
00355     difference_type offset = 0, length_am = 0;
00356     for (thread_index_t s = 0; s < sd->num_threads; s++)
00357       {
00358         length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
00359         offset += sd->pieces[iam][s].begin;
00360       }
00361 
00362     typedef std::vector<
00363       std::pair<SortingPlacesIterator, SortingPlacesIterator> >
00364         seq_vector_type;
00365     seq_vector_type seqs(sd->num_threads);
00366 
00367     for (int s = 0; s < sd->num_threads; ++s)
00368       {
00369         seqs[s] =
00370           std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
00371         sd->temporary[s] + sd->pieces[iam][s].end);
00372       }
00373 
00374     possibly_stable_multiway_merge<
00375         stable,
00376         typename seq_vector_type::iterator,
00377         RandomAccessIterator,
00378         Comparator, difference_type>()
00379           (seqs.begin(), seqs.end(),
00380            sd->source + offset, comp,
00381            length_am);
00382 
00383 #   pragma omp barrier
00384 
00385     ::operator delete(sd->temporary[iam]);
00386   }
00387 
00388 /** @brief PMWMS main call.
00389   *  @param begin Begin iterator of sequence.
00390   *  @param end End iterator of sequence.
00391   *  @param comp Comparator.
00392   *  @param n Length of sequence.
00393   *  @param num_threads Number of threads to use.
00394   */
00395 template<bool stable, bool exact, typename RandomAccessIterator,
00396            typename Comparator>
00397   void
00398   parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
00399                      Comparator comp,
00400                      thread_index_t num_threads)
00401   {
00402     _GLIBCXX_CALL(end - begin)
00403 
00404     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00405     typedef typename traits_type::value_type value_type;
00406     typedef typename traits_type::difference_type difference_type;
00407 
00408     difference_type n = end - begin;
00409 
00410     if (n <= 1)
00411       return;
00412 
00413     // at least one element per thread
00414     if (num_threads > n)
00415       num_threads = static_cast<thread_index_t>(n);
00416 
00417     // shared variables
00418     PMWMSSortingData<RandomAccessIterator> sd;
00419     difference_type* starts;
00420 
00421 #   pragma omp parallel num_threads(num_threads)
00422       {
00423         num_threads = omp_get_num_threads();  //no more threads than requested
00424 
00425 #       pragma omp single
00426           {
00427             sd.num_threads = num_threads;
00428             sd.source = begin;
00429 
00430             sd.temporary = new value_type*[num_threads];
00431 
00432             if (!exact)
00433               {
00434                 difference_type size =
00435                     (_Settings::get().sort_mwms_oversampling * num_threads - 1)
00436                         * num_threads;
00437                 sd.samples = static_cast<value_type*>(
00438                               ::operator new(size * sizeof(value_type)));
00439               }
00440             else
00441               sd.samples = NULL;
00442 
00443             sd.offsets = new difference_type[num_threads - 1];
00444             sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
00445             for (int s = 0; s < num_threads; ++s)
00446               sd.pieces[s].resize(num_threads);
00447             starts = sd.starts = new difference_type[num_threads + 1];
00448 
00449             difference_type chunk_length = n / num_threads;
00450             difference_type split = n % num_threads;
00451             difference_type pos = 0;
00452             for (int i = 0; i < num_threads; ++i)
00453               {
00454                 starts[i] = pos;
00455                 pos += (i < split) ? (chunk_length + 1) : chunk_length;
00456               }
00457             starts[num_threads] = pos;
00458           } //single
00459 
00460         // Now sort in parallel.
00461         parallel_sort_mwms_pu<stable, exact>(&sd, comp);
00462       } //parallel
00463 
00464     delete[] starts;
00465     delete[] sd.temporary;
00466 
00467     if (!exact)
00468       ::operator delete(sd.samples);
00469 
00470     delete[] sd.offsets;
00471     delete[] sd.pieces;
00472   }
00473 } //namespace __gnu_parallel
00474 
00475 #endif /* _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H */

Generated on Tue Apr 21 13:13:29 2009 for libstdc++ by  doxygen 1.5.8