Skip to content

Commit

Permalink
added flag for using last params in chooseParams
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaitanya Talnikar committed Apr 24, 2015
1 parent 956af36 commit ec3b421
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions gpExp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def updateKernelParams(self, paramsIn):
del params['noise']
self.kernel.updateHyperParameters(params)
#
def findOptParamsLogLike(self, pts, evals, paramsStart=None, paramLowerBounds=None, paramUpperBounds=None,useNoise=None, maxiter=40):
def findOptParamsLogLike(self, pts, evals, paramsStart=None, paramLowerBounds=None, paramUpperBounds=None,useNoise=None, maxiter=40, useLastParams=True):
"""
Compute the optimal hyperparameters by maximizing the marginal log likelihood
Expand Down Expand Up @@ -578,33 +578,36 @@ def objFunc(in0, gradIn):
return out


paramsOut, optValue = self.chooseParams(paramLb, paramUb, paramVals, objFunc, maxiter=maxiter)
paramsOut, optValue = self.chooseParams(paramLb, paramUb, paramVals, objFunc, maxiter=maxiter, useLastParams=useLastParams)
bestParams = np.zeros(np.shape(paramsOut))

params = dict(zip(keys,paramsOut))
self.updateKernelParams(params)
return params, optValue

def chooseParams(self, paramLowerBounds, paramUpperBounds, startValues, costFunction,maxiter=40):
def chooseParams(self, paramLowerBounds, paramUpperBounds, startValues, costFunction,maxiter=40, useLastParams=True):

if NLOPT is True:
local_opt = nlopt.opt(nlopt.LN_COBYLA, len(startValues))
local_opt = nlopt.opt(nlopt.LN_COBYLA, len(startValues))

local_opt.set_xtol_rel(1e-3)
local_opt.set_ftol_rel(1e-3)
local_opt.set_ftol_abs(1e-3)
local_opt.set_maxtime(10);
local_opt.set_maxeval(50*len(startValues));
local_opt.set_lower_bounds(paramLowerBounds)
local_opt.set_upper_bounds(paramUpperBounds)
local_opt.set_xtol_rel(1e-3)
local_opt.set_ftol_rel(1e-3)
local_opt.set_ftol_abs(1e-3)
local_opt.set_maxtime(10);
local_opt.set_maxeval(50*len(startValues));

local_opt.set_lower_bounds(paramLowerBounds)
local_opt.set_upper_bounds(paramUpperBounds)

try:
try:
local_opt.set_min_objective(costFunction)
sol = local_opt.optimize(startValues)
except nlopt.RoundoffLimited:
return costFunction.last_x_value, costFunction.last_f_value
return sol, local_opt.last_optimum_value()
except nlopt.RoundoffLimited:
if useLastParams:
return costFunction.last_x_value, costFunction.last_f_value
else:
return startValues, None
return sol, local_opt.last_optimum_value()
else:
maxeval = 100
bounds = zip(paramLowerBounds, paramUpperBounds)
Expand Down

0 comments on commit ec3b421

Please sign in to comment.