Skip to content

Commit

Permalink
add scratch examples
Browse files Browse the repository at this point in the history
  • Loading branch information
pgawlowicz committed Oct 28, 2018
1 parent 094fa59 commit 3f198dc
Show file tree
Hide file tree
Showing 36 changed files with 4,884 additions and 0 deletions.
103 changes: 103 additions & 0 deletions scratch/interference-pattern/cognitive-agent-v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import gym
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from tensorflow import keras
from ns3gym import ns3env

env = gym.make('ns3-v0')
ob_space = env.observation_space
ac_space = env.action_space
print("Observation space: ", ob_space, ob_space.dtype)
print("Action space: ", ac_space, ac_space.n)

s_size = ob_space.shape[0]
a_size = ac_space.n
model = keras.Sequential()
model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu'))
model.add(keras.layers.Dense(a_size, activation='softmax'))
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])

total_episodes = 200
max_env_steps = 100
env._max_episode_steps = max_env_steps

epsilon = 1.0 # exploration rate
epsilon_min = 0.01
epsilon_decay = 0.999

time_history = []
rew_history = []

for e in range(total_episodes):

state = env.reset()
state = np.reshape(state, [1, s_size])
rewardsum = 0
for time in range(max_env_steps):

# Choose action
if np.random.rand(1) < epsilon:
action = np.random.randint(a_size)
else:
action = np.argmax(model.predict(state)[0])

# Step
next_state, reward, done, _ = env.step(action)

if done:
print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}"
.format(e, total_episodes, time, rewardsum, epsilon))
break

next_state = np.reshape(next_state, [1, s_size])

# Train
target = reward
if not done:
target = (reward + 0.95 * np.amax(model.predict(next_state)[0]))

target_f = model.predict(state)
target_f[0][action] = target
model.fit(state, target_f, epochs=1, verbose=0)

state = next_state
rewardsum += reward
if epsilon > epsilon_min: epsilon *= epsilon_decay

time_history.append(time)
rew_history.append(rewardsum)

#for n in range(2 ** s_size):
# state = [n >> i & 1 for i in range(0, 2)]
# state = np.reshape(state, [1, s_size])
# print("state " + str(state)
# + " -> prediction " + str(model.predict(state)[0])
# )

#print(model.get_config())
#print(model.to_json())
#print(model.get_weights())

print("Plot Learning Performance")
mpl.rcdefaults()
mpl.rcParams.update({'font.size': 16})

fig, ax = plt.subplots(figsize=(10,4))
plt.grid(True, linestyle='--')
plt.title('Learning Performance')
plt.plot(range(len(time_history)), time_history, label='Steps', marker="^", linestyle=":")#, color='red')
plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="", linestyle="-")#, color='k')
plt.xlabel('Episode')
plt.ylabel('Time')
plt.legend(prop={'size': 12})

plt.savefig('learning.pdf', bbox_inches='tight')
plt.show()
213 changes: 213 additions & 0 deletions scratch/interference-pattern/mygym.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <[email protected]>
*/

#include "mygym.h"
#include "ns3/object.h"
#include "ns3/core-module.h"
#include "ns3/wifi-module.h"
#include "ns3/node-list.h"
#include "ns3/log.h"
#include <sstream>
#include <iostream>

namespace ns3 {

NS_LOG_COMPONENT_DEFINE ("MyGymEnv");

NS_OBJECT_ENSURE_REGISTERED (MyGymEnv);

MyGymEnv::MyGymEnv ()
{
NS_LOG_FUNCTION (this);
m_currentNode = 0;
m_currentChannel = 0;
m_collisionTh = 3;
m_channelNum = 1;
m_channelOccupation.clear();
}

MyGymEnv::MyGymEnv (uint32_t channelNum)
{
NS_LOG_FUNCTION (this);
m_currentNode = 0;
m_currentChannel = 0;
m_collisionTh = 3;
m_channelNum = channelNum;
m_channelOccupation.clear();
}

MyGymEnv::~MyGymEnv ()
{
NS_LOG_FUNCTION (this);
}

TypeId
MyGymEnv::GetTypeId (void)
{
static TypeId tid = TypeId ("MyGymEnv")
.SetParent<OpenGymEnv> ()
.SetGroupName ("OpenGym")
.AddConstructor<MyGymEnv> ()
;
return tid;
}

void
MyGymEnv::DoDispose ()
{
NS_LOG_FUNCTION (this);
}

Ptr<OpenGymSpace>
MyGymEnv::GetActionSpace()
{
NS_LOG_FUNCTION (this);
Ptr<OpenGymDiscreteSpace> space = CreateObject<OpenGymDiscreteSpace> (m_channelNum);
NS_LOG_UNCOND ("GetActionSpace: " << space);
return space;
}

Ptr<OpenGymSpace>
MyGymEnv::GetObservationSpace()
{
NS_LOG_FUNCTION (this);
float low = 0.0;
float high = 1.0;
std::vector<uint32_t> shape = {m_channelNum,};
std::string dtype = TypeNameGet<uint32_t> ();
Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);
NS_LOG_UNCOND ("GetObservationSpace: " << space);
return space;
}

