Skip to content

Instantly share code, notes, and snippets.

@mnarayan
Created October 19, 2017 02:52
Show Gist options
  • Save mnarayan/590de9feccb3e6e156f0c94f5152aa6e to your computer and use it in GitHub Desktop.
Save mnarayan/590de9feccb3e6e156f0c94f5152aa6e to your computer and use it in GitHub Desktop.
How to compute canonical variates and canonical correlations across training and test splits
function [cca_rho cca_v cca_cv] = sample_canonical_correlation(X,Y,varargin)
% SAMPLE_CANONICAL_CORRELATION
%
% Usage: [rho] = sample_canonical_correlation(X,Y, R_X, R_Y)
%
% Inputs:
% - X is the test set data matrix of n_samples x p features
% - Y is the test set data matrix of n_samples x r features
% - options.W_X is the linear projection matrix for X
% - options.W_Y is a linear projection matrix for Y
% - options.mu_x is the training mean of X
% - options.mu_y is the training mean of Y
%
% where,
% (X * W_X, Y * W_Y) produces upto t canonical variates,
% and t <= min(p,r)
%
% Description:
%
% W_X should be the eigenvectors of R_X
% R_X = Shat_XX^{-1/2} * Shat_XY * Shat_YY^{-1} * Shat_YX * Shat_XX^{-1/2}
%
% W_Y should be eigenvectors of R_Y
% R_Y = Shat_YY^{-1/2} * Shat_YX * Shat_XX^{-1} * Shat_XY * Shat_YY^{-1/2}
%
% The sample covariances (Shat_XX, Shat_YY, Shat_XY) are expected to have been obtained on the training data.
%
%
% References
% Modern Multivariate Statistics by Izenmen
%
% Combo17-Galaxy Example
% if(exist('Combo17Galaxy.csv'))
% Combo17 = load('Combo17Galaxy.csv');
% else
% websave('Combo17Galaxy.csv','https://csvshare.com/view/4y6tCe-Tm.csv')
% end
% Yidx = [1 2 4 5 6 8 9]; Xidx = [10:2:16 30:2:65];
% exclude_idx = find(sum(isnan(table2array(Combo17)),2)>0)';
% include_idx = setdiff(1:height(Combo17),exclude_idx);
% n_samples = length(include_idx);
% X = table2array(Combo17(include_idx,Xidx));
% Y = table2array(Combo17(include_idx,Yidx));
% cvobj = cvpartition(n_samples,'Kfold',10);
%
% for foldNo=1:cvobj.NumTestSets
% cca_opts = {};
% cca_opts.mu_X = mean(X(cvobj.training(foldNo),:));
% cca_opts.mu_Y = mean(Y(cvobj.training(foldNo),:));
% [cca_opts.W_X cca_opts.W_Y] = ...
% canoncorr(X(cvobj.training(foldNo),:),Y(cvobj.training(foldNo),:));
% [cca_rho{foldNo} cca_v{foldNo} cca_cv{foldNo}] = ...
% sample_canonical_correlation(...
% X(cvobj.test(foldNo),:), ...
% Y(cvobj.test(foldNo),:), ...
% cca_opts);
% end
% Copyright 2017, Manjari Narayan
options = {};
if nargin==2
options.mu_x = mean(X);
options.mu_y = mean(Y);
Shat_XX = cov(X);
Shatinv_XX = pinv(Shat_XX);
Shat_YY = cov(Y);
Shatinv_YY = pinv(Shat_YY);
Shat_XY = bsxfun(@minus, X, options.mu_X)' * ...
bsxfun(@minus,Y, options.mu_Y);
R_X = sqrtm(Shatinv_XX)*Shat_XY*Shatinv_YY*Shat_XY'*sqrtm(Shatinv_XX);
R_Y = sqrtm(Shatinv_YY)*Shat_XY'*Shatinv_XX*Shat_XY*sqrtm(Shatinv_YY);
t = min(size(X,2),size(Y,2));
[V D] = eig(R_X);
options.W_X = V(:,1:t);
[V D] = eig(R_Y);
options.W_Y = V(:,1:t);
else
options = varargin{1};
end
Xcenter = bsxfun(@minus, X, options.mu_X);
Ycenter = bsxfun(@minus, Y, options.mu_Y);
% Canonical Variates (U, V)
U_CCA = Xcenter * options.W_X;
V_CCA = Ycenter * options.W_Y;
cca_v = cat(3,U_CCA, V_CCA);
cca_sdX = sqrt(diag(U_CCA' * U_CCA));
cca_sdY = sqrt(diag(V_CCA' * V_CCA));
cca_cov = U_CCA' * V_CCA;
cca_rho = diag(inv(diag(cca_sdX)) * cca_cov * inv(diag(cca_sdY)));
disp('Top 3 canonical correlations')
disp(cca_rho(1:3))
cca_cv = [];
t = min([size(X,2) size(Y,2) size(options.W_X,2)]);
for tt=1:t
cca_cv(tt) = sum(cca_rho(1:tt).^2);
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment