Skip to content

Commit

Permalink
refined lingRegVb
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Feb 15, 2016
1 parent 84372cf commit 5dc0656
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions chapter10/linRegVb.m
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,34 @@

maxiter = 100;
energy = -inf(1,maxiter+1);
idx = (1:m)';
dg = sub2ind([m,m],idx,idx);
I = eye(m);
tol = 1e-8;

a = a0+m/2;
a = a0+m/2; % 10.94
c = c0+n/2;
Ealpha = 1e-4;
Ebeta = 1e-4;
for iter = 2:maxiter
invS = Ebeta*XX;
invS(dg) = invS(dg)+Ealpha;
% q(w)
invS = diag(Ealpha)+Ebeta*XX; % 10.101
U = chol(invS);
Ew = Ebeta*(U\(U'\Xt));

Ew = Ebeta*(U\(U'\Xt)); % 10.100
KLw = -sum(log(diag(U)));
% q(alpha)
w2 = dot(Ew,Ew);
e2 = sum((t-Ew'*X).^2);
invU = U\I;
trS = dot(invU(:),invU(:));
b = b0+0.5*(w2+trS); % 10.95
Ealpha = a/b; % 10.102
KLalpha = -a*log(b);
% q(beta)
e2 = sum((t-Ew'*X).^2);
invUX = U\X;
trXSX = dot(invUX(:),invUX(:));

b = b0+0.5*(w2+trS);
d = d0+0.5*(e2+trXSX);

Ealpha = a/b;
Ebeta = c/d;
logdetS = -2*sum(log(diag(U)));
energy(iter) = -a*log(b)-c*log(d)+0.5*logdetS;
KLbeta = -c*log(d);
% lower bound
energy(iter) = KLalpha+KLbeta+KLw;
if energy(iter)-energy(iter-1) < tol*abs(energy(iter-1)); break; end
end
const = gammaln(a)-gammaln(a0)+gammaln(c)-gammaln(c0)+a0*log(b0)+c0*log(d0)+0.5*(m-n*log(2*pi));
Expand All @@ -65,8 +63,8 @@

model.w0 = w0;
model.w = Ew;
model.Ealpha = Ealpha;
model.Ebeta = Ebeta;
model.alpha = Ealpha;
model.beta = Ebeta;
model.a = a;
model.b = b;
model.c = c;
Expand Down

0 comments on commit 5dc0656

Please sign in to comment.