diff --git a/run_nn.py b/run_nn.py index de2ec437..0f2947cf 100644 --- a/run_nn.py +++ b/run_nn.py @@ -201,6 +201,10 @@ [loss,err,pout] = net(inp,lab,test_flag) + + if multi_gpu: + loss=loss.mean() + err=err.mean() if do_forward: if rnn==1: