Skip to content

Instantly share code, notes, and snippets.

@zaki50
Created December 8, 2012 07:22
Show Gist options
  • Save zaki50/4239101 to your computer and use it in GitHub Desktop.
Save zaki50/4239101 to your computer and use it in GitHub Desktop.
与えられた座標を3次スプライン補間します。
/*
* Copyright (C) 2012 Makoto Yamazaki <zaki@uphyca.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.zakky.interpolator;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
/**
* http://www.akita-nct.ac.jp/yamamoto/lecture/2004/5E/interpolation/text/html/node3.html 読んで書いた
* http://commons.apache.org/math から commons-math3-3.0 を持ってきてクラスパスを通してください。
*
* 3次スプライン補間なので、N + 1 個の (x, y) 座標の組から N 個の3次関数を作成して補完します。
* また、与えられた座標において、隣接する3次関数の1次導関数と2次導関数が等しくなります。
*/
public class SplineInterpolator {
/**
* 座標列によって区切られる区間(3次関数)の数
*/
private final int N;
/**
* X 座標列
*/
private final double[] mXCoordinates;
/**
* Y 座標列
*/
private final double[] mYCoordinates;
/*
* 補間のための3次関数群の係数情報
*
* 使用する関数のインデックスを j とすると、 0 <= j < N で、
* y = mA[j](x-x[j])^3 + mB[j](x-x[j])^2 + mC[j](x-x[j])^1 + mD[j]
*/
/**
* 3次の項に対する係数列。長さ N。
*/
private final double[] mA;
/**
* 2次の項に対する係数列。長さ N。
*/
private final double[] mB;
/**
* 1次の項に対する係数列。長さ N。
*/
private final double[] mC;
/**
* 0次の項に対する係数列。長さ N。
*/
private final double[] mD;
public SplineInterpolator(double[] xCoordinates, double[] yCoordinates) {
this(xCoordinates, yCoordinates, xCoordinates.length);
}
public SplineInterpolator(double[] xCoordinates, double[] yCoordinates, int length) {
super();
ensureInputsAreValid(xCoordinates, yCoordinates, length);
N = length - 1;
mXCoordinates = new double[N + 1];
System.arraycopy(xCoordinates, 0, mXCoordinates, 0, length);
mYCoordinates = new double[N + 1];
System.arraycopy(yCoordinates, 0, mYCoordinates, 0, length);
mA = new double[N];
mB = new double[N];
mC = new double[N];
mD = new double[N];
// fill mA, mB, mC, mD
calculate();
}
/**
* 与えられた {@code x} に対する y の値を返します。
*
* @param x x値。
* @return (補間によって計算された) y の値。
*/
public double get(double x) {
final double[] xCoordinates = mXCoordinates;
// FIXME 区間がたくさんある場合は NavigableMap とか使った方がいい
int targetFuntionIndex = N - 1; // 見つからない場合は最後のものを使う
for (int j = 0; j < N; j++) {
if (x <= xCoordinates[j + 1]) {
targetFuntionIndex = j;
break;
}
}
// (x - x_j)^1
final double x1 = x - xCoordinates[targetFuntionIndex];
// (x - x_j)^2
final double x2 = x1 * x1;
// (x - x_j)^3
final double x3 = x2 * x1;
final double y = mA[targetFuntionIndex] * x3
+ mB[targetFuntionIndex] * x2
+ mC[targetFuntionIndex] * x1
+ mD[targetFuntionIndex];
return y;
}
private static void ensureInputsAreValid(double[] xCoordinates, double[] yCoordinates,
int length) {
if (length < 2) {
throw new IllegalArgumentException("'length' must be 2 or more.");
}
if (xCoordinates.length < length) {
throw new IllegalArgumentException(
"length of 'xCoordinates' must not be less than 'length'");
}
if (yCoordinates.length < length) {
throw new IllegalArgumentException(
"length of 'yCoordinates' must not be less than 'length'");
}
// x が単調増加であることと、座標にNaN や無限大が含まれないことを確認
double prevX = Double.NEGATIVE_INFINITY;
for (int i = 0; i < length; i++) {
final double x = xCoordinates[i];
final double y = yCoordinates[i];
if (Double.isNaN(x) || Double.isInfinite(x)) {
throw new IllegalArgumentException(
"'xCoodinates' must not contain NaN nor INFINITY");
}
if (Double.isNaN(y) || Double.isInfinite(y)) {
throw new IllegalArgumentException(
"'yCoodinates' must not contain NaN nor INFINITY");
}
if (x <= prevX) {
throw new IllegalArgumentException(
"elements in 'xCoodinates' must monotonically increase");
}
}
}
/**
* 補間に必要な値を計算します。
*/
private void calculate() {
final double[] xCoordinates = mXCoordinates;
final double[] yCoordinates = mYCoordinates;
final double[] h = new double[N];
for (int j = 0; j < N; j++) {
h[j] = xCoordinates[j + 1] - xCoordinates[j];
}
final RealMatrix coefficients = buildCoefficients(h);
final RealVector constants = buildConstants(h);
// 行列式を解いて u[1] から u[N - 1] までを求める。
final DecompositionSolver solver = new LUDecomposition(coefficients).getSolver();
// u の index はずれているので、値を取り出す時は getUAt(int) を使うこと
final RealVector u = solver.solve(constants);
// mA, mB, mC, mD の値を求める
for (int j = 0, length = N; j < length; j++) {
final double u_j = getUAt(u, j);
final double u_j1 = getUAt(u, j + 1);
final double y_j = yCoordinates[j];
final double y_j1 = yCoordinates[j + 1];
mA[j] = (u_j1 - u_j) / (6d * (h[j]));
mB[j] = u_j / 2d;
mC[j] = ((y_j1 - y_j) / h[j]) - ((h[j] * (2d * u_j + u_j1)) / 6d);
mD[j] = y_j;
}
}
/**
* 補間関数の係数を求める際に使用する u[] を計算する際の行列式の係数行列(AU = B の A)
* を構築します。
*
* @param h {@code x[j + 1] - x[j]}の配列
* @return 係数行列。
*/
private RealMatrix buildCoefficients(final double[] h) {
final RealMatrix coefficients = new Array2DRowRealMatrix(N - 1, N - 1);
for (int rowIndex = 0; rowIndex < N - 1; rowIndex++) {
final double targetH = h[rowIndex];
final double nextH = h[rowIndex + 1];
if (rowIndex != 0) {
coefficients.setEntry(rowIndex, rowIndex - 1, targetH);
}
coefficients.setEntry(rowIndex, rowIndex, 2.0 * (targetH + nextH));
if (rowIndex != N - 2) {
coefficients.setEntry(rowIndex, rowIndex + 1, nextH);
}
}
return coefficients;
}
/**
* 補間関数の係数を求める際に使用する u[] を計算する際の行列式の定数列(AU = B の B)
* を構築します。
*
* @param h {@code x[j + 1] - x[j]}の配列
* @return 定数列。
*/
private RealVector buildConstants(final double[] h) {
final double[] yCoordinates = mYCoordinates;
final double[] v = new double[N]; // v[0] は使わない
double pv = (yCoordinates[1] - yCoordinates[0]) / h[0];
for (int j = 1; j < N; j++) {
double temp = (yCoordinates[j + 1] - yCoordinates[j]) / h[j];
v[j] = 6.0 * (temp - pv);
pv = temp;
}
final RealVector constants = new ArrayRealVector(N - 1);
for (int j = 1; j < N; j++) {
constants.setEntry(j - 1, v[j]);
}
return constants;
}
/**
* {@code} u が保持する値のインデックスは、他の計算部分に使用するものとズレが
* あるので、ズレを吸収してアクセスするためのユーティリティメソッドです。
* u に含まれていない値については、 natural spline になるように補います。
*
* @param u u[1] から u[N - 1] の N - 1 個分の u値を保持する {@link RealVector}。
* @param j 計算式上での u のインデックス。
* @return 計算式上での u[j] の値。
*/
private static double getUAt(RealVector u, int j) {
if (j == 0 || u.getDimension() < j) {
// natural spline なので 0。
return 0d;
}
return u.getEntry(j - 1);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment