Skip to content

Commit

Permalink
Update and add new mk
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Ahmed committed Mar 30, 2024
1 parent c91af1e commit e41a172
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 20 deletions.
13 changes: 6 additions & 7 deletions rl/chapter04/simple_inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ class SimpleInventoryDeterministicPolicy(
DeterministicPolicy[InventoryState, IntLike]
):
def __init__(self, reorder_point: IntLike):
self.reorder_point: IntLike = reorder_point # r

def action_for(s: InventoryState) -> IntLike:
return max(self.reorder_point - s.inventory_position(), 0)
return max(reorder_point - s.inventory_position(), 0)

super().__init__(action_for)
super().__init__(action_for=action_for, reorder_point=reorder_point)


class SimpleInventoryStochasticPolicy(Policy[InventoryState, IntLike]):
Expand All @@ -55,7 +53,8 @@ def action_func(state=state) -> IntLike:
key = jax.random.PRNGKey(random.randint(42, 1234))
reorder_point_sample: IntLike = jax.random.poisson(key,
self.reorder_point_poisson_mean)
return jnp.max(reorder_point_sample - state.state.inventory_position(), 0)
return max(reorder_point_sample - state.state.inventory_position(), 0)
return SampledDistribution(action_func)


@dataclass(frozen=True)
Expand Down Expand Up @@ -101,11 +100,11 @@ def fraction_of_days_oos(
count: int = 0
high_fractile: IntLike = np.int32(poisson(self.poisson_lambda).ppf(0.98))
start: InventoryState = random.choice(
[InventoryState(i, 0) for i in range(high_fractile + 1)])
[InventoryState(on_hand=i, on_order=0) for i in range(high_fractile + 1)])

for _ in range(num_traces):
steps = itertools.islice(
impl_mrp.simulate_reward(Constant(NonTerminal(start))),
impl_mrp.simulate_reward(Constant(value=NonTerminal(state=start))),
time_steps
)
for step in steps:
Expand Down
7 changes: 5 additions & 2 deletions rl/chapter04/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from chapter04.simple_inventory import (SimpleInventoryDeterministicPolicy,
SimpleInventoryMDPNoCap,
SimpleInventoryStochasticPolicy)
from loguru import logger

if __name__ == '__main__':
user_poisson_lambda = 2.0
Expand All @@ -17,25 +18,27 @@
holding_cost=user_holding_cost,
stockout_cost=user_stockout_cost)

logger.info("Running Deterministic Policy...")
si_dp = SimpleInventoryDeterministicPolicy(
reorder_point=user_reorder_point
)

