This is the mail archive of the gcc-bugs@gcc.gnu.org mailing list for the GCC project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

[Bug c++/57925] New: discrete_distribution can be improved to O(1) per sampling


http://gcc.gnu.org/bugzilla/show_bug.cgi?id=57925

            Bug ID: 57925
           Summary: discrete_distribution can be improved to O(1) per
                    sampling
           Product: gcc
           Version: 4.8.1
            Status: UNCONFIRMED
          Severity: enhancement
          Priority: P3
         Component: c++
          Assignee: unassigned at gcc dot gnu.org
          Reporter: yangzhe1990 at gmail dot com

Current implementation of discrete_distribution employs a strait-forward
algorithm by partial_sum and std::lower_bound, which is an O(log n) per sample
algorithm. It can be improved to an O(1) per sampling, by using the pair and
alias method, just like what GSL did.

The (log n) factor of std::lower_bound is large even for n = 3. 

The algorithm is relatively simple, and you will definitely enjoy implementing
it. Here's my sample code. Although it's quite informal, but I hope you can be
convinced that the effort of this improvement is small but its so great to have
it in the STL.

class my_discrete_distribution {
private:
    vector<double> paired;
    vector<double> weight;
    uniform_real_distribution<double> u;
public:
    template<class InputIterator>
    my_discrete_distribution(InputIterator begin, InputIterator end)
    {
        for (; begin != end; ++begin)
            weight.push_back(*begin);
        int size = weight.size();
        vector<int> small(size);
        int small_cnt = 0;
        for (int i = 0; i < size; ++i) {
            weight[i] *= size;
            if (weight[i] <= 1)
                small[small_cnt++] = i;
        }
        paired.resize(size);
        for (int i = 0; i < size; ++i)
            paired[i] = i;
        for (int i = 0; i < size; ++i) {
            double w = weight[i];
            if (w > 1) {
                int j; int p;
                for (j = small_cnt - 1; j >= 0 && w > 1; --j) {
                    p = small[j];
                    paired[p] = i;
                    w -= (1 - weight[p]);
                }
                small[j + 1] = i;
                small_cnt = j + 2;
                weight[i] = w;
            }
        }
        u.param(uniform_real_distribution<double>(0.0, size + 0.0).param());
    }

    template<class RNG>
    int operator() (RNG &rng) {
        const double a = u(rng);
        int int_part = (int)a;
        if (a - int_part >= weight[int_part])
            return paired[int_part];
        else
            return int_part;
    }
};


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]