「線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html
テンソル
public class Tensor { protected final int order; public Tensor(int order) { this.order = order; } }
public class LinearAlgebraTest { public static void main(String[] args) { Tensor o1 = new Tensor(1); // 1階のテンソル Tensor o2 = new Tensor(2); // 2階のテンソル } }
order は N階のテンソルを表します。 Java的にはN次元の配列ということになります。
ベクトル
ベクトルクラスを定義します。
public class Vector extends Tensor { private final float[] scalars; public Vector(float[] scalars) { super(1); // ベクトルは1階のテンソル this.scalars = scalars; } }
ベクトルはテンソルを継承して、order は 1
固定です。
加算を実装します。
Vectorに以下を追加
public Vector add(Vector object) { float[] scalars = new float[this.scalars.length]; for (int i = 0;i < this.scalars.length;i++) { scalars[i] = this.scalars[i] + object.scalars[i]; } return new Vector(scalars); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("["); if (scalars.length > 0) { sb.append(scalars[0]); for (int i = 1;i < scalars.length;i++) { sb.append(", ").append(scalars[i]); } } sb.append("]"); return sb.toString(); }
呼び出し
Vector v1 = new Vector(new float[]{1, 2, 3}); Vector v2 = new Vector(new float[]{4, 5, 6}); System.out.println(v1); System.out.println(v2); Vector v3 = v1.add(v2); System.out.println(v3);
実行結果
[1.0, 2.0, 3.0] [4.0, 5.0, 6.0] [5.0, 7.0, 9.0]
行列
こんな感じの実装にしてみます。
public class Matrix extends Tensor { private final float[][] o2scalars; public Matrix(float[][] o2scalars) { super(2); // 行列は2階のテンソル this.o2scalars = o2scalars; } }
加算を実装してみます。
@Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(o2scalars.length).append(" x ").append(o2scalars[0].length).append("\n"); for (int i = 0;i < o2scalars.length;i++) { sb.append("|"); for (int j = 0;j < o2scalars[i].length;j++) { sb.append(String.format("% 5.1f|", o2scalars[i][j])); } sb.append("\n"); } return sb.toString(); } public Matrix add(Matrix object) { float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length]; for (int i = 0;i < this.o2scalars.length;i++) { for (int j = 0;j < this.o2scalars[i].length;j++) { o2scalars[i][j] = this.o2scalars[i][j] + object.o2scalars[i][j]; } } return new Matrix(o2scalars); }
Matrix m1 = new Matrix(new float[][] { {1,2,3}, {4,5,6}, }); Matrix m2 = new Matrix(new float[][] { {7,8,9}, {10,11,12}, }); System.out.println(m1); System.out.println(m2); Matrix m3 = m1.add(m2); System.out.println(m3);
実行結果です。
2 x 3 | 1.0| 2.0| 3.0| | 4.0| 5.0| 6.0| 2 x 3 | 7.0| 8.0| 9.0| | 10.0| 11.0| 12.0| 2 x 3 | 8.0| 10.0| 12.0| | 14.0| 16.0| 18.0|
次回は行列の積を実装してみたいと思います。