Created
October 19, 2017 02:52
-
-
Save mnarayan/590de9feccb3e6e156f0c94f5152aa6e to your computer and use it in GitHub Desktop.
How to compute canonical variates and canonical correlations across training and test splits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [cca_rho cca_v cca_cv] = sample_canonical_correlation(X,Y,varargin) | |
% SAMPLE_CANONICAL_CORRELATION | |
% | |
% Usage: [rho] = sample_canonical_correlation(X,Y, R_X, R_Y) | |
% | |
% Inputs: | |
% - X is the test set data matrix of n_samples x p features | |
% - Y is the test set data matrix of n_samples x r features | |
% - options.W_X is the linear projection matrix for X | |
% - options.W_Y is a linear projection matrix for Y | |
% - options.mu_x is the training mean of X | |
% - options.mu_y is the training mean of Y | |
% | |
% where, | |
% (X * W_X, Y * W_Y) produces upto t canonical variates, | |
% and t <= min(p,r) | |
% | |
% Description: | |
% | |
% W_X should be the eigenvectors of R_X | |
% R_X = Shat_XX^{-1/2} * Shat_XY * Shat_YY^{-1} * Shat_YX * Shat_XX^{-1/2} | |
% | |
% W_Y should be eigenvectors of R_Y | |
% R_Y = Shat_YY^{-1/2} * Shat_YX * Shat_XX^{-1} * Shat_XY * Shat_YY^{-1/2} | |
% | |
% The sample covariances (Shat_XX, Shat_YY, Shat_XY) are expected to have been obtained on the training data. | |
% | |
% | |
% References | |
% Modern Multivariate Statistics by Izenmen | |
% | |
% Combo17-Galaxy Example | |
% if(exist('Combo17Galaxy.csv')) | |
% Combo17 = load('Combo17Galaxy.csv'); | |
% else | |
% websave('Combo17Galaxy.csv','https://csvshare.com/view/4y6tCe-Tm.csv') | |
% end | |
% Yidx = [1 2 4 5 6 8 9]; Xidx = [10:2:16 30:2:65]; | |
% exclude_idx = find(sum(isnan(table2array(Combo17)),2)>0)'; | |
% include_idx = setdiff(1:height(Combo17),exclude_idx); | |
% n_samples = length(include_idx); | |
% X = table2array(Combo17(include_idx,Xidx)); | |
% Y = table2array(Combo17(include_idx,Yidx)); | |
% cvobj = cvpartition(n_samples,'Kfold',10); | |
% | |
% for foldNo=1:cvobj.NumTestSets | |
% cca_opts = {}; | |
% cca_opts.mu_X = mean(X(cvobj.training(foldNo),:)); | |
% cca_opts.mu_Y = mean(Y(cvobj.training(foldNo),:)); | |
% [cca_opts.W_X cca_opts.W_Y] = ... | |
% canoncorr(X(cvobj.training(foldNo),:),Y(cvobj.training(foldNo),:)); | |
% [cca_rho{foldNo} cca_v{foldNo} cca_cv{foldNo}] = ... | |
% sample_canonical_correlation(... | |
% X(cvobj.test(foldNo),:), ... | |
% Y(cvobj.test(foldNo),:), ... | |
% cca_opts); | |
% end | |
% Copyright 2017, Manjari Narayan | |
options = {}; | |
if nargin==2 | |
options.mu_x = mean(X); | |
options.mu_y = mean(Y); | |
Shat_XX = cov(X); | |
Shatinv_XX = pinv(Shat_XX); | |
Shat_YY = cov(Y); | |
Shatinv_YY = pinv(Shat_YY); | |
Shat_XY = bsxfun(@minus, X, options.mu_X)' * ... | |
bsxfun(@minus,Y, options.mu_Y); | |
R_X = sqrtm(Shatinv_XX)*Shat_XY*Shatinv_YY*Shat_XY'*sqrtm(Shatinv_XX); | |
R_Y = sqrtm(Shatinv_YY)*Shat_XY'*Shatinv_XX*Shat_XY*sqrtm(Shatinv_YY); | |
t = min(size(X,2),size(Y,2)); | |
[V D] = eig(R_X); | |
options.W_X = V(:,1:t); | |
[V D] = eig(R_Y); | |
options.W_Y = V(:,1:t); | |
else | |
options = varargin{1}; | |
end | |
Xcenter = bsxfun(@minus, X, options.mu_X); | |
Ycenter = bsxfun(@minus, Y, options.mu_Y); | |
% Canonical Variates (U, V) | |
U_CCA = Xcenter * options.W_X; | |
V_CCA = Ycenter * options.W_Y; | |
cca_v = cat(3,U_CCA, V_CCA); | |
cca_sdX = sqrt(diag(U_CCA' * U_CCA)); | |
cca_sdY = sqrt(diag(V_CCA' * V_CCA)); | |
cca_cov = U_CCA' * V_CCA; | |
cca_rho = diag(inv(diag(cca_sdX)) * cca_cov * inv(diag(cca_sdY))); | |
disp('Top 3 canonical correlations') | |
disp(cca_rho(1:3)) | |
cca_cv = []; | |
t = min([size(X,2) size(Y,2) size(options.W_X,2)]); | |
for tt=1:t | |
cca_cv(tt) = sum(cca_rho(1:tt).^2); | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment