This is the mail archive of the
gcc-bugs@gcc.gnu.org
mailing list for the GCC project.
[Bug c++/57925] New: discrete_distribution can be improved to O(1) per sampling
- From: "yangzhe1990 at gmail dot com" <gcc-bugzilla at gcc dot gnu dot org>
- To: gcc-bugs at gcc dot gnu dot org
- Date: Thu, 18 Jul 2013 11:41:34 +0000
- Subject: [Bug c++/57925] New: discrete_distribution can be improved to O(1) per sampling
- Auto-submitted: auto-generated
http://gcc.gnu.org/bugzilla/show_bug.cgi?id=57925
Bug ID: 57925
Summary: discrete_distribution can be improved to O(1) per
sampling
Product: gcc
Version: 4.8.1
Status: UNCONFIRMED
Severity: enhancement
Priority: P3
Component: c++
Assignee: unassigned at gcc dot gnu.org
Reporter: yangzhe1990 at gmail dot com
Current implementation of discrete_distribution employs a strait-forward
algorithm by partial_sum and std::lower_bound, which is an O(log n) per sample
algorithm. It can be improved to an O(1) per sampling, by using the pair and
alias method, just like what GSL did.
The (log n) factor of std::lower_bound is large even for n = 3.
The algorithm is relatively simple, and you will definitely enjoy implementing
it. Here's my sample code. Although it's quite informal, but I hope you can be
convinced that the effort of this improvement is small but its so great to have
it in the STL.
class my_discrete_distribution {
private:
vector<double> paired;
vector<double> weight;
uniform_real_distribution<double> u;
public:
template<class InputIterator>
my_discrete_distribution(InputIterator begin, InputIterator end)
{
for (; begin != end; ++begin)
weight.push_back(*begin);
int size = weight.size();
vector<int> small(size);
int small_cnt = 0;
for (int i = 0; i < size; ++i) {
weight[i] *= size;
if (weight[i] <= 1)
small[small_cnt++] = i;
}
paired.resize(size);
for (int i = 0; i < size; ++i)
paired[i] = i;
for (int i = 0; i < size; ++i) {
double w = weight[i];
if (w > 1) {
int j; int p;
for (j = small_cnt - 1; j >= 0 && w > 1; --j) {
p = small[j];
paired[p] = i;
w -= (1 - weight[p]);
}
small[j + 1] = i;
small_cnt = j + 2;
weight[i] = w;
}
}
u.param(uniform_real_distribution<double>(0.0, size + 0.0).param());
}
template<class RNG>
int operator() (RNG &rng) {
const double a = u(rng);
int int_part = (int)a;
if (a - int_part >= weight[int_part])
return paired[int_part];
else
return int_part;
}
};