1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.distribution;
18
19 import java.util.Arrays;
20 import org.apache.commons.statistics.distribution.ContinuousDistribution;
21 import org.apache.commons.statistics.distribution.NormalDistribution;
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
24 import org.apache.commons.math4.legacy.linear.EigenDecomposition;
25 import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException;
26 import org.apache.commons.math4.legacy.linear.RealMatrix;
27 import org.apache.commons.math4.legacy.linear.SingularMatrixException;
28 import org.apache.commons.rng.UniformRandomProvider;
29 import org.apache.commons.math4.core.jdkmath.JdkMath;
30
31
32
33
34
35
36
37
38
39
40
41 public class MultivariateNormalDistribution
42 extends AbstractMultivariateRealDistribution {
43
44 private final double[] means;
45
46 private final RealMatrix covarianceMatrix;
47
48 private final RealMatrix covarianceMatrixInverse;
49
50 private final double covarianceMatrixDeterminant;
51
52 private final RealMatrix samplingMatrix;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 public MultivariateNormalDistribution(final double[] means,
73 final double[][] covariances)
74 throws SingularMatrixException,
75 DimensionMismatchException,
76 NonPositiveDefiniteMatrixException {
77 super(means.length);
78
79 final int dim = means.length;
80
81 if (covariances.length != dim) {
82 throw new DimensionMismatchException(covariances.length, dim);
83 }
84
85 for (int i = 0; i < dim; i++) {
86 if (dim != covariances[i].length) {
87 throw new DimensionMismatchException(covariances[i].length, dim);
88 }
89 }
90
91 this.means = Arrays.copyOf(means, means.length);
92
93 covarianceMatrix = new Array2DRowRealMatrix(covariances);
94
95
96 final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
97
98
99 covarianceMatrixInverse = covMatDec.getSolver().getInverse();
100
101 covarianceMatrixDeterminant = covMatDec.getDeterminant();
102
103
104 final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
105
106 for (int i = 0; i < covMatEigenvalues.length; i++) {
107 if (covMatEigenvalues[i] < 0) {
108 throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
109 }
110 }
111
112
113 final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
114 for (int v = 0; v < dim; v++) {
115 final double[] evec = covMatDec.getEigenvector(v).toArray();
116 covMatEigenvectors.setColumn(v, evec);
117 }
118
119 final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
120
121
122 for (int row = 0; row < dim; row++) {
123 final double factor = JdkMath.sqrt(covMatEigenvalues[row]);
124 for (int col = 0; col < dim; col++) {
125 tmpMatrix.multiplyEntry(row, col, factor);
126 }
127 }
128
129 samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
130 }
131
132
133
134
135
136
137 public double[] getMeans() {
138 return Arrays.copyOf(means, means.length);
139 }
140
141
142
143
144
145
146 public RealMatrix getCovariances() {
147 return covarianceMatrix.copy();
148 }
149
150
151 @Override
152 public double density(final double[] vals) throws DimensionMismatchException {
153 final int dim = getDimension();
154 if (vals.length != dim) {
155 throw new DimensionMismatchException(vals.length, dim);
156 }
157
158 return JdkMath.pow(2 * JdkMath.PI, -0.5 * dim) *
159 JdkMath.pow(covarianceMatrixDeterminant, -0.5) *
160 getExponentTerm(vals);
161 }
162
163
164
165
166
167
168
169 public double[] getStandardDeviations() {
170 final int dim = getDimension();
171 final double[] std = new double[dim];
172 final double[][] s = covarianceMatrix.getData();
173 for (int i = 0; i < dim; i++) {
174 std[i] = JdkMath.sqrt(s[i][i]);
175 }
176 return std;
177 }
178
179
180 @Override
181 public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
182 return new MultivariateRealDistribution.Sampler() {
183
184 private final ContinuousDistribution.Sampler gauss = NormalDistribution.of(0, 1).createSampler(rng);
185
186
187 @Override
188 public double[] sample() {
189 final int dim = getDimension();
190 final double[] normalVals = new double[dim];
191
192 for (int i = 0; i < dim; i++) {
193 normalVals[i] = gauss.sample();
194 }
195
196 final double[] vals = samplingMatrix.operate(normalVals);
197
198 for (int i = 0; i < dim; i++) {
199 vals[i] += means[i];
200 }
201
202 return vals;
203 }
204 };
205 }
206
207
208
209
210
211
212
213 private double getExponentTerm(final double[] values) {
214 final double[] centered = new double[values.length];
215 for (int i = 0; i < centered.length; i++) {
216 centered[i] = values[i] - means[i];
217 }
218 final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
219 double sum = 0;
220 for (int i = 0; i < preMultiplied.length; i++) {
221 sum += preMultiplied[i] * centered[i];
222 }
223 return JdkMath.exp(-0.5 * sum);
224 }
225 }