Skip to content

Commit

Permalink
nb not done
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Mar 4, 2016
1 parent f58d55f commit f795d91
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 63 deletions.
10 changes: 10 additions & 0 deletions chapter08/demo.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
% demo for ch08

%% Naive Bayes with Gauss
d = 2;
k = 3;
n = 1000;
[X, t] = kmeansRnd(d,k,n);
plotClass(X,t);

model = nbGauss(X,t);
1 change: 0 additions & 1 deletion chapter08/nb.m

This file was deleted.

21 changes: 21 additions & 0 deletions chapter08/nbGauss.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function model = nbGauss(X, t)
% Naive bayes classifier with indepenet Gauss
% Input:
% X: d x n data matrix
% t: 1 x n label (1~k)
% Output:
% model: trained model structure
% Written by Mo Chen ([email protected]).
n = size(X,2);
k = max(t);
E = sparse(t,1:n,1,k,n,n);
nk = sum(E,2);
a = nk/n;
z = spdiags(1./nk,0,k,k);
mu = X*E'*z;
mm = bsxfun(@times,X*E',1./nk');
sigma = sqdist(mu,X)*z;

model.mu = mu;
model.sigma = sigma;
model.a = a;
14 changes: 14 additions & 0 deletions chapter08/nbGaussPred.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function [y, R] = nbGaussPred(model, X)


mu = model.mu;
sigma = model.sigma;
a = model.a;


lambda = 1./sigma;
ml = mu.*lambda;
q = bsxfun(@plus,X'.^2*lambda-2*X'*ml,dot(mu,ml,1)); % M distance
c = d*log(2*pi)+2*sum(log(sigma),1); % normalization constant
R = -0.5*bsxfun(@plus,q,c);
y = max(R,[],1);
129 changes: 67 additions & 62 deletions chapter09/demo.m
Original file line number Diff line number Diff line change
@@ -1,71 +1,76 @@
% demos for ch09

%% Empirical Bayesian linear regression via EM
close all; clear;
d = 5;
n = 200;
[x,t] = linRnd(d,n);
[model,llh] = linRegEm(x,t);
plot(llh);

%% RVM classification via EM
clear; close all
k = 2;
d = 2;
n = 1000;
[X,t] = kmeansRnd(d,k,n);
[x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));

[model, llh] = rvmBinEm(X,t-1);
plot(llh);
y = rvmBinPred(model,X)+1;
figure;
binPlot(model,X,y);
% close all; clear;
% d = 5;
% n = 200;
% [x,t] = linRnd(d,n);
% [model,llh] = linRegEm(x,t);
% plot(llh);
%
% %% RVM classification via EM
% clear; close all
% k = 2;
% d = 2;
% n = 1000;
% [X,t] = kmeansRnd(d,k,n);
% [x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));
%
% [model, llh] = rvmBinEm(X,t-1);
% plot(llh);
% y = rvmBinPred(model,X)+1;
% figure;
% binPlot(model,X,y);
%% kmeans
close all; clear;
d = 2;
k = 3;
n = 500;
d = 20;
k = 6;
n = 5000;
[X,label] = kmeansRnd(d,k,n);
y = kmeans(X,k);
tic;
y = kmeans_(X,k);
toc
tic
y = kmeans(X',k);
toc
% y = kmedoids(X,k);
plotClass(X,label);
figure;
plotClass(X,y);
% plotClass(X,label);
% figure;
% plotClass(X,y);

%% Gausssian Mixture via EM
close all; clear;
d = 2;
k = 3;
n = 1000;
[X,label] = mixGaussRnd(d,k,n);
plotClass(X,label);

m = floor(n/2);
X1 = X(:,1:m);
X2 = X(:,(m+1):end);
% train
[z1,model,llh] = mixGaussEm(X1,k);
figure;
plot(llh);
figure;
plotClass(X1,z1);
% predict
z2 = mixGaussPred(X2,model);
figure;
plotClass(X2,z2);
%% Gauss mixture initialized by kmeans
close all; clear;
d = 2;
k = 3;
n = 500;
[X,label] = mixGaussRnd(d,k,n);
init = kmeans(X,k);
[z,model,llh] = mixGaussEm(X,init);
plotClass(X,label);
figure;
plotClass(X,init);
figure;
plotClass(X,z);
figure;
plot(llh);
% close all; clear;
% d = 2;
% k = 3;
% n = 1000;
% [X,label] = mixGaussRnd(d,k,n);
% plotClass(X,label);
%
% m = floor(n/2);
% X1 = X(:,1:m);
% X2 = X(:,(m+1):end);
% % train
% [z1,model,llh] = mixGaussEm(X1,k);
% figure;
% plot(llh);
% figure;
% plotClass(X1,z1);
% % predict
% z2 = mixGaussPred(X2,model);
% figure;
% plotClass(X2,z2);
% %% Gauss mixture initialized by kmeans
% close all; clear;
% d = 2;
% k = 3;
% n = 500;
% [X,label] = mixGaussRnd(d,k,n);
% init = kmeans(X,k);
% [z,model,llh] = mixGaussEm(X,init);
% plotClass(X,label);
% figure;
% plotClass(X,init);
% figure;
% plotClass(X,z);
% figure;
% plot(llh);

0 comments on commit f795d91

Please sign in to comment.