「線形代数の基礎」をJavaで実装してみる

線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html

テンソル

public class Tensor {
    protected final int order;
    public Tensor(int order) {
        this.order = order;
    }
}
public class LinearAlgebraTest {
    public static void main(String[] args) {
        Tensor o1 = new Tensor(1); // 1階のテンソル
        Tensor o2 = new Tensor(2); // 2階のテンソル
    }
}

order は N階のテンソルを表します。 Java的にはN次元の配列ということになります。

ベクトル

ベクトルクラスを定義します。

public class Vector extends Tensor {
    private final float[] scalars;
    public Vector(float[] scalars) {
        super(1); // ベクトルは1階のテンソル
        this.scalars = scalars;
    }
}

ベクトルはテンソルを継承して、order は 1 固定です。

加算を実装します。

Vectorに以下を追加

    public Vector add(Vector object) {
        float[] scalars = new float[this.scalars.length];
        for (int i = 0;i < this.scalars.length;i++) {
            scalars[i] = this.scalars[i] + object.scalars[i];
        }
        return new Vector(scalars);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        if (scalars.length > 0) {
            sb.append(scalars[0]);
            for (int i = 1;i < scalars.length;i++) {
                sb.append(", ").append(scalars[i]);
            }
        }
        sb.append("]");
        return sb.toString();
    }

呼び出し

        Vector v1 = new Vector(new float[]{1, 2, 3});
        Vector v2 = new Vector(new float[]{4, 5, 6});

        System.out.println(v1);
        System.out.println(v2);

        Vector v3 = v1.add(v2);
        System.out.println(v3);

実行結果

[1.0, 2.0, 3.0]
[4.0, 5.0, 6.0]
[5.0, 7.0, 9.0]

行列

こんな感じの実装にしてみます。

public class Matrix extends Tensor {
    private final float[][] o2scalars;
    public Matrix(float[][] o2scalars) {
        super(2); // 行列は2階のテンソル
        this.o2scalars = o2scalars;
    }
}

加算を実装してみます。

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(o2scalars.length).append(" x ").append(o2scalars[0].length).append("\n");
        for (int i = 0;i < o2scalars.length;i++) {
            sb.append("|");
            for (int j = 0;j < o2scalars[i].length;j++) {
                sb.append(String.format("% 5.1f|", o2scalars[i][j]));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    public Matrix add(Matrix object) {
        float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length];

        for (int i = 0;i < this.o2scalars.length;i++) {
            for (int j = 0;j < this.o2scalars[i].length;j++) {
                o2scalars[i][j] = this.o2scalars[i][j] + object.o2scalars[i][j];
            }
        }
        return new Matrix(o2scalars);
    }
        Matrix m1 = new Matrix(new float[][] {
                {1,2,3},
                {4,5,6},
        });
        Matrix m2 = new Matrix(new float[][] {
                {7,8,9},
                {10,11,12},
        });
        System.out.println(m1);
        System.out.println(m2);

        Matrix m3 = m1.add(m2);
        System.out.println(m3);

実行結果です。

2 x 3
|  1.0|  2.0|  3.0|
|  4.0|  5.0|  6.0|

2 x 3
|  7.0|  8.0|  9.0|
| 10.0| 11.0| 12.0|

2 x 3
|  8.0| 10.0| 12.0|
| 14.0| 16.0| 18.0|

次回は行列の積を実装してみたいと思います。