This is the mail archive of the java-patches@gcc.gnu.org mailing list for the Java project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

[PATCH] speed up java.util.Random


I saw a PR some time ago about the slowness of java.util.Random in GNU
classpath.  I cannot remember and cannot find the PR number, and
actually I cannot reproduce this slowness when comparing the classpath
implementation against the Sun implementation:

Sun JDK 1.4.2         4586  6038  6216  6254  6875  6952    6153.50
current classpath     3398  3586  3996  4169  4228  4962    4056.50

The huge variance is because I am compiling and running the benchmarks with Sun's JITting virtual machine.

Still, inlining the next(int) method, while keeping the same result and
the synchronized behavior of java.util.Random, I could improve
performance by ~20%:

current classpath     3398  3586  3996  4169  4228  4962    4056.50
patched classpath     2585  2618  3465  3538  3606  3807    3269.83

I don't know how this can be contributed to GNU classpath.  I attach a
patch with a changelog, and the tests I ran to validate my
implementation and to compare the relative speeds.

Paolo

2005-11-24 Paolo Bonzini <bonzini@gnu.org>

        * java/util/Random.java: Rewrite, inlining the next(int) method
        and synchronization blocks into the callers.


public class RandomTest
{
  public static OldRandom or = new OldRandom (12345678);
  public static NewRandom nr = new NewRandom (12345678);

  public static void main (String[] args)
  {
    System.out.println ("\nnext (int)\n==========================");
    System.out.println (or.next (1) == nr.next (1));
    System.out.println (or.next (16) == nr.next (16));
    System.out.println (or.next (24) == nr.next (24));
    System.out.println (or.next (32) == nr.next (32));

    System.out.println ("\nnextBytes\n==========================");
    byte[] ob = new byte[7]; or.nextBytes (ob);
    byte[] nb = new byte[7]; nr.nextBytes (nb);
    System.out.println (ob[0] == nb[0]);
    System.out.println (ob[1] == nb[1]);
    System.out.println (ob[2] == nb[2]);
    System.out.println (ob[3] == nb[3]);
    System.out.println (ob[4] == nb[4]);
    System.out.println (ob[5] == nb[5]);
    System.out.println (ob[6] == nb[6]);

    System.out.println ("\nnextInt\n==========================");
    System.out.println (or.nextInt () == nr.nextInt ());
    System.out.println (or.nextInt () == nr.nextInt ());

    System.out.println ("\nnextInt(int)\n==========================");
    System.out.println (or.nextInt (256) == nr.nextInt (256));
    System.out.println (or.nextInt (256) == nr.nextInt (256));
    System.out.println (or.nextInt (32768) == nr.nextInt (32768));
    System.out.println (or.nextInt (32768) == nr.nextInt (32768));
    System.out.println (or.nextInt (32769) == nr.nextInt (32769));
    System.out.println (or.nextInt (32769) == nr.nextInt (32769));
    System.out.println (or.nextInt (0x70654321) == nr.nextInt (0x70654321));
    System.out.println (or.nextInt (0x70654321) == nr.nextInt (0x70654321));

    System.out.println ("\nnextLong()\n==========================");
    System.out.println (or.nextLong () == nr.nextLong ());
    System.out.println (or.nextLong () == nr.nextLong ());
    System.out.println (or.nextLong () == nr.nextLong ());
    System.out.println (or.nextLong () == nr.nextLong ());

    System.out.println ("\nnextBoolean()\n==========================");
    System.out.println (or.nextBoolean () == nr.nextBoolean ());
    System.out.println (or.nextBoolean () == nr.nextBoolean ());
    System.out.println (or.nextBoolean () == nr.nextBoolean ());
    System.out.println (or.nextBoolean () == nr.nextBoolean ());
    System.out.println (or.nextBoolean () == nr.nextBoolean ());
    System.out.println (or.nextBoolean () == nr.nextBoolean ());

    System.out.println ("\nnextFloat()\n==========================");
    System.out.println (or.nextFloat () == nr.nextFloat ());
    System.out.println (or.nextFloat () == nr.nextFloat ());
    System.out.println (or.nextFloat () == nr.nextFloat ());
    System.out.println (or.nextFloat () == nr.nextFloat ());

    System.out.println ("\nnextDouble()\n==========================");
    System.out.println (or.nextDouble () == nr.nextDouble ());
    System.out.println (or.nextDouble () == nr.nextDouble ());
    System.out.println (or.nextDouble () == nr.nextDouble ());
    System.out.println (or.nextDouble () == nr.nextDouble ());

    System.out.println ("\nnextGaussian()\n==========================");
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
    System.out.println (or.nextGaussian () == nr.nextGaussian ());
  }
}


