Skip to content

Commit

Permalink
add discrete MRF BP and EP
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed May 28, 2017
1 parent 3aedcb4 commit 8ddf99a
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 0 deletions.
63 changes: 63 additions & 0 deletions chapter08/belProp.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
function [nodeBel, edgeBel] = belProp(A, nodePot, edgePot, epoch)
% Belief propagation for MRF
% Assuming egdePot is symmetric
% Input:
% A: n x n adjacent matrix of undirected graph, where value is edge index
% nodePot: k x n node potential
% edgePot: k x k x m edge potential
% Output:
% nodeBel: k x n node belief
% edgeBel: k x k x m edge belief
% L: variational lower bound (Bethe energy)
% Written by Mo Chen ([email protected])
nodePot = exp(-nodePot);
edgePot = exp(-edgePot);

tol = 0;
if nargin < 4
epoch = 10;
tol = 1e-4;
end
[k,n] = size(nodePot);
m = size(edgePot,3);

[s,t,e] = find(tril(A));
A = sparse([s;t],[t;s],[e;e+m]); % digraph adjacent matrix, where value is message index
mu = ones(k,2*m)/k; % message
for iter = 1:epoch
mu0 = mu;
for i = 1:n
in = nonzeros(A(:,i)); % incoming message index
nb = nodePot(:,i).*prod(mu(:,in),2); % product of incoming message
for l = in'
ep = edgePot(:,:,ud(l,m));
mu(:,rd(l,m)) = normalize(ep*(nb./mu(:,l)));
end
end
if max(abs(mu(:)-mu0(:))) < tol; break; end
end

nodeBel = zeros(k,n);
for i = 1:n
nodeBel(:,i) = nodePot(:,i).*prod(mu(:,nonzeros(A(:,i))),2);
end
nodeBel = normalize(nodeBel,1);

edgeBel = zeros(k,k,m);
for l = 1:m
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);
nbt = nodeBel(:,t(l))./mu(:,eij);
nbs = nodeBel(:,s(l))./mu(:,eji);
eb = (nbt*nbs').*ep;
edgeBel(:,:,eij) = eb./sum(eb(:));
end

function i = rd(i, m)
% reverse direction edge index
i = mod(i+m-1,2*m)+1;

function i = ud(i, m)
% undirected edge index
i = mod(i-1,m)+1;
63 changes: 63 additions & 0 deletions chapter08/belProp0.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
function [nodeBel, edgeBel] = belProp0(A, nodePot, edgePot, epoch)
% Belief propagation for MRF, calculation in log scale
% Assuming egdePot is symmetric
% Input:
% A: n x n adjacent matrix of undirected graph, where value is edge index
% nodePot: k x n node potential
% edgePot: k x k x m edge potential
% Output:
% nodeBel: k x n node belief
% edgeBel: k x k x m edge belief
% L: variational lower bound (Bethe energy)
% Written by Mo Chen ([email protected])
tol = 0;
if nargin < 4
epoch = 10;
tol = 1e-4;
end
[k,n] = size(nodePot);
m = size(edgePot,3);

[s,t,e] = find(tril(A));
A = sparse([s;t],[t;s],[e;e+m]); % digraph adjacent matrix, where value is message index
mu = zeros(k,2*m)-log(k); % message
for iter = 1:epoch
mu0 = mu;
for i = 1:n
in = nonzeros(A(:,i)); % incoming message index
nb = -nodePot(:,i)+sum(mu(:,in),2); % product of incoming message
for l = in'
ep = edgePot(:,:,ud(l,m));
mut = logsumexp(-ep+(nb-mu(:,l)),1);
mu(:,rd(l,m)) = mut-logsumexp(mut);
end
end
if max(abs(mu(:)-mu0(:))) < tol; break; end
end

nodeBel = zeros(k,n);
for i = 1:n
nb = -nodePot(:,i)+sum(mu(:,nonzeros(A(:,i))),2);
nodeBel(:,i) = nb-logsumexp(nb);
end

edgeBel = zeros(k,k,m);
for l = 1:m
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);
nbt = nodeBel(:,t(l))-mu(:,eij);
nbs = nodeBel(:,s(l))-mu(:,eji);
eb = (nbt+nbs')-ep;
edgeBel(:,:,eij) = eb-logsumexp(eb(:));
end
nodeBel = exp(nodeBel);
edgeBel = exp(edgeBel);

function i = rd(i, m)
% reverse direction edge index
i = mod(i+m-1,2*m)+1;

function i = ud(i, m)
% undirected edge index
i = mod(i-1,m)+1;
59 changes: 59 additions & 0 deletions chapter08/expProp.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
function [nodeBel, edgeBel] = expProp(A, nodePot, edgePot, epoch)
% Expectation propagation for MRF
% Assuming egdePot is symmetric
% Another implementation with precompute nodeBel and update during iterations
% Input:
% A: n x n adjacent matrix of undirected graph, where value is edge index
% nodePot: k x n node potential
% edgePot: k x k x m edge potential
% Output:
% nodeBel: k x n node belief
% edgeBel: k x k x m edge belief
% L: variational lower bound (Bethe energy)
% Written by Mo Chen ([email protected])

% working in exp domain
nodePot = exp(-nodePot);
edgePot = exp(-edgePot);

tol = 0;
if nargin < 4
epoch = 10;
tol = 1e-4;
end
k = size(nodePot,1);
m = size(edgePot,3);

[s,t,e] = find(tril(A));
mu = ones(k,2*m)/k; % message
nodeBel = normalize(nodePot,1);
for iter = 1:epoch
mu0 = mu;
for l = 1:m
i = s(l);
j = t(l);
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);

