近似計算の比較

計算をするときいくつかの処理を高速化のために近似値計算で済ます方法があります。 こちらのブログで計算方法が紹介されているので拝借します。 https://martin.ankerl.com/2007/10/04/optimized-pow-approximation-for-java-and-c-c/

package math;

import org.apache.commons.math3.util.FastMath;

public class PowTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.pow(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.pow", sum, (System.currentTimeMillis() - start));
        }
        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.pow(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.pow", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow1(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow1", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow2(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow2", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow3(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow3", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow4(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow4", sum, (System.currentTimeMillis() - start));
        }
    }

    public static double pow1(final double a, final double b) {
        final int x = (int) (Double.doubleToLongBits(a) >> 32);
        final int y = (int) (b * (x - 1072632447) + 1072632447);
        return Double.longBitsToDouble(((long) y) << 32);
    }

    public static double pow2(final double a, final double b) {
        final long tmp = Double.doubleToLongBits(a);
        final long tmp2 = (long)(b * (tmp - 4606921280493453312L)) + 4606921280493453312L;
        return Double.longBitsToDouble(tmp2);
    }

    public static double pow3(final double a, final double b) {
        final double x = (Double.doubleToLongBits(a) >> 32);
        final long tmp2 = (long) (1512775 * (x - 1072632447) / 1512775 * b + (1072693248 - 60801));
        return Double.longBitsToDouble(tmp2 << 32);
    }
    public static double pow4(final double a, final double b) {
        final int tmp = (int) (Double.doubleToLongBits(a) >> 32);
        final int tmp2 = (int) (b * (tmp - 1072632447) + 1072632447);
        return Double.longBitsToDouble(((long) tmp2) << 32);
    }
}

macOS Oracle JDK 12.0.2で実行しました。

Math.pow      sum= 333332833333127550 7ms
Math.pow      sum= 333332833333127550 5ms
Math.pow      sum= 333332833333127550 4ms
FastMath.pow  sum= 333332833333127550 165ms
FastMath.pow  sum= 333332833333127550 148ms
FastMath.pow  sum= 333332833333127550 144ms
pow1          sum= 332595653188566270 10ms
pow1          sum= 332595653188566270 9ms
pow1          sum= 332595653188566270 9ms
pow2          sum= 332595653188566270 9ms
pow2          sum= 332595653188566270 7ms
pow2          sum= 332595653188566270 8ms
pow3          sum= 332595653188566270 16ms
pow3          sum= 332595653188566270 16ms
pow3          sum= 332595653188566270 15ms
pow4          sum= 332595653188566270 10ms
pow4          sum= 332595653188566270 10ms
pow4          sum= 332595653188566270 10ms

Math.powが最速でした。FastMath.powはかなり遅いです。 pow1 ~ pow4は計算結果は変わらないですが、速度に差が出ています。

次はexpです。

package math;

import org.apache.commons.math3.util.FastMath;

public class ExpTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.exp(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.exp(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += exp1(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "exp1", sum, (System.currentTimeMillis() - start));
        }
    }
    public static double exp1(double val) {
        final long tmp = (long) (1512775 * val + (1072693248 - 60801));
        return Double.longBitsToDouble(tmp << 32);
    }
}

結果です。

Math.exp      sum= 268797601492947300000000000000000000000000000000 15ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 24ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 45ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 30ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 30ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 31ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 31ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms
exp1          sum= 267998235537123060000000000000000000000000000000 8ms
exp1          sum= 267998235537123060000000000000000000000000000000 8ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms

今回はexp1が最速でした。FastMath.expの結果は悪く無いですが、一番遅かったです。

最後はsqrtです

package math;

import org.apache.commons.math3.util.FastMath;

public class SqrtTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.sqrt(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.sqrt(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.sqrt", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += sqrt1(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "sqrt1", sum, (System.currentTimeMillis() - start));
        }
    }
    public static double sqrt1(final double a) {
        final long x = Double.doubleToLongBits(a) >> 32;
        double y = Double.longBitsToDouble((x + 1072632448) << 31);
        return y;
    }
}

結果です。

Math.exp      sum= 666666166 7ms
Math.exp      sum= 666666166 9ms
Math.exp      sum= 666666166 5ms
Math.exp      sum= 666666166 3ms
FastMath.sqrt sum= 666666166 10ms
FastMath.sqrt sum= 666666166 8ms
FastMath.sqrt sum= 666666166 7ms
FastMath.sqrt sum= 666666166 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms

最速はMath.expでした。 アルゴリズム的には似通ったものになるのか差はほとんどありませんでした。