Java Vector API を試す

https://nowokay.hatenablog.com/entry/2019/09/05/015537

こちらの記事を参考にビルドします。

CのコードとJavaのコードを比較していきます。

# vector.h
float dotProduct512(float* vec1, float* vec2, int num);
float dotProduct256(float* vec1, float* vec2, int num);
float dotProduct(float* vec1, float* vec2, int num);
#vector.c
#include "vector.h"
#include <immintrin.h>

float dotProduct512(float* vec1, float* vec2, int num)
{
    __m512 avx_sum = _mm512_setzero_ps();
    int limit = num - num % 16;
    for (int i = 0;i < limit;i += 16) {
        const __m512 a512 = _mm512_loadu_ps((float*)&vec1[i]);
        const __m512 b512 = _mm512_loadu_ps((float*)&vec2[i]);
        avx_sum = _mm512_fmadd_ps(a512, b512, avx_sum);
    }

    float __attribute__((aligned(32))) out[16] = {};
    _mm512_storeu_ps(out, avx_sum);
    float sum = 0;
    for (int i = 0;i < 16;i++) {
        sum += out[i];
    }
    for (int i = limit;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}

float dotProduct256(float* vec1, float* vec2, int num)
{
    __m256 avx_sum = _mm256_setzero_ps();
    int limit = num - num % 8;
    for (int i = 0;i < limit;i += 8) {
        const __m256 a256 = _mm256_loadu_ps((float*)&vec1[i]);
        const __m256 b256 = _mm256_loadu_ps((float*)&vec2[i]);
        avx_sum = _mm256_fmadd_ps(a256, b256, avx_sum);
    }

    float __attribute__((aligned(32))) out[16] = {};
    _mm256_store_ps(out, avx_sum);
    float sum = 0;
    for (int i = 0;i < 8;i++) {
        sum += out[i];
    }
    for (int i = limit;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}

float dotProduct(float* vec1, float* vec2, int num)
{
    float sum = 0;
    for (int i = 0;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}
# vec_test.c

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include "vector.h"

#define NUM 100 * 1000 * 1000

unsigned long getMicroSec()
{
    struct timespec time1;
    clock_gettime(CLOCK_REALTIME,&time1);

    unsigned long micros = time1.tv_sec * 1000000;
    micros += time1.tv_nsec / 1000;
    return micros;
}

void bench1()
{

    float *vec_a, *vec_b;
    vec_a = (float*)malloc(sizeof(float) * NUM);
    vec_b = (float*)malloc(sizeof(float) * NUM);

    for(int i = 0;i < NUM;i++) {
        vec_a[i] = ((float)rand()) / NUM;
        vec_b[i] = ((float)rand()) / NUM;
    }

    unsigned long start = getMicroSec();
    float sum = dotProduct(vec_a, vec_b, NUM);
    printf("bench1 - %lu ms\n", (getMicroSec() - start) / 1000);

    free(vec_a);
    free(vec_b);
}

void bench2()
{

    float *vec_a, *vec_b;
    vec_a = (float*)malloc(sizeof(float) * NUM);
    vec_b = (float*)malloc(sizeof(float) * NUM);

    for(int i = 0;i < NUM;i++) {
        vec_a[i] = ((float)rand()) / NUM;
        vec_b[i] = ((float)rand()) / NUM;
    }
    unsigned long start = getMicroSec();

    float sum = dotProduct256(vec_a, vec_b, NUM);
    printf("bench2 - %lu ms\n", (getMicroSec() - start) / 1000);

    free(vec_a);
    free(vec_b);
}

int main(void)
{
    bench1();
    bench2();
    return 0;
}

以下のコマンドでビルドします。

gcc -O2 -mavx512f vec_test.c  -o vec_test vector
gcc -O2 vec_test.c  -o vec_test vector

実行結果です。

 $ ./vec_test
bench1 - 135 ms
bench2 - 52 ms

FMA命令を使用した方が早くなります。 次にJavaです。

// VectorTest2.java
import jdk.incubator.vector.*;

public class VectorTest2 {
    private final static int NUM = 100 * 1000 * 1000;

    private static float dotProduct256(float[] vec_a, float[] vec_b) {
        var SP = FloatVector.SPECIES_256;

        int limit = vec_a.length - vec_a.length % 8;
        var fv_sum = FloatVector.fromValues(SP, 0, 0, 0, 0, 0, 0, 0, 0);
        for (int i = 0; i < limit; i += 8) {
            var fv_a = FloatVector.fromArray(SP, vec_a, i);
            var fv_b = FloatVector.fromArray(SP, vec_b, i);
            fv_sum = fv_a.fma(fv_b, fv_sum);
        }

        float[] outArray = new float[8];
        fv_sum.intoArray(outArray, 0);

        float sum = 0;
        for (float f: outArray) {
            sum += f;
        }

        for (int i = limit; i < vec_a.length; i += 1) {
            sum += vec_a[i] * vec_b[i];
        }
        return sum;
    }
    private static float dotProduct(float[] vec_a, float[] vec_b) {
        float sum = 0;
        for (int i = 0; i < vec_a.length; i++) {
            sum += vec_a[i] * vec_b[i];
        }
        return sum;
    }

    private static void bench1() {
        float[] vec_a = new float[NUM];
        float[] vec_b = new float[NUM];
        for (int i = 0;i < NUM;i++) {
            vec_a[i] = (float)Math.random();
            vec_b[i] = (float)Math.random();
        }
        long start = System.currentTimeMillis();
        float sum = dotProduct(vec_a, vec_b);
        System.out.format("bench1 - %d ms\n", (System.currentTimeMillis() - start));
    }
    private static void bench2() {
        float[] vec_a = new float[NUM];
        float[] vec_b = new float[NUM];
        for (int i = 0;i < NUM;i++) {
            vec_a[i] = (float)Math.random();
            vec_b[i] = (float)Math.random();
        }
        long start = System.currentTimeMillis();
        float sum = dotProduct256(vec_a, vec_b);
        System.out.format("bench2 - %d ms\n", (System.currentTimeMillis() - start));
    }
    public static void main(String[] args) {
        for(int i = 0;i < 20;i++) {
            bench1();
        }
        for(int i = 0;i < 20;i++) {
            bench2();
        }
    }
}

ビルドして実行します。

$ javac14 src/main/java/VectorTest2.java \
                    --add-modules jdk.incubator.vector \
                    -d out/
$ java14 -cp out/ VectorTest2
bench1 - 114 ms
bench1 - 127 ms
bench1 - 116 ms
bench1 - 128 ms
bench1 - 108 ms
bench1 - 140 ms
bench1 - 117 ms
bench1 - 114 ms
bench1 - 112 ms
bench1 - 109 ms
bench1 - 144 ms
bench1 - 120 ms
bench1 - 120 ms
bench1 - 110 ms
bench1 - 117 ms
bench1 - 107 ms
bench1 - 107 ms
bench1 - 110 ms
bench1 - 107 ms
bench1 - 109 ms
bench2 - 447 ms
bench2 - 125 ms
bench2 - 163 ms
bench2 - 110 ms
bench2 - 117 ms
bench2 - 117 ms
bench2 - 117 ms
bench2 - 116 ms
bench2 - 123 ms
bench2 - 116 ms
bench2 - 129 ms
bench2 - 116 ms
bench2 - 116 ms
bench2 - 116 ms
bench2 - 117 ms
bench2 - 115 ms
bench2 - 117 ms
bench2 - 116 ms
bench2 - 118 ms
bench2 - 116 ms

VectorAPIを使用しない場合107msが最速ですが、使用した場合116msが最速です。 なぜ使用すると遅くなるのかは不明ですが、まだまだ最適化が必要なのでしょう。


Intel公式資料

Java Doc