Skip to content

Commit

Permalink
add some docs string & refactorization
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Ahmed committed Jan 29, 2024
1 parent c00ee32 commit 568bb62
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions rl/chapter03/mk_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ def dtype(self):
alpha1: float = 0.25 # strength of mean-reversion (non-negative value)
seed: int = 42

def init_keys(self, seed: int):
def generate_keys(self, seed: int) -> jnp.ndarray:
"""
This function generates a set of random keys using JAX's deterministic
random number generation system. This system ensures reproducible results
across different runs by stabilizing numerical outputs.
Parameters:
- seed (int): The seed number used for randomization.
Returns:
- jnp.ndarray: An array of generated random keys.
"""
rng = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(rng)
return key
Expand All @@ -36,7 +47,7 @@ def up_prob(self, state: State) -> float:
return 1. / (1 + jnp.exp(-self.alpha1 * (self.level_param - state.price)))

def next_state(self, state: State) -> State:
key = self.init_keys(self.seed)
key = self.generate_keys(self.seed)
up_move: jnp.ndarray = jax.random.binomial(key, 1, self.up_prob(state))
self.seed = key[0]
return Process1.State(price=state.price + up_move * 2 - 1)
Expand All @@ -51,8 +62,19 @@ class State:

alpha2: float = 0.75 # strength of reverse-pull (value in [0, 1])
seed: int = 42

def generate_keys(self, seed: int) -> jnp.ndarray:
"""
This function generates a set of random keys using JAX's deterministic
random number generation system. This system ensures reproducible results
across different runs by stabilizing numerical outputs.
Parameters:
- seed (int): The seed number used for randomization.
def init_keys(self, seed: int):
Returns:
- jnp.ndarray: An array of generated random keys.
"""
rng = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(rng)
return key
Expand All @@ -69,7 +91,7 @@ def up_prob(self, state: State) -> float:
return 0.5 * (1 + self.alpha2 * (handy_map[state.is_prev_mv_up]))

def next_state(self, state: State) -> State:
k = self.init_keys(self.seed)
k = self.generate_keys(self.seed)
up_move: jnp.ndarray = jax.random.binomial(k, 1, self.up_prob(state))
self.seed = k[0]
return Process2.State(
Expand All @@ -87,7 +109,18 @@ class State:
alpha3: float = 1. # strength of reverse-pull (non-negative value)
seed: int = 42

def init_keys(self, seed: int):
def generate_keys(self, seed: int) -> jnp.ndarray:
"""
This function generates a set of random keys using JAX's deterministic
random number generation system. This system ensures reproducible results
across different runs by stabilizing numerical outputs.
Parameters:
- seed (int): The seed number used for randomization.
Returns:
- jnp.ndarray: An array of generated random keys.
"""
rng = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(rng)
return key
Expand All @@ -99,7 +132,7 @@ def up_prob(self, state: State) -> float:
return 1. / 1 + (1/x - 1) ** self.alpha3 if total else 0.5

def next_state(self, state: State) -> State:
k = self.init_keys(self.seed)
k = self.generate_keys(self.seed)
up_move: int = jax.random.binomial(k, 1, self.up_prob(state))
self.seed = k[0]
return Process3.State(
Expand Down

0 comments on commit 568bb62

Please sign in to comment.