forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
113 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
function [y, sigma, p] = linRegPred(model, X, t) | ||
% Compute linear regression model reponse y = w'*X+w0 and likelihood | ||
% Input: | ||
% model: trained model structure | ||
% X: d x n testing data | ||
% t (optional): 1 x n testing response | ||
% Output: | ||
% y: 1 x n prediction | ||
% sigma: variance | ||
% p: 1 x n likelihood of t | ||
% Written by Mo Chen ([email protected]). | ||
w = model.w; | ||
w0 = model.w0; | ||
y = w'*X+w0; | ||
%% probability prediction | ||
if nargout > 1 | ||
beta = model.beta; | ||
if isfield(model,'U') | ||
U = model.U; % 3.54 | ||
Xo = bsxfun(@minus,X,model.xbar); | ||
XU = U'\Xo; | ||
sigma = sqrt((1+dot(XU,XU,1))/beta); % 3.59 | ||
else | ||
sigma = sqrt(1/beta)*ones(1,size(X,2)); | ||
end | ||
end | ||
|
||
if nargin == 3 && nargout == 3 | ||
p = exp(logGauss(t,y,sigma)); | ||
% p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma)); | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,42 @@ | ||
% demos for ch10 | ||
% chapter10/12: prediction functions for VB | ||
%% regression | ||
clear; close all; | ||
|
||
d = 100; | ||
beta = 1e-1; | ||
X = rand(1,d); | ||
w = randn; | ||
b = randn; | ||
t = w'*X+b+beta*randn(1,d); | ||
x = linspace(min(X),max(X),d); % test data | ||
|
||
[model,llh] = linRegVb(X,t); | ||
% [model,llh] = rvmRegVb(X,t); | ||
plot(llh); | ||
[y, sigma] = linRegPred(model,x,t); | ||
figure | ||
plotCurveBar(x,y,sigma); | ||
hold on; | ||
plot(X,t,'o'); | ||
hold off | ||
%% Variational Bayesian for linear\RVM regression | ||
% clear; close all; | ||
% | ||
% d = 100; | ||
% beta = 1e-1; | ||
% X = rand(1,d); | ||
% w = randn; | ||
% b = randn; | ||
% t = w'*X+b+beta*randn(1,d); | ||
% x = linspace(min(X),max(X),d); % test data | ||
% | ||
% [model,llh] = linRegVb(X,t); | ||
% % [model,llh] = rvmRegVb(X,t); | ||
% plot(llh); | ||
% [y, sigma] = linRegPred(model,x,t); | ||
% figure | ||
% plotCurveBar(x,y,sigma); | ||
% hold on; | ||
% plot(X,t,'o'); | ||
% hold off | ||
%% Variational Bayesian for Gaussian Mixture Model | ||
% close all; clear; | ||
% d = 2; | ||
% k = 3; | ||
% n = 2000; | ||
% [X,label] = mixGaussRnd(d,k,n); | ||
% plotClass(X,label); | ||
% [y, model, L] = mixGaussVb(X,10); | ||
% figure; | ||
% plotClass(X,y); | ||
% figure; | ||
% plot(L) | ||
|
||
close all; clear; | ||
d = 2; | ||
k = 3; | ||
n = 2000; | ||
[X,z] = mixGaussRnd(d,k,n); | ||
plotClass(X,z); | ||
Xt = X(:,n/2+1:end); | ||
X = X(:,1:n/2); | ||
% VB fitting | ||
[y, model, L] = mixGaussVb(X,10); | ||
figure; | ||
plotClass(X,y); | ||
figure; | ||
plot(L) | ||
% Predict testing data | ||
[yt, R] = mixGaussVbPred(model,Xt); | ||
figure; | ||
plotClass(Xt,yt); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,33 @@ | ||
function [label, R] = mixGaussVbPred(X, model) | ||
function [z, R] = mixGaussVbPred(model, X) | ||
% Predict label and responsibility for Gaussian mixture model trained by VB. | ||
% Input: | ||
% X: d x n data matrix | ||
% model: trained model structure outputed by the EM algirthm | ||
% Output: | ||
% label: 1 x n cluster label | ||
% R: k x n responsibility | ||
% Written by Mo Chen ([email protected]). | ||
% Written by Mo Chen ([email protected]). | ||
alpha = model.alpha; % Dirichlet | ||
kappa = model.kappa; % Gaussian | ||
m = model.m; % Gasusian | ||
v = model.v; % Whishart | ||
U = model.U; % Whishart | ||
logW = model.logW; | ||
n = size(X,2); | ||
[d,k] = size(m); | ||
|
||
EQ = zeros(n,k); | ||
for i = 1:k | ||
Q = (U(:,:,i)'\bsxfun(@minus,X,m(:,i))); | ||
EQ(:,i) = d/kappa(i)+v(i)*dot(Q,Q,1); % 10.64 | ||
end | ||
ElogLambda = sum(psi(0,0.5*bsxfun(@minus,v+1,(1:d)')),1)+d*log(2)+logW; % 10.65 | ||
Elogpi = psi(0,alpha)-psi(0,sum(alpha)); % 10.66 | ||
logRho = -0.5*bsxfun(@minus,EQ,ElogLambda-d*log(2*pi)); % 10.46 | ||
logRho = bsxfun(@plus,logRho,Elogpi); % 10.46 | ||
logR = bsxfun(@minus,logRho,logsumexp(logRho,2)); % 10.49 | ||
R = exp(logR); | ||
z = zeros(1,n); | ||
[~,z(:)] = max(R,[],2); | ||
[~,~,z(:)] = unique(z); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
function plotCurveBar( x, y, sigma ) | ||
% Plot 1d curve and variance | ||
% Input: | ||
% x: 1 x n | ||
% y: 1 x n | ||
% sigma: 1 x n or scaler | ||
% Written by Mo Chen ([email protected]). | ||
color = [255,228,225]/255; %pink | ||
[x,idx] = sort(x); | ||
y = y(idx); | ||
sigma = sigma(idx); | ||
|
||
fill([x,fliplr(x)],[y+sigma,fliplr(y-sigma)],color); | ||
hold on; | ||
plot(x,y,'r-'); | ||
hold off | ||
|