Skip to content

Commit

Permalink
multi-agent env example
Browse files Browse the repository at this point in the history
  • Loading branch information
pgawlowicz committed Mar 20, 2020
1 parent bce54b7 commit 904f758
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 0 deletions.
56 changes: 56 additions & 0 deletions scratch/multi-agent/agent1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
from ns3gym import ns3env

__author__ = "Piotr Gawlowicz"
__copyright__ = "Copyright (c) 2020, Technische Universität Berlin"
__version__ = "0.1.0"
__email__ = "[email protected]"


port = 5555
env = ns3env.Ns3Env(port=port, startSim=False)
env.reset()

ob_space = env.observation_space
ac_space = env.action_space
print("Observation space: ", ob_space, ob_space.dtype)
print("Action space: ", ac_space, ac_space.dtype)


stepIdx = 0
currIt = 0
iterationNum = 3

try:
while True:
obs = env.reset()
print("Step: ", stepIdx)
print("---obs: ", obs)

while True:
stepIdx += 1
action = env.action_space.sample()
print("---action: ", action)

print("Step: ", stepIdx)
obs, reward, done, info = env.step(action)
print("---obs, reward, done, info: ", obs, reward, done, info)

input("press enter....")

if done:
break

currIt += 1
if currIt == iterationNum:
break


except KeyboardInterrupt:
print("Ctrl-C -> Exit")
finally:
env.close()
print("Done")
56 changes: 56 additions & 0 deletions scratch/multi-agent/agent2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
from ns3gym import ns3env

__author__ = "Piotr Gawlowicz"
__copyright__ = "Copyright (c) 2020, Technische Universität Berlin"
__version__ = "0.1.0"
__email__ = "[email protected]"


port = 5556
env = ns3env.Ns3Env(port=port, startSim=False)
env.reset()

ob_space = env.observation_space
ac_space = env.action_space
print("Observation space: ", ob_space, ob_space.dtype)
print("Action space: ", ac_space, ac_space.dtype)


stepIdx = 0
currIt = 0
iterationNum = 3

try:
while True:
obs = env.reset()
print("Step: ", stepIdx)
print("---obs: ", obs)

while True:
stepIdx += 1
action = env.action_space.sample()
print("---action: ", action)

print("Step: ", stepIdx)
obs, reward, done, info = env.step(action)
print("---obs, reward, done, info: ", obs, reward, done, info)

input("press enter....")

if done:
break

currIt += 1
if currIt == iterationNum:
break


except KeyboardInterrupt:
print("Ctrl-C -> Exit")
finally:
env.close()
print("Done")
223 changes: 223 additions & 0 deletions scratch/multi-agent/mygym.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <[email protected]>
*/

#include "mygym.h"
#include "ns3/object.h"
#include "ns3/core-module.h"
#include "ns3/wifi-module.h"
#include "ns3/node-list.h"
#include "ns3/log.h"
#include <sstream>
#include <iostream>

namespace ns3 {

NS_LOG_COMPONENT_DEFINE ("MyGymEnv");

NS_OBJECT_ENSURE_REGISTERED (MyGymEnv);

MyGymEnv::MyGymEnv ()
{
NS_LOG_FUNCTION (this);
m_interval = Seconds(0.1);

Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this);
}

MyGymEnv::MyGymEnv (uint32_t agentId, Time stepTime)
{
NS_LOG_FUNCTION (this);
m_agentId = agentId;
m_interval = stepTime;

Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this);
}

void
MyGymEnv::ScheduleNextStateRead ()
{
NS_LOG_FUNCTION (this);
Simulator::Schedule (m_interval, &MyGymEnv::ScheduleNextStateRead, this);
Notify();
}

MyGymEnv::~MyGymEnv ()
{
NS_LOG_FUNCTION (this);
}

TypeId
MyGymEnv::GetTypeId (void)
{
static TypeId tid = TypeId ("MyGymEnv")
.SetParent<OpenGymEnv> ()
.SetGroupName ("OpenGym")
.AddConstructor<MyGymEnv> ()
;
return tid;
}

void
MyGymEnv::DoDispose ()
{
NS_LOG_FUNCTION (this);
}

/*
Define observation space
*/
Ptr<OpenGymSpace>
MyGymEnv::GetObservationSpace()
{
uint32_t nodeNum = 5;
float low = 0.0;
float high = 10.0;
std::vector<uint32_t> shape = {nodeNum,};
std::string dtype = TypeNameGet<uint32_t> ();

Ptr<OpenGymDiscreteSpace> discrete = CreateObject<OpenGymDiscreteSpace> (nodeNum);
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);

Ptr<OpenGymDictSpace> space = CreateObject<OpenGymDictSpace> ();
space->Add("box", box);
space->Add("discrete", discrete);

NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetObservationSpace: " << space);
return space;
}

