-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment.py
104 lines (89 loc) · 3.69 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gym
import numpy as np
import pandas as pd
import os
# from tqdm import tqdm
from yaw_planner import Oxford, LookAhead, NoControl, Rotating, Owl, LookGoal
from datetime import datetime
from utils import *
# import matplotlib.pyplot as plt
from envs.drone_v2 import Drone2DEnv2
policy_list = {
'LookAhead': LookAhead,
'NoControl': NoControl,
'Oxford': Oxford,
'Rotating': Rotating,
'Owl' : Owl,
'LookGoal' : LookGoal
}
def add_to_csv(dir, value):
df = pd.read_csv(dir, index_col=False)
df.loc[len(df)] = value
df.to_csv(dir, index=False)
class Experiment:
def __init__(self, params, dir):
self.params = params
self.env = gym.make(params.env, params=params)
self.dt = params.dt
self.policy = policy_list[params.gaze_method]
self.policy.__init__(self.policy, params)
self.result_dir = dir
if (not os.path.isfile(dir)) and (params.record):
d = {'Method':[],
'Planner':[],
'Motion Profile':[],
'Map ID':[],
'Agent size':[],
'Number of agents':[],
'Number of pillars':[],
'Agent speed':[],
'Drone speed':[],
'Depth variance':[],
'Initial position':[],
'Target position':[],
'Flight time':[],
'Grid discovered':[],
'Agent tracked':[],
'Agent tracked time':[],
'Success':[],
'Static Collision':[],
'Dynamic Collision':[],
'Freezing':[],
'Dead Lock':[],
'state machine':[]}
df = pd.DataFrame(d)
df.to_csv(dir, index=False)
def run(self):
self.env.reset()
done = False
while not done:
a = self.policy.plan(self.policy, self.env.info)
_, _, done, info = self.env.step(a)
if done:
if self.params.record:
tracking_time = np.array([len(tracker.ts)*0.1 for tracker in info['tracker_buffer']]).sum()
value = (self.params.gaze_method,
self.params.planner,
self.params.motion_profile,
self.params.map_id,
self.params.agent_radius,
self.params.agent_number,
self.params.pillar_number,
self.params.agent_max_speed,
self.params.drone_max_speed,
self.params.var_cam,
self.params.init_position,
self.params.target_list[0],
info['flight_time'],
info['drone'].map.grid_map.shape[0] * info['drone'].map.grid_map.shape[1] - np.sum(np.where(info['drone'].map.grid_map == 0, 1, 0)),# grid discovered
len(info['tracker_buffer']),
tracking_time / len(info['tracker_buffer']),
1 if info['state_machine'] == state_machine['GOAL_REACHED'] else 0,
1 if info['collision_flag'] == 1 else 0,
1 if info['collision_flag'] == 2 else 0,
info['freezing_flag'],
info['dead_lock_flag'],
info['state_machine'])
add_to_csv(self.result_dir, value)
if self.params.render:
self.env.render()