oos_frac_dp = si_mdp_nocap.fraction_of_days_oos(policy=si_dp,
time_steps=user_time_steps,
num_traces=user_num_traces)
print(
logger.debug(
f"Deterministic Policy yields {oos_frac_dp * 100:.2f}%"
+ " of Out-Of-Stock days"
)

logger.info("Running the Stochastic Policy...")
si_sp = SimpleInventoryStochasticPolicy(
reorder_point_poisson_mean=user_reorder_point_poisson_mean)

oos_frac_sp = si_mdp_nocap.fraction_of_days_oos(policy=si_sp,
time_steps=user_time_steps,
num_traces=user_num_traces)
print(
logger.debug(
f"Stochastic Policy yields {oos_frac_sp * 100:.2f}%"
+ " of Out-Of-Stock days"
)
50 changes: 45 additions & 5 deletions rl/mk_d_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from abc import ABC, abstractmethod
from typing import Tuple, Union, Iterable, Generic, TypeVar
from typing import Tuple, Union, Iterable, Generic, TypeVar, Mapping, Sequence, Set

import chex
from chex import dataclass

# import jax.numpy as jnp
import numpy as np

from gen_utils.distribution import Distribution
from mk_process import NonTerminal, MarkovRewardProcess, State
from gen_utils.distribution import Distribution, FiniteDistribution, SampledDistribution
from mk_process import Terminal, NonTerminal, MarkovRewardProcess, State
from policy import Policy

Array = Union[chex.Array, chex.ArrayNumpy]
Expand All @@ -18,6 +18,10 @@
A = TypeVar('A', bound=IntLike)
S = TypeVar('S', bound=Union[IntLike, Array])

StateReward = FiniteDistribution[Tuple[State[S], FloatLike]]
ActionMapping = Mapping[A, StateReward[S]]
StateActionMapping = Mapping[NonTerminal[S], ActionMapping[A, S]]


@dataclass(frozen=True)
class TransitionStep(Generic[S, A]):
Expand All @@ -36,7 +40,7 @@ def step(
self,
state: NonTerminal[S],
action: A
) -> Distribution[Tuple[State[S]], FloatLike]:
) -> Distribution[Tuple[State[S], FloatLike]]:
pass

def apply_policy(self, policy: Policy[S, A]) -> MarkovRewardProcess[S]:
Expand All @@ -48,7 +52,7 @@ def transition_reward(
state: NonTerminal[S],
) -> Distribution[Tuple[State, FloatLike]]:

actions: Distribution[A] = policy.act(state)
actions: Distribution[A] = policy.act(state) # TODO: Check why this produce None actions!!!!!!!!!!
return actions.apply(lambda a: mdp.step(state, a))
return RewardProcess()

Expand All @@ -69,3 +73,39 @@ def simulate_actions(

yield TransitionStep(state, action, next_state, reward)
state = next_state


class FiniteMarkovDecisionProcess(MarkovDecisionProcess[S, A]):
mapping: StateActionMapping[S, A]
non_terminal_states: Sequence[NonTerminal[S]]

def __init_(
self,
mapping: Mapping[S, Mapping[A, FiniteDistribution[Tuple[S, FloatLike]]]] # rewrite to reduce overhead
):
non_terminals: Set[S] = set(mapping.keys())
self.mapping = {NonTerminal(state=s): {a: Categorical(\
value={(NonTerminal(state=s1) if s1 in non_terminals else Terminal(state=s1), r): p
for (s1, r), p in v}
) for a, v in d.items())} for s, d in mapping.items()}
self.non_terminal_states = list(self.mapping.keys())

def __repr__(self) -> str:
display = ""
for s, d in self.mapping.items():
display += f"From State {s.state}:\n"
for a, d1 in d.items():
display += f" With Action {a}:\n"
for (s1, r), p in d1:
opt = "Terminal " if isinstance(s1, Terminal) else ""
display += f" To [{opt}State {s1.state} and "\
+ f"Reward {r:.3f}] with Probability {p:.3f}\n"
return display

def step(self, state: NonTerminal[S], action: A) -> StateReward[S]:
action_map: ActionMapping[A, S] = self.mapping[state]
return action_map[action]

def actions(self, state: NonTerminal[S]) -> Iterable[A]:
return self.mapping[state].keys()

8 changes: 4 additions & 4 deletions rl/mk_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def simulate_reward(
next_distribution = self.transition_reward(state)

next_state, reward = next_distribution.sample()
yield TransitionStep(state, next_state, reward)
yield TransitionStep(state=state, next_state=next_state, reward=reward)

state = next_state

Expand All @@ -151,7 +151,7 @@ def transition(self, state: NonTerminal[S]) -> Distribution[State[S]]:
def next_state(distribution=distribution):
next_s, _ = distribution.sample()
return next_s
return SampledDistribution(next_state)
return SampledDistribution(sampler=next_state)


StateReward = FiniteDistribution[Tuple[State[S], FloatLike]]
Expand All @@ -176,8 +176,8 @@ def __init__(self, transition_reward_map: Mapping[S, StateReward]):

nt: Set[S] = set(transition_reward_map.keys())
self.transition_reward_map = {
NonTerminal(s): Categorical(
{(NonTerminal(s1) if s1 in nt else Terminal(s1), r): p
NonTerminal(state=s): Categorical(distribution=\
{(NonTerminal(state=s1) if s1 in nt else Terminal(s1), r): p
for (s1, r), p in v}
) for s, v in transition_reward_map.items()
}
Expand Down
5 changes: 3 additions & 2 deletions rl/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def act(self, state: NonTerminal[S]) -> Distribution[A]:
@dataclass(frozen=True)
class DeterministicPolicy(Policy[S, A]):
action_for: Callable[[S], A]
reorder_point: IntLike

def act(self, state: NonTerminal[S]) -> Constant[A]:
return Constant(self.action_for(state.state))
return Constant(value=self.action_for(state.state)) # TODO: THIS SHOULD HAVE APPLY


@dataclass(frozen=True)
class UniformPolicy(Policy[S, A]):
valid_actions: Callable[[S], A]

def act(self, state: NonTerminal[S]) -> Choose[A]:
return Choose(self.valid_actions(state.state))
return Choose(value=self.valid_actions(state.state))

0 comments on commit e41a172

Please sign in to comment.