Skip to content

Commit

Permalink
Use util.check_random_state
Browse files Browse the repository at this point in the history
See also issue #153
  • Loading branch information
oyamad committed May 29, 2015
1 parent 55bac4d commit a812bf3
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions quantecon/random_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .mc_tools import MarkovChain
from .external import numba_installed, jit
from .util import check_random_state


def random_markov_chain(n, k=None, sparse=False, random_state=None):
Expand Down Expand Up @@ -100,15 +101,8 @@ def random_stochastic_matrix(n, k=None, sparse=False, format='csr',
if not (isinstance(k, int) and 0 < k <= n):
raise ValueError('k must be an integer with 0 < k <= n')

if random_state is None or isinstance(random_state, int):
_random_state = np.random.RandomState(random_state)
elif isinstance(random_state, np.random.RandomState):
_random_state = random_state
else:
raise ValueError

# n prob vectors of dimension k, shape (n, k)
probvecs = random_probvec(n, k, random_state=_random_state)
probvecs = random_probvec(n, k, random_state=random_state)

if k == n:
P = probvecs
Expand All @@ -121,7 +115,7 @@ def random_stochastic_matrix(n, k=None, sparse=False, format='csr',
rows = np.repeat(np.arange(n), k)
cols = \
random_choice_without_replacement(
n, k, num_trials=n, random_state=_random_state
n, k, num_trials=n, random_state=random_state
).ravel()
data = probvecs.ravel()

Expand Down Expand Up @@ -157,8 +151,7 @@ def random_probvec(m, k, random_state=None):
"""
x = np.empty((m, k+1))

if random_state is None:
random_state = np.random
random_state = check_random_state(random_state)
r = random_state.random_sample(size=(m, k-1))

r.sort(axis=-1)
Expand Down Expand Up @@ -210,8 +203,7 @@ def random_choice_without_replacement(n, k, num_trials=None,

m = 1 if num_trials is None else num_trials

if random_state is None:
random_state = np.random
random_state = check_random_state(random_state)
r = random_state.random_sample(size=(m, k))

# Logic taken from random.sample in the standard library
Expand Down

0 comments on commit a812bf3

Please sign in to comment.