bool
MyGymEnv::GetGameOver()
{
NS_LOG_FUNCTION (this);
bool isGameOver = false;

uint32_t collisionNum = 0;
for (auto& v : m_collisions)
collisionNum += v;

if (collisionNum >= m_collisionTh){
isGameOver = true;
}
NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver);
return isGameOver;
}

Ptr<OpenGymDataContainer>
MyGymEnv::GetObservation()
{
NS_LOG_FUNCTION (this);
std::vector<uint32_t> shape = {m_channelNum,};
Ptr<OpenGymBoxContainer<uint32_t> > box = CreateObject<OpenGymBoxContainer<uint32_t> >(shape);

for (uint32_t i = 0; i < m_channelOccupation.size(); ++i) {
uint32_t value = m_channelOccupation.at(i);
box->AddValue(value);
}

NS_LOG_UNCOND ("MyGetObservation: " << box);
return box;
}

float
MyGymEnv::GetReward()
{
NS_LOG_FUNCTION (this);
float reward = 1.0;
if (m_channelOccupation.size() == 0){
return 0.0;
}
uint32_t occupied = m_channelOccupation.at(m_currentChannel);
if (occupied == 1) {
reward = -1.0;
m_collisions.erase(m_collisions.begin());
m_collisions.push_back(1);
} else {
m_collisions.erase(m_collisions.begin());
m_collisions.push_back(0);
}
NS_LOG_UNCOND ("MyGetReward: " << reward);
return reward;
}

std::string
MyGymEnv::GetExtraInfo()
{
NS_LOG_FUNCTION (this);
std::string myInfo = "info";
NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo);
return myInfo;
}

bool
MyGymEnv::ExecuteActions(Ptr<OpenGymDataContainer> action)
{
NS_LOG_FUNCTION (this);
Ptr<OpenGymDiscreteContainer> discrete = DynamicCast<OpenGymDiscreteContainer>(action);
uint32_t nextChannel = discrete->GetValue();
m_currentChannel = nextChannel;

NS_LOG_UNCOND ("Current Channel: " << m_currentChannel);
return true;
}

void
MyGymEnv::CollectChannelOccupation(uint32_t chanId, uint32_t occupied)
{
NS_LOG_FUNCTION (this);
m_channelOccupation.push_back(occupied);
}

bool
MyGymEnv::CheckIfReady()
{
NS_LOG_FUNCTION (this);
return m_channelOccupation.size() == m_channelNum;
}

void
MyGymEnv::ClearObs()
{
NS_LOG_FUNCTION (this);
m_channelOccupation.clear();
}

void
MyGymEnv::PerformCca (Ptr<MyGymEnv> entity, uint32_t channelId, Ptr<const SpectrumValue> avgPowerSpectralDensity)
{
double power = Integral (*(avgPowerSpectralDensity));
double powerDbW = 10 * std::log10(power);
double threshold = -60;
uint32_t busy = powerDbW > threshold;
NS_LOG_UNCOND("Channel: " << channelId << " CCA: " << busy << " RxPower: " << powerDbW);

entity->CollectChannelOccupation(channelId, busy);

if (entity->CheckIfReady()){
entity->Notify();
entity->ClearObs();
}
}

} // ns3 namespace
78 changes: 78 additions & 0 deletions scratch/interference-pattern/mygym.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <[email protected]>
*/


#ifndef MY_GYM_ENTITY_H
#define MY_GYM_ENTITY_H

#include "ns3/stats-module.h"
#include "ns3/opengym-module.h"
#include "ns3/spectrum-module.h"

namespace ns3 {

class Node;
class WifiMacQueue;
class Packet;

class MyGymEnv : public OpenGymEnv
{
public:
MyGymEnv ();
MyGymEnv (uint32_t channelNum);
virtual ~MyGymEnv ();
static TypeId GetTypeId (void);
virtual void DoDispose ();

Ptr<OpenGymSpace> GetActionSpace();
Ptr<OpenGymSpace> GetObservationSpace();
bool GetGameOver();
Ptr<OpenGymDataContainer> GetObservation();
float GetReward();
std::string GetExtraInfo();
bool ExecuteActions(Ptr<OpenGymDataContainer> action);

// the function has to be static to work with MakeBoundCallback
// that is why we pass pointer to MyGymEnv instance to be able to store the context (node, etc)
static void PerformCca(Ptr<MyGymEnv> entity, uint32_t channelId, Ptr<const SpectrumValue> avgPowerSpectralDensity);
void CollectChannelOccupation(uint32_t chanId, uint32_t occupied);
bool CheckIfReady();
void ClearObs();

private:
void ScheduleNextStateRead();
Ptr<WifiMacQueue> GetQueue(Ptr<Node> node);
bool SetCw(Ptr<Node> node, uint32_t cwMinValue=0, uint32_t cwMaxValue=0);

Time m_interval = Seconds(0.1);
Ptr<Node> m_currentNode;
uint64_t m_rxPktNum;
uint32_t m_channelNum;
std::vector<uint32_t> m_channelOccupation;
uint32_t m_currentChannel;

uint32_t m_collisionTh;
std::vector<uint32_t> m_collisions = {0,0,0,0,0,0,0,0,0,0,};
};

}


#endif // MY_GYM_ENTITY_H
Loading

0 comments on commit 3f198dc

Please sign in to comment.