| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.android.server.wifi.util; |
| |
| /** |
| * Utility for doing basic matix calculations |
| */ |
| public class Matrix { |
| public final int n; |
| public final int m; |
| public final double[] mem; |
| |
| /** |
| * Creates a new matrix, initialized to zeros |
| * |
| * @param rows - number of rows (n) |
| * @param cols - number of columns (m) |
| */ |
| public Matrix(int rows, int cols) { |
| n = rows; |
| m = cols; |
| mem = new double[rows * cols]; |
| } |
| |
| /** |
| * Creates a new matrix using the provided array of values |
| * <p> |
| * Values are in row-major order. |
| * |
| * @param stride is the number of columns. |
| * @param values is the array of values. |
| * @throws IllegalArgumentException if length of values array not a multiple of stride |
| */ |
| public Matrix(int stride, double[] values) { |
| n = (values.length + stride - 1) / stride; |
| m = stride; |
| mem = values; |
| if (mem.length != n * m) throw new IllegalArgumentException(); |
| } |
| |
| /** |
| * Creates a new matrix duplicating the given one |
| * |
| * @param that is the source Matrix. |
| */ |
| public Matrix(Matrix that) { |
| n = that.n; |
| m = that.m; |
| mem = new double[that.mem.length]; |
| for (int i = 0; i < mem.length; i++) { |
| mem[i] = that.mem[i]; |
| } |
| } |
| |
| /** |
| * Gets the matrix coefficient from row i, column j |
| * |
| * @param i row number |
| * @param j column number |
| * @return Coefficient at i,j |
| * @throws IndexOutOfBoundsException if an index is out of bounds |
| */ |
| public double get(int i, int j) { |
| if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); |
| return mem[i * m + j]; |
| } |
| |
| /** |
| * Store a matrix coefficient in row i, column j |
| * |
| * @param i row number |
| * @param j column number |
| * @param v Coefficient to store at i,j |
| * @throws IndexOutOfBoundsException if an index is out of bounds |
| */ |
| public void put(int i, int j, double v) { |
| if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); |
| mem[i * m + j] = v; |
| } |
| |
| /** |
| * Forms the sum of two matrices, this and that |
| * |
| * @param that is the other matrix |
| * @return newly allocated matrix representing the sum of this and that |
| * @throws IllegalArgumentException if shapes differ |
| */ |
| public Matrix plus(Matrix that) { |
| return plus(that, new Matrix(n, m)); |
| |
| } |
| |
| /** |
| * Forms the sum of two matrices, this and that |
| * |
| * @param that is the other matrix |
| * @param result is space to hold the result |
| * @return result, filled with the matrix sum |
| * @throws IllegalArgumentException if shapes differ |
| */ |
| public Matrix plus(Matrix that, Matrix result) { |
| if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { |
| throw new IllegalArgumentException(); |
| } |
| for (int i = 0; i < mem.length; i++) { |
| result.mem[i] = this.mem[i] + that.mem[i]; |
| } |
| return result; |
| } |
| |
| /** |
| * Forms the difference of two matrices, this and that |
| * |
| * @param that is the other matrix |
| * @return newly allocated matrix representing the difference of this and that |
| * @throws IllegalArgumentException if shapes differ |
| */ |
| public Matrix minus(Matrix that) { |
| return minus(that, new Matrix(n, m)); |
| } |
| |
| /** |
| * Forms the difference of two matrices, this and that |
| * |
| * @param that is the other matrix |
| * @param result is space to hold the result |
| * @return result, filled with the matrix difference |
| * @throws IllegalArgumentException if shapes differ |
| */ |
| public Matrix minus(Matrix that, Matrix result) { |
| if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { |
| throw new IllegalArgumentException(); |
| } |
| for (int i = 0; i < mem.length; i++) { |
| result.mem[i] = this.mem[i] - that.mem[i]; |
| } |
| return result; |
| } |
| |
| /** |
| * Forms a scalar product |
| * |
| * @param scalar is the value to multiply by |
| * @return newly allocated matrix representing the product this and scalar |
| */ |
| public Matrix times(double scalar) { |
| return times(scalar, new Matrix(n, m)); |
| } |
| |
| /** |
| * Forms a scalar product |
| * |
| * @param scalar is the value to multiply by |
| * @param result is space to hold the result |
| * @return result, filled with the matrix difference |
| * @throws IllegalArgumentException if shapes differ |
| */ |
| public Matrix times(double scalar, Matrix result) { |
| if (!(this.n == result.n && this.m == result.m)) { |
| throw new IllegalArgumentException(); |
| } |
| for (int i = 0; i < mem.length; i++) { |
| result.mem[i] = this.mem[i] * scalar; |
| } |
| return result; |
| } |
| |
| /** |
| * Forms the matrix product of two matrices, this and that |
| * |
| * @param that is the other matrix |
| * @return newly allocated matrix representing the matrix product of this and that |
| * @throws IllegalArgumentException if shapes are not conformant |
| */ |
| public Matrix dot(Matrix that) { |
| return dot(that, new Matrix(this.n, that.m)); |
| } |
| |
| /** |
| * Forms the matrix product of two matrices, this and that |
| * <p> |
| * Caller supplies an object to contain the result, as well as scratch space |
| * |
| * @param that is the other matrix |
| * @param result is space to hold the result |
| * @return result, filled with the matrix product |
| * @throws IllegalArgumentException if shapes are not conformant |
| */ |
| public Matrix dot(Matrix that, Matrix result) { |
| if (!(this.n == result.n && this.m == that.n && that.m == result.m)) { |
| throw new IllegalArgumentException(); |
| } |
| for (int i = 0; i < n; i++) { |
| for (int j = 0; j < that.m; j++) { |
| double s = 0.0; |
| for (int k = 0; k < m; k++) { |
| s += this.get(i, k) * that.get(k, j); |
| } |
| result.put(i, j, s); |
| } |
| } |
| return result; |
| } |
| |
| /** |
| * Forms the matrix transpose |
| * |
| * @return newly allocated transpose matrix |
| */ |
| public Matrix transpose() { |
| return transpose(new Matrix(m, n)); |
| } |
| |
| /** |
| * Forms the matrix transpose |
| * <p> |
| * Caller supplies an object to contain the result |
| * |
| * @param result is space to hold the result |
| * @return result, filled with the matrix transpose |
| * @throws IllegalArgumentException if result shape is wrong |
| */ |
| public Matrix transpose(Matrix result) { |
| if (!(this.n == result.m && this.m == result.n)) throw new IllegalArgumentException(); |
| for (int i = 0; i < n; i++) { |
| for (int j = 0; j < m; j++) { |
| result.put(j, i, get(i, j)); |
| } |
| } |
| return result; |
| } |
| |
| /** |
| * Forms the inverse of a square matrix |
| * |
| * @return newly allocated matrix representing the matrix inverse |
| * @throws ArithmeticException if the matrix is not invertible |
| */ |
| public Matrix inverse() { |
| return inverse(new Matrix(n, m), new Matrix(n, 2 * m)); |
| } |
| |
| /** |
| * Forms the inverse of a square matrix |
| * |
| * @param result is space to hold the result |
| * @param scratch is workspace of dimension n by 2*n |
| * @return result, filled with the matrix inverse |
| * @throws ArithmeticException if the matrix is not invertible |
| * @throws IllegalArgumentException if shape of scratch or result is wrong |
| */ |
| public Matrix inverse(Matrix result, Matrix scratch) { |
| if (!(n == m && n == result.n && m == result.m && n == scratch.n && 2 * m == scratch.m)) { |
| throw new IllegalArgumentException(); |
| } |
| |
| for (int i = 0; i < n; i++) { |
| for (int j = 0; j < m; j++) { |
| scratch.put(i, j, get(i, j)); |
| scratch.put(i, m + j, i == j ? 1.0 : 0.0); |
| } |
| } |
| |
| for (int i = 0; i < n; i++) { |
| int ibest = i; |
| double vbest = Math.abs(scratch.get(ibest, ibest)); |
| for (int ii = i + 1; ii < n; ii++) { |
| double v = Math.abs(scratch.get(ii, i)); |
| if (v > vbest) { |
| ibest = ii; |
| vbest = v; |
| } |
| } |
| if (ibest != i) { |
| for (int j = 0; j < scratch.m; j++) { |
| double t = scratch.get(i, j); |
| scratch.put(i, j, scratch.get(ibest, j)); |
| scratch.put(ibest, j, t); |
| } |
| } |
| double d = scratch.get(i, i); |
| if (d == 0.0) throw new ArithmeticException("Singular matrix"); |
| for (int j = 0; j < scratch.m; j++) { |
| scratch.put(i, j, scratch.get(i, j) / d); |
| } |
| for (int ii = i + 1; ii < n; ii++) { |
| d = scratch.get(ii, i); |
| for (int j = 0; j < scratch.m; j++) { |
| scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); |
| } |
| } |
| } |
| for (int i = n - 1; i >= 0; i--) { |
| for (int ii = 0; ii < i; ii++) { |
| double d = scratch.get(ii, i); |
| for (int j = 0; j < scratch.m; j++) { |
| scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); |
| } |
| } |
| } |
| for (int i = 0; i < result.n; i++) { |
| for (int j = 0; j < result.m; j++) { |
| result.put(i, j, scratch.get(i, m + j)); |
| } |
| } |
| return result; |
| } |
| /** |
| * Forms the matrix product with the transpose of a second matrix |
| * |
| * @param that is the other matrix |
| * @return newly allocated matrix representing the matrix product of this and that.transpose() |
| * @throws IllegalArgumentException if shapes are not conformant |
| */ |
| public Matrix dotTranspose(Matrix that) { |
| return dotTranspose(that, new Matrix(this.n, that.n)); |
| } |
| |
| /** |
| * Forms the matrix product with the transpose of a second matrix |
| * <p> |
| * Caller supplies an object to contain the result, as well as scratch space |
| * |
| * @param that is the other matrix |
| * @param result is space to hold the result |
| * @return result, filled with the matrix product of this and that.transpose() |
| * @throws IllegalArgumentException if shapes are not conformant |
| */ |
| public Matrix dotTranspose(Matrix that, Matrix result) { |
| if (!(this.n == result.n && this.m == that.m && that.n == result.m)) { |
| throw new IllegalArgumentException(); |
| } |
| for (int i = 0; i < n; i++) { |
| for (int j = 0; j < that.n; j++) { |
| double s = 0.0; |
| for (int k = 0; k < m; k++) { |
| s += this.get(i, k) * that.get(j, k); |
| } |
| result.put(i, j, s); |
| } |
| } |
| return result; |
| } |
| |
| /** |
| * Tests for equality |
| */ |
| @Override |
| public boolean equals(Object that) { |
| if (this == that) return true; |
| if (!(that instanceof Matrix)) return false; |
| Matrix other = (Matrix) that; |
| if (n != other.n) return false; |
| if (m != other.m) return false; |
| for (int i = 0; i < mem.length; i++) { |
| if (mem[i] != other.mem[i]) return false; |
| } |
| return true; |
| } |
| |
| /** |
| * Calculates a hash code |
| */ |
| @Override |
| public int hashCode() { |
| int h = n * 101 + m; |
| for (int i = 0; i < mem.length; i++) { |
| h = h * 37 + Double.hashCode(mem[i]); |
| } |
| return h; |
| } |
| |
| /** |
| * Makes a string representation |
| * |
| * @return string like "[a, b; c, d]" |
| */ |
| @Override |
| public String toString() { |
| StringBuilder sb = new StringBuilder(n * m * 8); |
| sb.append("["); |
| for (int i = 0; i < mem.length; i++) { |
| if (i > 0) sb.append(i % m == 0 ? "; " : ", "); |
| sb.append(mem[i]); |
| } |
| sb.append("]"); |
| return sb.toString(); |
| } |
| |
| } |