public class RandomSpeed
{
  public static OldRandom or = new OldRandom (12345678);
  public static NewRandom nr = new NewRandom (12345678);
  public static java.util.Random r = new java.util.Random (12345678);
  public static byte[] b = new byte[7];

  public static void main (String args[])
  {
    long old;

    old = System.currentTimeMillis ();
    System.out.print ("sun ");
    testSun ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("old ");
    testOld ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("new ");
    testNew ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("sun ");
    testSun ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("old ");
    testOld ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("new ");
    testNew ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("sun ");
    testSun ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("old ");
    testOld ();
    System.out.println (System.currentTimeMillis () - old);

    old = System.currentTimeMillis ();
    System.out.print ("new ");
    testNew ();
    System.out.println (System.currentTimeMillis () - old);
  }

  public static void testOld ()
  {
    for (int i = 0; i < 100000; i++)
      {
        or.nextBytes (b);

        or.nextInt (); or.nextInt (); or.nextInt ();
        or.nextInt (); or.nextInt (); or.nextInt ();

        or.nextInt (256); or.nextInt (256);
        or.nextInt (32768); or.nextInt (32768);
        or.nextInt (32769); or.nextInt (32769);
        or.nextInt (0x70654321); or.nextInt (0x70654321);

        or.nextLong (); or.nextLong (); or.nextLong (); or.nextLong ();

        or.nextBoolean (); or.nextBoolean (); or.nextBoolean ();
        or.nextBoolean (); or.nextBoolean (); or.nextBoolean ();

        or.nextFloat (); or.nextFloat (); or.nextFloat (); or.nextFloat ();

        or.nextDouble (); or.nextDouble (); or.nextDouble (); or.nextDouble ();

        or.nextGaussian (); or.nextGaussian (); or.nextGaussian ();
        or.nextGaussian (); or.nextGaussian (); or.nextGaussian ();
        or.nextGaussian (); or.nextGaussian ();
      }
  }

  public static void testNew ()
  {
    for (int i = 0; i < 100000; i++)
      {
        nr.nextBytes (b);

        nr.nextInt (); nr.nextInt (); nr.nextInt ();
        nr.nextInt (); nr.nextInt (); nr.nextInt ();

        nr.nextInt (256); nr.nextInt (256);
        nr.nextInt (32768); nr.nextInt (32768);
        nr.nextInt (32769); nr.nextInt (32769);
        nr.nextInt (0x70654321); nr.nextInt (0x70654321);

        nr.nextLong (); nr.nextLong (); nr.nextLong (); nr.nextLong ();

        nr.nextBoolean (); nr.nextBoolean (); nr.nextBoolean ();
        nr.nextBoolean (); nr.nextBoolean (); nr.nextBoolean ();

        nr.nextFloat (); nr.nextFloat (); nr.nextFloat (); nr.nextFloat ();

        nr.nextDouble (); nr.nextDouble (); nr.nextDouble (); nr.nextDouble ();

        nr.nextGaussian (); nr.nextGaussian (); nr.nextGaussian ();
        nr.nextGaussian (); nr.nextGaussian (); nr.nextGaussian ();
        nr.nextGaussian (); nr.nextGaussian ();
      }
  }

  public static void testSun ()
  {
    for (int i = 0; i < 100000; i++)
      {
        r.nextBytes (b);

        r.nextInt (); r.nextInt (); r.nextInt ();
        r.nextInt (); r.nextInt (); r.nextInt ();

        r.nextInt (256); r.nextInt (256);
        r.nextInt (32768); r.nextInt (32768);
        r.nextInt (32769); r.nextInt (32769);
        r.nextInt (0x70654321); r.nextInt (0x70654321);

        r.nextLong (); r.nextLong (); r.nextLong (); r.nextLong ();

        r.nextBoolean (); r.nextBoolean (); r.nextBoolean ();
        r.nextBoolean (); r.nextBoolean (); r.nextBoolean ();

        r.nextFloat (); r.nextFloat (); r.nextFloat (); r.nextFloat ();

        r.nextDouble (); r.nextDouble (); r.nextDouble (); r.nextDouble ();

        r.nextGaussian (); r.nextGaussian (); r.nextGaussian ();
        r.nextGaussian (); r.nextGaussian (); r.nextGaussian ();
        r.nextGaussian (); r.nextGaussian ();
      }
  }
}


--- Random.java.old	2005-11-24 22:21:44.000000000 +0100
+++ Random.java	2005-11-24 23:20:19.000000000 +0100
@@ -148,7 +148,7 @@   public Random(long seed)
    */
   public synchronized void setSeed(long seed)
   {
-    this.seed = (seed ^ 0x5DEECE66DL) & ((1L << 48) - 1);
+    this.seed = (seed ^ 0x5DEECE66DL) & 0xFFFFFFFFFFFFL;
     haveNextNextGaussian = false;
   }
 
