blob: 11dd885c8aeca42eeefc5aa171db4acfa22112ad [file] [log] [blame]
// -*- C++ -*-
// Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
//
// This file is part of the GNU ISO C++ Library. This library is free
// software; you can redistribute it and/or modify it under the terms
// of the GNU General Public License as published by the Free Software
// Foundation; either version 3, or (at your option) any later
// version.
// This library is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// General Public License for more details.
// Under Section 7 of GPL version 3, you are granted additional
// permissions described in the GCC Runtime Library Exception, version
// 3.1, as published by the Free Software Foundation.
// You should have received a copy of the GNU General Public License and
// a copy of the GCC Runtime Library Exception along with this program;
// see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
// <http://www.gnu.org/licenses/>.
/** @file parallel/multiway_mergesort.h
* @brief Parallel multiway merge sort.
* This file is a GNU parallel extension to the Standard C++ Library.
*/
// Written by Johannes Singler.
#ifndef _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H
#define _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H 1
#include <vector>
#include <parallel/basic_iterator.h>
#include <bits/stl_algo.h>
#include <parallel/parallel.h>
#include <parallel/multiway_merge.h>
namespace __gnu_parallel
{
/** @brief Subsequence description. */
template<typename _DifferenceTp>
struct Piece
{
typedef _DifferenceTp difference_type;
/** @brief Begin of subsequence. */
difference_type begin;
/** @brief End of subsequence. */
difference_type end;
};
/** @brief Data accessed by all threads.
*
* PMWMS = parallel multiway mergesort */
template<typename RandomAccessIterator>
struct PMWMSSortingData
{
typedef std::iterator_traits<RandomAccessIterator> traits_type;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::difference_type difference_type;
/** @brief Number of threads involved. */
thread_index_t num_threads;
/** @brief Input begin. */
RandomAccessIterator source;
/** @brief Start indices, per thread. */
difference_type* starts;
/** @brief Storage in which to sort. */
value_type** temporary;
/** @brief Samples. */
value_type* samples;
/** @brief Offsets to add to the found positions. */
difference_type* offsets;
/** @brief Pieces of data to merge @c [thread][sequence] */
std::vector<Piece<difference_type> >* pieces;
};
/**
* @brief Select samples from a sequence.
* @param sd Pointer to algorithm data. Result will be placed in
* @c sd->samples.
* @param num_samples Number of samples to select.
*/
template<typename RandomAccessIterator, typename _DifferenceTp>
void
determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
_DifferenceTp num_samples)
{
typedef std::iterator_traits<RandomAccessIterator> traits_type;
typedef typename traits_type::value_type value_type;
typedef _DifferenceTp difference_type;
thread_index_t iam = omp_get_thread_num();
difference_type* es = new difference_type[num_samples + 2];
equally_split(sd->starts[iam + 1] - sd->starts[iam],
num_samples + 1, es);
for (difference_type i = 0; i < num_samples; ++i)
::new(&(sd->samples[iam * num_samples + i]))
value_type(sd->source[sd->starts[iam] + es[i + 1]]);
delete[] es;
}
/** @brief Split consistently. */
template<bool exact, typename RandomAccessIterator,
typename Comparator, typename SortingPlacesIterator>
struct split_consistently
{
};
/** @brief Split by exact splitting. */
template<typename RandomAccessIterator, typename Comparator,
typename SortingPlacesIterator>
struct split_consistently
<true, RandomAccessIterator, Comparator, SortingPlacesIterator>
{
void operator()(
const thread_index_t iam,
PMWMSSortingData<RandomAccessIterator>* sd,
Comparator& comp,
const typename
std::iterator_traits<RandomAccessIterator>::difference_type
num_samples)
const
{
# pragma omp barrier
std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
seqs(sd->num_threads);
for (thread_index_t s = 0; s < sd->num_threads; s++)
seqs[s] = std::make_pair(sd->temporary[s],
sd->temporary[s]
+ (sd->starts[s + 1] - sd->starts[s]));
std::vector<SortingPlacesIterator> offsets(sd->num_threads);
// if not last thread
if (iam < sd->num_threads - 1)
multiseq_partition(seqs.begin(), seqs.end(),
sd->starts[iam + 1], offsets.begin(), comp);
for (int seq = 0; seq < sd->num_threads; seq++)
{
// for each sequence
if (iam < (sd->num_threads - 1))
sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
else
// very end of this sequence
sd->pieces[iam][seq].end =
sd->starts[seq + 1] - sd->starts[seq];
}
# pragma omp barrier
for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
{
// For each sequence.
if (iam > 0)
sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
else
// Absolute beginning.
sd->pieces[iam][seq].begin = 0;
}
}
};
/** @brief Split by sampling. */
template<typename RandomAccessIterator, typename Comparator,
typename SortingPlacesIterator>
struct split_consistently<false, RandomAccessIterator, Comparator,
SortingPlacesIterator>
{
void operator()(
const thread_index_t iam,
PMWMSSortingData<RandomAccessIterator>* sd,
Comparator& comp,
const typename
std::iterator_traits<RandomAccessIterator>::difference_type
num_samples)
const
{
typedef std::iterator_traits<RandomAccessIterator> traits_type;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::difference_type difference_type;
determine_samples(sd, num_samples);
# pragma omp barrier
# pragma omp single
__gnu_sequential::sort(sd->samples,
sd->samples + (num_samples * sd->num_threads),
comp);
# pragma omp barrier
for (thread_index_t s = 0; s < sd->num_threads; ++s)
{
// For each sequence.
if (num_samples * iam > 0)
sd->pieces[iam][s].begin =
std::lower_bound(sd->temporary[s],
sd->temporary[s]
+ (sd->starts[s + 1] - sd->starts[s]),
sd->samples[num_samples * iam],
comp)
- sd->temporary[s];
else
// Absolute beginning.
sd->pieces[iam][s].begin = 0;
if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
sd->pieces[iam][s].end =
std::lower_bound(sd->temporary[s],
sd->temporary[s]
+ (sd->starts[s + 1] - sd->starts[s]),
sd->samples[num_samples * (iam + 1)],
comp)
- sd->temporary[s];
else
// Absolute end.
sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
}
}
};
template<bool stable, typename RandomAccessIterator, typename Comparator>
struct possibly_stable_sort
{
};
template<typename RandomAccessIterator, typename Comparator>
struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
{
void operator()(const RandomAccessIterator& begin,
const RandomAccessIterator& end, Comparator& comp) const
{
__gnu_sequential::stable_sort(begin, end, comp);
}
};
template<typename RandomAccessIterator, typename Comparator>
struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
{
void operator()(const RandomAccessIterator begin,
const RandomAccessIterator end, Comparator& comp) const
{
__gnu_sequential::sort(begin, end, comp);
}
};
template<bool stable, typename SeqRandomAccessIterator,
typename RandomAccessIterator, typename Comparator,
typename DiffType>
struct possibly_stable_multiway_merge
{
};
template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
typename Comparator, typename DiffType>
struct possibly_stable_multiway_merge
<true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
DiffType>
{
void operator()(const SeqRandomAccessIterator& seqs_begin,
const SeqRandomAccessIterator& seqs_end,
const RandomAccessIterator& target,
Comparator& comp,
DiffType length_am) const
{
stable_multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
sequential_tag());
}
};
template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
typename Comparator, typename DiffType>
struct possibly_stable_multiway_merge
<false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
DiffType>
{
void operator()(const SeqRandomAccessIterator& seqs_begin,
const SeqRandomAccessIterator& seqs_end,
const RandomAccessIterator& target,
Comparator& comp,
DiffType length_am) const
{
multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
sequential_tag());
}
};
/** @brief PMWMS code executed by each thread.
* @param sd Pointer to algorithm data.
* @param comp Comparator.
*/
template<bool stable, bool exact, typename RandomAccessIterator,
typename Comparator>
void
parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
Comparator& comp)
{
typedef std::iterator_traits<RandomAccessIterator> traits_type;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::difference_type difference_type;
thread_index_t iam = omp_get_thread_num();
// Length of this thread's chunk, before merging.
difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
// Sort in temporary storage, leave space for sentinel.
typedef value_type* SortingPlacesIterator;
sd->temporary[iam] =
static_cast<value_type*>(
::operator new(sizeof(value_type) * (length_local + 1)));
// Copy there.
std::uninitialized_copy(sd->source + sd->starts[iam],
sd->source + sd->starts[iam] + length_local,
sd->temporary[iam]);
possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
(sd->temporary[iam], sd->temporary[iam] + length_local, comp);
// Invariant: locally sorted subsequence in sd->temporary[iam],
// sd->temporary[iam] + length_local.
// No barrier here: Synchronization is done by the splitting routine.
difference_type num_samples =
_Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
split_consistently
<exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
(iam, sd, comp, num_samples);
// Offset from target begin, length after merging.
difference_type offset = 0, length_am = 0;
for (thread_index_t s = 0; s < sd->num_threads; s++)
{
length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
offset += sd->pieces[iam][s].begin;
}
typedef std::vector<
std::pair<SortingPlacesIterator, SortingPlacesIterator> >
seq_vector_type;
seq_vector_type seqs(sd->num_threads);
for (int s = 0; s < sd->num_threads; ++s)
{
seqs[s] =
std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
sd->temporary[s] + sd->pieces[iam][s].end);
}
possibly_stable_multiway_merge<
stable,
typename seq_vector_type::iterator,
RandomAccessIterator,
Comparator, difference_type>()
(seqs.begin(), seqs.end(),
sd->source + offset, comp,
length_am);
# pragma omp barrier
::operator delete(sd->temporary[iam]);
}
/** @brief PMWMS main call.
* @param begin Begin iterator of sequence.
* @param end End iterator of sequence.
* @param comp Comparator.
* @param n Length of sequence.
* @param num_threads Number of threads to use.
*/
template<bool stable, bool exact, typename RandomAccessIterator,
typename Comparator>
void
parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
Comparator comp,
thread_index_t num_threads)
{
_GLIBCXX_CALL(end - begin)
typedef std::iterator_traits<RandomAccessIterator> traits_type;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::difference_type difference_type;
difference_type n = end - begin;
if (n <= 1)
return;
// at least one element per thread
if (num_threads > n)
num_threads = static_cast<thread_index_t>(n);
// shared variables
PMWMSSortingData<RandomAccessIterator> sd;
difference_type* starts;
# pragma omp parallel num_threads(num_threads)
{
num_threads = omp_get_num_threads(); //no more threads than requested
# pragma omp single
{
sd.num_threads = num_threads;
sd.source = begin;
sd.temporary = new value_type*[num_threads];
if (!exact)
{
difference_type size =
(_Settings::get().sort_mwms_oversampling * num_threads - 1)
* num_threads;
sd.samples = static_cast<value_type*>(
::operator new(size * sizeof(value_type)));
}
else
sd.samples = NULL;
sd.offsets = new difference_type[num_threads - 1];
sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
for (int s = 0; s < num_threads; ++s)
sd.pieces[s].resize(num_threads);
starts = sd.starts = new difference_type[num_threads + 1];
difference_type chunk_length = n / num_threads;
difference_type split = n % num_threads;
difference_type pos = 0;
for (int i = 0; i < num_threads; ++i)
{
starts[i] = pos;
pos += (i < split) ? (chunk_length + 1) : chunk_length;
}
starts[num_threads] = pos;
} //single
// Now sort in parallel.
parallel_sort_mwms_pu<stable, exact>(&sd, comp);
} //parallel
delete[] starts;
delete[] sd.temporary;
if (!exact)
::operator delete(sd.samples);
delete[] sd.offsets;
delete[] sd.pieces;
}
} //namespace __gnu_parallel
#endif /* _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H */