forked from tkn-tub/ns3-gym
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
094fa59
commit 3f198dc
Showing
36 changed files
with
4,884 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
import gym | ||
import tensorflow as tf | ||
import tensorflow.contrib.slim as slim | ||
import numpy as np | ||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
from tensorflow import keras | ||
from ns3gym import ns3env | ||
|
||
env = gym.make('ns3-v0') | ||
ob_space = env.observation_space | ||
ac_space = env.action_space | ||
print("Observation space: ", ob_space, ob_space.dtype) | ||
print("Action space: ", ac_space, ac_space.n) | ||
|
||
s_size = ob_space.shape[0] | ||
a_size = ac_space.n | ||
model = keras.Sequential() | ||
model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu')) | ||
model.add(keras.layers.Dense(a_size, activation='softmax')) | ||
model.compile(optimizer=tf.train.AdamOptimizer(0.001), | ||
loss='categorical_crossentropy', | ||
metrics=['accuracy']) | ||
|
||
total_episodes = 200 | ||
max_env_steps = 100 | ||
env._max_episode_steps = max_env_steps | ||
|
||
epsilon = 1.0 # exploration rate | ||
epsilon_min = 0.01 | ||
epsilon_decay = 0.999 | ||
|
||
time_history = [] | ||
rew_history = [] | ||
|
||
for e in range(total_episodes): | ||
|
||
state = env.reset() | ||
state = np.reshape(state, [1, s_size]) | ||
rewardsum = 0 | ||
for time in range(max_env_steps): | ||
|
||
# Choose action | ||
if np.random.rand(1) < epsilon: | ||
action = np.random.randint(a_size) | ||
else: | ||
action = np.argmax(model.predict(state)[0]) | ||
|
||
# Step | ||
next_state, reward, done, _ = env.step(action) | ||
|
||
if done: | ||
print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}" | ||
.format(e, total_episodes, time, rewardsum, epsilon)) | ||
break | ||
|
||
next_state = np.reshape(next_state, [1, s_size]) | ||
|
||
# Train | ||
target = reward | ||
if not done: | ||
target = (reward + 0.95 * np.amax(model.predict(next_state)[0])) | ||
|
||
target_f = model.predict(state) | ||
target_f[0][action] = target | ||
model.fit(state, target_f, epochs=1, verbose=0) | ||
|
||
state = next_state | ||
rewardsum += reward | ||
if epsilon > epsilon_min: epsilon *= epsilon_decay | ||
|
||
time_history.append(time) | ||
rew_history.append(rewardsum) | ||
|
||
#for n in range(2 ** s_size): | ||
# state = [n >> i & 1 for i in range(0, 2)] | ||
# state = np.reshape(state, [1, s_size]) | ||
# print("state " + str(state) | ||
# + " -> prediction " + str(model.predict(state)[0]) | ||
# ) | ||
|
||
#print(model.get_config()) | ||
#print(model.to_json()) | ||
#print(model.get_weights()) | ||
|
||
print("Plot Learning Performance") | ||
mpl.rcdefaults() | ||
mpl.rcParams.update({'font.size': 16}) | ||
|
||
fig, ax = plt.subplots(figsize=(10,4)) | ||
plt.grid(True, linestyle='--') | ||
plt.title('Learning Performance') | ||
plt.plot(range(len(time_history)), time_history, label='Steps', marker="^", linestyle=":")#, color='red') | ||
plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="", linestyle="-")#, color='k') | ||
plt.xlabel('Episode') | ||
plt.ylabel('Time') | ||
plt.legend(prop={'size': 12}) | ||
|
||
plt.savefig('learning.pdf', bbox_inches='tight') | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ | ||
/* | ||
* Copyright (c) 2018 Technische Universität Berlin | ||
* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License version 2 as | ||
* published by the Free Software Foundation; | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program; if not, write to the Free Software | ||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA | ||
* | ||
* Author: Piotr Gawlowicz <[email protected]> | ||
*/ | ||
|
||
#include "mygym.h" | ||
#include "ns3/object.h" | ||
#include "ns3/core-module.h" | ||
#include "ns3/wifi-module.h" | ||
#include "ns3/node-list.h" | ||
#include "ns3/log.h" | ||
#include <sstream> | ||
#include <iostream> | ||
|
||
namespace ns3 { | ||
|
||
NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); | ||
|
||
NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); | ||
|
||
MyGymEnv::MyGymEnv () | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
m_currentNode = 0; | ||
m_currentChannel = 0; | ||
m_collisionTh = 3; | ||
m_channelNum = 1; | ||
m_channelOccupation.clear(); | ||
} | ||
|
||
MyGymEnv::MyGymEnv (uint32_t channelNum) | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
m_currentNode = 0; | ||
m_currentChannel = 0; | ||
m_collisionTh = 3; | ||
m_channelNum = channelNum; | ||
m_channelOccupation.clear(); | ||
} | ||
|
||
MyGymEnv::~MyGymEnv () | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
} | ||
|
||
TypeId | ||
MyGymEnv::GetTypeId (void) | ||
{ | ||
static TypeId tid = TypeId ("MyGymEnv") | ||
.SetParent<OpenGymEnv> () | ||
.SetGroupName ("OpenGym") | ||
.AddConstructor<MyGymEnv> () | ||
; | ||
return tid; | ||
} | ||
|
||
void | ||
MyGymEnv::DoDispose () | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
} | ||
|
||
Ptr<OpenGymSpace> | ||
MyGymEnv::GetActionSpace() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
Ptr<OpenGymDiscreteSpace> space = CreateObject<OpenGymDiscreteSpace> (m_channelNum); | ||
NS_LOG_UNCOND ("GetActionSpace: " << space); | ||
return space; | ||
} | ||
|
||
Ptr<OpenGymSpace> | ||
MyGymEnv::GetObservationSpace() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
float low = 0.0; | ||
float high = 1.0; | ||
std::vector<uint32_t> shape = {m_channelNum,}; | ||
std::string dtype = TypeNameGet<uint32_t> (); | ||
Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype); | ||
NS_LOG_UNCOND ("GetObservationSpace: " << space); | ||
return space; | ||
} | ||
|
||
bool | ||
MyGymEnv::GetGameOver() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
bool isGameOver = false; | ||
|
||
uint32_t collisionNum = 0; | ||
for (auto& v : m_collisions) | ||
collisionNum += v; | ||
|
||
if (collisionNum >= m_collisionTh){ | ||
isGameOver = true; | ||
} | ||
NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver); | ||
return isGameOver; | ||
} | ||
|
||
Ptr<OpenGymDataContainer> | ||
MyGymEnv::GetObservation() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
std::vector<uint32_t> shape = {m_channelNum,}; | ||
Ptr<OpenGymBoxContainer<uint32_t> > box = CreateObject<OpenGymBoxContainer<uint32_t> >(shape); | ||
|
||
for (uint32_t i = 0; i < m_channelOccupation.size(); ++i) { | ||
uint32_t value = m_channelOccupation.at(i); | ||
box->AddValue(value); | ||
} | ||
|
||
NS_LOG_UNCOND ("MyGetObservation: " << box); | ||
return box; | ||
} | ||
|
||
float | ||
MyGymEnv::GetReward() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
float reward = 1.0; | ||
if (m_channelOccupation.size() == 0){ | ||
return 0.0; | ||
} | ||
uint32_t occupied = m_channelOccupation.at(m_currentChannel); | ||
if (occupied == 1) { | ||
reward = -1.0; | ||
m_collisions.erase(m_collisions.begin()); | ||
m_collisions.push_back(1); | ||
} else { | ||
m_collisions.erase(m_collisions.begin()); | ||
m_collisions.push_back(0); | ||
} | ||
NS_LOG_UNCOND ("MyGetReward: " << reward); | ||
return reward; | ||
} | ||
|
||
std::string | ||
MyGymEnv::GetExtraInfo() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
std::string myInfo = "info"; | ||
NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); | ||
return myInfo; | ||
} | ||
|
||
bool | ||
MyGymEnv::ExecuteActions(Ptr<OpenGymDataContainer> action) | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
Ptr<OpenGymDiscreteContainer> discrete = DynamicCast<OpenGymDiscreteContainer>(action); | ||
uint32_t nextChannel = discrete->GetValue(); | ||
m_currentChannel = nextChannel; | ||
|
||
NS_LOG_UNCOND ("Current Channel: " << m_currentChannel); | ||
return true; | ||
} | ||
|
||
void | ||
MyGymEnv::CollectChannelOccupation(uint32_t chanId, uint32_t occupied) | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
m_channelOccupation.push_back(occupied); | ||
} | ||
|
||
bool | ||
MyGymEnv::CheckIfReady() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
return m_channelOccupation.size() == m_channelNum; | ||
} | ||
|
||
void | ||
MyGymEnv::ClearObs() | ||
{ | ||
NS_LOG_FUNCTION (this); | ||
m_channelOccupation.clear(); | ||
} | ||
|
||
void | ||
MyGymEnv::PerformCca (Ptr<MyGymEnv> entity, uint32_t channelId, Ptr<const SpectrumValue> avgPowerSpectralDensity) | ||
{ | ||
double power = Integral (*(avgPowerSpectralDensity)); | ||
double powerDbW = 10 * std::log10(power); | ||
double threshold = -60; | ||
uint32_t busy = powerDbW > threshold; | ||
NS_LOG_UNCOND("Channel: " << channelId << " CCA: " << busy << " RxPower: " << powerDbW); | ||
|
||
entity->CollectChannelOccupation(channelId, busy); | ||
|
||
if (entity->CheckIfReady()){ | ||
entity->Notify(); | ||
entity->ClearObs(); | ||
} | ||
} | ||
|
||
} // ns3 namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ | ||
/* | ||
* Copyright (c) 2018 Technische Universität Berlin | ||
* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License version 2 as | ||
* published by the Free Software Foundation; | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program; if not, write to the Free Software | ||
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA | ||
* | ||
* Author: Piotr Gawlowicz <[email protected]> | ||
*/ | ||
|
||
|
||
#ifndef MY_GYM_ENTITY_H | ||
#define MY_GYM_ENTITY_H | ||
|
||
#include "ns3/stats-module.h" | ||
#include "ns3/opengym-module.h" | ||
#include "ns3/spectrum-module.h" | ||
|
||
namespace ns3 { | ||
|
||
class Node; | ||
class WifiMacQueue; | ||
class Packet; | ||
|
||
class MyGymEnv : public OpenGymEnv | ||
{ | ||
public: | ||
MyGymEnv (); | ||
MyGymEnv (uint32_t channelNum); | ||
virtual ~MyGymEnv (); | ||
static TypeId GetTypeId (void); | ||
virtual void DoDispose (); | ||
|
||
Ptr<OpenGymSpace> GetActionSpace(); | ||
Ptr<OpenGymSpace> GetObservationSpace(); | ||
bool GetGameOver(); | ||
Ptr<OpenGymDataContainer> GetObservation(); | ||
float GetReward(); | ||
std::string GetExtraInfo(); | ||
bool ExecuteActions(Ptr<OpenGymDataContainer> action); | ||
|
||
// the function has to be static to work with MakeBoundCallback | ||
// that is why we pass pointer to MyGymEnv instance to be able to store the context (node, etc) | ||
static void PerformCca(Ptr<MyGymEnv> entity, uint32_t channelId, Ptr<const SpectrumValue> avgPowerSpectralDensity); | ||
void CollectChannelOccupation(uint32_t chanId, uint32_t occupied); | ||
bool CheckIfReady(); | ||
void ClearObs(); | ||
|
||
private: | ||
void ScheduleNextStateRead(); | ||
Ptr<WifiMacQueue> GetQueue(Ptr<Node> node); | ||
bool SetCw(Ptr<Node> node, uint32_t cwMinValue=0, uint32_t cwMaxValue=0); | ||
|
||
Time m_interval = Seconds(0.1); | ||
Ptr<Node> m_currentNode; | ||
uint64_t m_rxPktNum; | ||
uint32_t m_channelNum; | ||
std::vector<uint32_t> m_channelOccupation; | ||
uint32_t m_currentChannel; | ||
|
||
uint32_t m_collisionTh; | ||
std::vector<uint32_t> m_collisions = {0,0,0,0,0,0,0,0,0,0,}; | ||
}; | ||
|
||
} | ||
|
||
|
||
#endif // MY_GYM_ENTITY_H |
Oops, something went wrong.