nodeBel(:,j) = nodeBel(:,j)./mu(:,eij);
mu(:,eij) = normalize(ep*(nodeBel(:,i)./mu(:,eji)));
nodeBel(:,j) = normalize(nodeBel(:,j).*mu(:,eij));

nodeBel(:,i) = nodeBel(:,i)./mu(:,eji);
mu(:,eji) = normalize(ep*(nodeBel(:,j)./mu(:,eij)));
nodeBel(:,i) = normalize(nodeBel(:,i).*mu(:,eji));
end
if max(abs(mu(:)-mu0(:))) < tol; break; end
end

edgeBel = zeros(k,k,m);
for l = 1:m
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);
nbt = nodeBel(:,t(l))./mu(:,eij);
nbs = nodeBel(:,s(l))./mu(:,eji);
eb = (nbt*nbs').*ep;
edgeBel(:,:,eij) = eb./sum(eb(:));
end
60 changes: 60 additions & 0 deletions chapter08/expProp0.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
function [nodeBel, edgeBel] = expProp0(A, nodePot, edgePot, epoch)
% Expectation propagation for MRF, calculation in log scale
% Assuming egdePot is symmetric
% Another implementation with precompute nodeBel and update during iterations
% Input:
% A: n x n adjacent matrix of undirected graph, where value is edge index
% nodePot: k x n node potential
% edgePot: k x k x m edge potential
% Output:
% nodeBel: k x n node belief
% edgeBel: k x k x m edge belief
% L: variational lower bound (Bethe energy)
% Written by Mo Chen ([email protected])
tol = 0;
if nargin < 4
epoch = 10;
tol = 1e-4;
end
k = size(nodePot,1);
m = size(edgePot,3);

[s,t,e] = find(tril(A));
mu = zeros(k,2*m)-log(k);
nodeBel = -nodePot-logsumexp(-nodePot,1);
for iter = 1:epoch
mu0 = mu;
for l = 1:m
i = s(l);
j = t(l);
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);

nodeBel(:,j) = nodeBel(:,j)-mu(:,eij);
mut = logsumexp(-ep+(nodeBel(:,i)-mu(:,eji)),1);
mu(:,eij) = mut-logsumexp(mut);
nb = nodeBel(:,j)+mu(:,eij);
nodeBel(:,j) = nb-logsumexp(nb);

nodeBel(:,i) = nodeBel(:,i)-mu(:,eji);
mut = logsumexp(-ep+(nodeBel(:,j)-mu(:,eij)),1);
mu(:,eji) = mut-logsumexp(mut);
nb = nodeBel(:,i)+mu(:,eji);
nodeBel(:,i) = nb-logsumexp(nb);
end
if max(abs(mu(:)-mu0(:))) < tol; break; end
end

edgeBel = zeros(k,k,m);
for l = 1:m
eij = e(l);
eji = eij+m;
ep = edgePot(:,:,eij);
nbt = nodeBel(:,t(l))-mu(:,eij);
nbs = nodeBel(:,s(l))-mu(:,eji);
eb = (nbt+nbs')-ep;
edgeBel(:,:,eij) = eb-logsumexp(eb(:));
end
nodeBel = exp(nodeBel);
edgeBel = exp(edgeBel);
18 changes: 18 additions & 0 deletions chapter08/imageMeanField.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function nodeBel = imageMeanField(M, N, nodePot, edgePot, epoch)
if nargin < 5
epoch = 10;
end
stride = [-1,1,-M,M];
nodeBel = softmax(-nodePot,1);
for t = 1:epoch
for j = 1:N
for i = 1:M
pos = i + M*(j-1);
ne = pos + stride;
ne([i,i,j,j] == [1,M,1,N]) = [];
nodeBel(:,pos) = softmax(-edgePot*sum(nodeBel(:,ne),2)-nodePot(:,pos));
end
end
end


18 changes: 18 additions & 0 deletions chapter08/isingMeanField.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function mu = isingMeanField(J, h, epoch)
if nargin < 3
epoch = 10;
end
[M,N] = size(h);
mu = tanh(h);
stride = [-1,1,-M,M];
for t = 1:epoch
for j = 1:N
for i = 1:M
pos = i + M*(j-1);
ne = pos + stride;
ne([i,i,j,j] == [1,M,1,N]) = [];
mu(i,j) = tanh(J*sum(mu(ne)) + h(i,j));
end
end
end

18 changes: 18 additions & 0 deletions chapter08/isingMeanField0.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function mu = isingMeanField0(J, h, epoch)
% use padding trick
if nargin < 3
epoch = 10;
end
mu = zeros(size(h)+2); % padding
[m,n] = size(mu);
mu(2:m-1,2:n-1) = tanh(h); % init
stride = [-1,1,-m,m];
for t = 1:epoch
for j = 2:n-1
for i = 2:m-1
ne = i + m*(j-1) + stride;
mu(i,j) = tanh(J*sum(mu(ne))+h(i-1,j-1));
end
end
end
mu = mu(2:m-1,2:n-1);

0 comments on commit 8ddf99a

Please sign in to comment.