@@ -170,8 +170,8 @@   public synchronized void setSeed(long 
    */
   protected synchronized int next(int bits)
   {
-    seed = (seed * 0x5DEECE66DL + 0xBL) & ((1L << 48) - 1);
-    return (int) (seed >>> (48 - bits));
+    long rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    return (int) (rand48 >>> (48 - bits));
   }
 
   /**
@@ -198,26 +198,30 @@   public synchronized void setSeed(long 
    */
   public void nextBytes(byte[] bytes)
   {
-    int random;
-    // Do a little bit unrolling of the above algorithm.
-    int max = bytes.length & ~0x3;
-    for (int i = 0; i < max; i += 4)
-      {
-        random = next(32);
-        bytes[i] = (byte) random;
-        bytes[i + 1] = (byte) (random >> 8);
-        bytes[i + 2] = (byte) (random >> 16);
-        bytes[i + 3] = (byte) (random >> 24);
-      }
-    if (max < bytes.length)
-      {
-        random = next(32);
-        for (int j = max; j < bytes.length; j++)
-          {
-            bytes[j] = (byte) random;
-            random >>= 8;
-          }
-      }
+    synchronized (this) {
+      int random;
+      // Do a little bit unrolling of the above algorithm.
+      int max = bytes.length & ~0x3;
+      for (int i = 0; i < max; i += 4)
+        {
+          long rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+          random = (int) (rand48 >>> 16);
+          bytes[i] = (byte) random;
+          bytes[i + 1] = (byte) (random >> 8);
+          bytes[i + 2] = (byte) (random >> 16);
+          bytes[i + 3] = (byte) (random >> 24);
+        }
+      if (max < bytes.length)
+        {
+          long rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+          random = (int) (rand48 >>> 16);
+          for (int j = max; j < bytes.length; j++)
+            {
+              bytes[j] = (byte) random;
+              random >>= 8;
+            }
+        }
+    }
   }
 
   /**
@@ -235,7 +239,11 @@   public void nextBytes(byte[] bytes)
    */
   public int nextInt()
   {
-    return next(32);
+    long rand48;
+    synchronized (this) {
+      rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    }
+    return (int) (rand48 >>> 16);
   }
 
   /**
@@ -289,15 +297,23 @@   public int nextInt(int n)
   {
     if (n <= 0)
       throw new IllegalArgumentException("n must be positive");
-    if ((n & -n) == n) // i.e., n is a power of 2
-      return (int) ((n * (long) next(31)) >> 31);
     int bits, val;
-    do
-      {
-        bits = next(31);
-        val = bits % n;
-      }
-    while (bits - val + (n - 1) < 0);
+    synchronized (this) {
+      if ((n & -n) == n) // i.e., n is a power of 2
+        {
+          long rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+          long random = rand48 >>> 17;
+          val = (int) ((n * random) >> 31);
+        }
+      else
+        do
+          {
+            long rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+            bits = (int) (rand48 >>> 17);
+            val = bits % n;
+          }
+        while (bits - val + (n - 1) < 0);
+    }
     return val;
   }
 
@@ -315,7 +331,14 @@   public int nextInt(int n)
    */
   public long nextLong()
   {
-    return ((long) next(32) << 32) + next(32);
+    long high, low;
+    synchronized (this) {
+      high = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+      low = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    }
+    high <<= 16;
+    low >>>= 16;
+    return (high & (0xFFFFFFFFL << 32)) + (int) low;
   }
 
   /**
@@ -332,7 +355,11 @@   public long nextLong()
    */
   public boolean nextBoolean()
   {
-    return next(1) != 0;
+    long rand48;
+    synchronized (this) {
+      rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    }
+    return (rand48 & 0x800000000000L) != 0;
   }
 
   /**
@@ -349,7 +376,11 @@   public boolean nextBoolean()
    */
   public float nextFloat()
   {
-    return next(24) / (float) (1 << 24);
+    long rand48;
+    synchronized (this) {
+      rand48 = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    }
+    return ((int) (rand48 >>> 24)) / ((float) (1 << 24));
   }
 
   /**
@@ -366,7 +397,19 @@   public float nextFloat()
    */
   public double nextDouble()
   {
-    return (((long) next(26) << 27) + next(27)) / (double) (1L << 53);
+    long high, low;
+
+    synchronized (this) {
+      high = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+      low = seed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL;
+    }
+
+    // Get the high 26 bits
+    high = high << 5 & 0x1FFFFFF8000000L;
+
+    // Get the low 27 bits
+    low = low >>> 21;
+    return (high + low) / (double) (1L << 53);
   }
 
   /**


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]