-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_frnn.m
executable file
·90 lines (74 loc) · 2.7 KB
/
main_frnn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
addpath('util');
% input_dir = '../data/easy2-c2-p1';
% train_dir = fullfile(input_dir, 'train');
% train_name = 'test2-c2-p1';
% train_name_list = 'train';
input_dir = '../feature_1_100_reduce/Vec';
train_dir = fullfile(input_dir, 'train');
data_list = 1;
train_name = 'train_001'; % define by yourself
train_name_list = {};
for i = 1:length(data_list)
train_name_list{i} = sprintf('train_%03d.mat', data_list(i));
end
tic
%[Y_train, X_train] = rnn_load_data(train_dir, train_name_list);
[Y_train, X_train] = rnn_load_binary_data(train_dir, train_name_list);
toc
class_filename = 'class_map/class_mapping_100';
class_map = dlmread(class_filename);
C_train = map_y_to_class(Y_train, class_map);
%opts.num_class = find_max_label(Y_train);
opts.num_label = 3784;
opts.num_data = length(Y_train);
opts.num_class = max(class_map);
opts.num_dim = size(X_train{1}, 2);
opts.learning_rate = 0.01;
opts.epoch = 100;
opts.epoch_to_save = 10;
opts.weight_decay = 0.005;
opts.momentum = 0.0;
opts.rmsprop_alpha = 0.9;
opts.bptt_depth = 3;
opts.gradient_thr = 0.5;
opts.hidden = 50;
opts.structure = [opts.num_dim, opts.hidden, opts.num_label+opts.num_class];
opts.activation = 'sigmoid'; % options: sigmoid, relu
opts.update_grad = 'sgd'; % options: sgd, rmsprop
parameter = sprintf('frnn_%s_hidden%d_lr%s_wd%s_m%s_bptt%s_thr%s', ...
train_name, opts.hidden, ...
num2str(opts.learning_rate), num2str(opts.weight_decay), ...
num2str(opts.momentum), num2str(opts.bptt_depth), ...
num2str(opts.gradient_thr));
opts.model_dir = fullfile('../model', parameter);
if( ~exist(opts.model_dir, 'dir') )
fprintf('mkdir %s\n', opts.model_dir);
mkdir(opts.model_dir);
end
model = frnn_init(opts, class_map);
model = frnn_train(model, X_train, Y_train, C_train);
test_dir = fullfile(input_dir, 'test');
fprintf('Load test data from %s\n', test_dir);
L = dir(test_dir);
test_list = {};
for i = 1:length(L)
if(L(i).name(1) == '.')
continue;
end
test_list{end+1} = L(i).name;
end
N = length(test_list);
result = zeros(N, 1);
Y_pred = cell(N, 1);
fprintf('FRNN testing...\n');
for t = 1:N
[Y_test, X_test] = rnn_load_data(test_dir, test_list{t}, 1);
[res, y_pred, cost] = frnn_test(model, X_test, Y_test);
result(t) = res;
Y_pred{t} = y_pred;
end
filename = fullfile(opts.model_dir, sprintf('epoch%d.csv', opts.epoch));
save_kaggle_csv(filename, result);
%answer = dlmread(fullfile(input_dir, 'testing_ans'))+1;
answer = dlmread('google.ans');
acc = mean(answer == result)