/*
Define action space
*/
Ptr<OpenGymSpace>
MyGymEnv::GetActionSpace()
{
uint32_t nodeNum = 5;
float low = 0.0;
float high = 10.0;
std::vector<uint32_t> shape = {nodeNum,};
std::string dtype = TypeNameGet<uint32_t> ();

Ptr<OpenGymDiscreteSpace> discrete = CreateObject<OpenGymDiscreteSpace> (nodeNum);
Ptr<OpenGymBoxSpace> box = CreateObject<OpenGymBoxSpace> (low, high, shape, dtype);

Ptr<OpenGymDictSpace> space = CreateObject<OpenGymDictSpace> ();
space->Add("box", box);
space->Add("discrete", discrete);

NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetActionSpace: " << space);
return space;
}

/*
Define game over condition
*/
bool
MyGymEnv::GetGameOver()
{
bool isGameOver = false;
bool test = false;
static float stepCounter = 0.0;
stepCounter += 1;
if (stepCounter == 10 && test) {
isGameOver = true;
}
NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetGameOver: " << isGameOver);
return isGameOver;
}

/*
Collect observations
*/
Ptr<OpenGymDataContainer>
MyGymEnv::GetObservation()
{
uint32_t nodeNum = 5;
uint32_t low = 0.0;
uint32_t high = 10.0;
Ptr<UniformRandomVariable> rngInt = CreateObject<UniformRandomVariable> ();

std::vector<uint32_t> shape = {nodeNum,};
Ptr<OpenGymBoxContainer<uint32_t> > box = CreateObject<OpenGymBoxContainer<uint32_t> >(shape);

// generate random data
for (uint32_t i = 0; i<nodeNum; i++){
uint32_t value = rngInt->GetInteger(low, high);
box->AddValue(value);
}

Ptr<OpenGymDiscreteContainer> discrete = CreateObject<OpenGymDiscreteContainer>(nodeNum);
uint32_t value = rngInt->GetInteger(low, high);
discrete->SetValue(value);

Ptr<OpenGymTupleContainer> data = CreateObject<OpenGymTupleContainer> ();
data->Add(box);
data->Add(discrete);

// Print data from tuple
Ptr<OpenGymBoxContainer<uint32_t> > mbox = DynamicCast<OpenGymBoxContainer<uint32_t> >(data->Get(0));
Ptr<OpenGymDiscreteContainer> mdiscrete = DynamicCast<OpenGymDiscreteContainer>(data->Get(1));
NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetObservation: " << data);
NS_LOG_UNCOND ("---" << mbox);
NS_LOG_UNCOND ("---" << mdiscrete);

return data;
}

/*
Define reward function
*/
float
MyGymEnv::GetReward()
{
static float reward = 0.0;
reward += 1;
return reward;
}

/*
Define extra info. Optional
*/
std::string
MyGymEnv::GetExtraInfo()
{
std::string myInfo = "testInfo";
myInfo += "|123";
NS_LOG_UNCOND("AgentID: " << m_agentId << " MyGetExtraInfo: " << myInfo);
return myInfo;
}

/*
Execute received actions
*/
bool
MyGymEnv::ExecuteActions(Ptr<OpenGymDataContainer> action)
{
Ptr<OpenGymDictContainer> dict = DynamicCast<OpenGymDictContainer>(action);
Ptr<OpenGymBoxContainer<uint32_t> > box = DynamicCast<OpenGymBoxContainer<uint32_t> >(dict->Get("box"));
Ptr<OpenGymDiscreteContainer> discrete = DynamicCast<OpenGymDiscreteContainer>(dict->Get("discrete"));

NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyExecuteActions: " << action);
NS_LOG_UNCOND ("---" << box);
NS_LOG_UNCOND ("---" << discrete);
return true;
}

} // ns3 namespace
57 changes: 57 additions & 0 deletions scratch/multi-agent/mygym.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
* Copyright (c) 2018 Technische Universität Berlin
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*
* Author: Piotr Gawlowicz <[email protected]>
*/


#ifndef MY_GYM_ENTITY_H
#define MY_GYM_ENTITY_H

#include "ns3/opengym-module.h"
#include "ns3/nstime.h"

namespace ns3 {

class MyGymEnv : public OpenGymEnv
{
public:
MyGymEnv ();
MyGymEnv (uint32_t agentId, Time stepTime);
virtual ~MyGymEnv ();
static TypeId GetTypeId (void);
virtual void DoDispose ();

Ptr<OpenGymSpace> GetActionSpace();
Ptr<OpenGymSpace> GetObservationSpace();
bool GetGameOver();
Ptr<OpenGymDataContainer> GetObservation();
float GetReward();
std::string GetExtraInfo();
bool ExecuteActions(Ptr<OpenGymDataContainer> action);

private:
void ScheduleNextStateRead();

uint32_t m_agentId;
Time m_interval;
};

}


#endif // MY_GYM_ENTITY_H
Loading

0 comments on commit 904f758

Please sign in to comment.