Skip to content

Commit

Permalink
added mxGaussVbPred
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Feb 22, 2016
1 parent 4f217ee commit 6ab1c9d
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 34 deletions.
32 changes: 32 additions & 0 deletions chapter03/linRegPred.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
function [y, sigma, p] = linRegPred(model, X, t)
% Compute linear regression model reponse y = w'*X+w0 and likelihood
% Input:
% model: trained model structure
% X: d x n testing data
% t (optional): 1 x n testing response
% Output:
% y: 1 x n prediction
% sigma: variance
% p: 1 x n likelihood of t
% Written by Mo Chen ([email protected]).
w = model.w;
w0 = model.w0;
y = w'*X+w0;
%% probability prediction
if nargout > 1
beta = model.beta;
if isfield(model,'U')
U = model.U; % 3.54
Xo = bsxfun(@minus,X,model.xbar);
XU = U'\Xo;
sigma = sqrt((1+dot(XU,XU,1))/beta); % 3.59
else
sigma = sqrt(1/beta)*ones(1,size(X,2));
end
end

if nargin == 3 && nargout == 3
p = exp(logGauss(t,y,sigma));
% p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma));
end

70 changes: 38 additions & 32 deletions chapter10/demo.m
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
% demos for ch10
% chapter10/12: prediction functions for VB
%% regression
clear; close all;

d = 100;
beta = 1e-1;
X = rand(1,d);
w = randn;
b = randn;
t = w'*X+b+beta*randn(1,d);
x = linspace(min(X),max(X),d); % test data

[model,llh] = linRegVb(X,t);
% [model,llh] = rvmRegVb(X,t);
plot(llh);
[y, sigma] = linRegPred(model,x,t);
figure
plotCurveBar(x,y,sigma);
hold on;
plot(X,t,'o');
hold off
%% Variational Bayesian for linear\RVM regression
% clear; close all;
%
% d = 100;
% beta = 1e-1;
% X = rand(1,d);
% w = randn;
% b = randn;
% t = w'*X+b+beta*randn(1,d);
% x = linspace(min(X),max(X),d); % test data
%
% [model,llh] = linRegVb(X,t);
% % [model,llh] = rvmRegVb(X,t);
% plot(llh);
% [y, sigma] = linRegPred(model,x,t);
% figure
% plotCurveBar(x,y,sigma);
% hold on;
% plot(X,t,'o');
% hold off
%% Variational Bayesian for Gaussian Mixture Model
% close all; clear;
% d = 2;
% k = 3;
% n = 2000;
% [X,label] = mixGaussRnd(d,k,n);
% plotClass(X,label);
% [y, model, L] = mixGaussVb(X,10);
% figure;
% plotClass(X,y);
% figure;
% plot(L)

close all; clear;
d = 2;
k = 3;
n = 2000;
[X,z] = mixGaussRnd(d,k,n);
plotClass(X,z);
Xt = X(:,n/2+1:end);
X = X(:,1:n/2);
% VB fitting
[y, model, L] = mixGaussVb(X,10);
figure;
plotClass(X,y);
figure;
plot(L)
% Predict testing data
[yt, R] = mixGaussVbPred(model,Xt);
figure;
plotClass(Xt,yt);

28 changes: 26 additions & 2 deletions chapter10/mixGaussVbPred.m
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
function [label, R] = mixGaussVbPred(X, model)
function [z, R] = mixGaussVbPred(model, X)
% Predict label and responsibility for Gaussian mixture model trained by VB.
% Input:
% X: d x n data matrix
% model: trained model structure outputed by the EM algirthm
% Output:
% label: 1 x n cluster label
% R: k x n responsibility
% Written by Mo Chen ([email protected]).
% Written by Mo Chen ([email protected]).
alpha = model.alpha; % Dirichlet
kappa = model.kappa; % Gaussian
m = model.m; % Gasusian
v = model.v; % Whishart
U = model.U; % Whishart
logW = model.logW;
n = size(X,2);
[d,k] = size(m);

EQ = zeros(n,k);
for i = 1:k
Q = (U(:,:,i)'\bsxfun(@minus,X,m(:,i)));
EQ(:,i) = d/kappa(i)+v(i)*dot(Q,Q,1); % 10.64
end
ElogLambda = sum(psi(0,0.5*bsxfun(@minus,v+1,(1:d)')),1)+d*log(2)+logW; % 10.65
Elogpi = psi(0,alpha)-psi(0,sum(alpha)); % 10.66
logRho = -0.5*bsxfun(@minus,EQ,ElogLambda-d*log(2*pi)); % 10.46
logRho = bsxfun(@plus,logRho,Elogpi); % 10.46
logR = bsxfun(@minus,logRho,logsumexp(logRho,2)); % 10.49
R = exp(logR);
z = zeros(1,n);
[~,z(:)] = max(R,[],2);
[~,~,z(:)] = unique(z);

17 changes: 17 additions & 0 deletions common/plotCurveBar.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function plotCurveBar( x, y, sigma )
% Plot 1d curve and variance
% Input:
% x: 1 x n
% y: 1 x n
% sigma: 1 x n or scaler
% Written by Mo Chen ([email protected]).
color = [255,228,225]/255; %pink
[x,idx] = sort(x);
y = y(idx);
sigma = sigma(idx);

fill([x,fliplr(x)],[y+sigma,fliplr(y-sigma)],color);
hold on;
plot(x,y,'r-');
hold off

0 comments on commit 6ab1c9d

Please sign in to comment.