Skip to content

Commit

Permalink
major tracking improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias Grosse-Puppendahl committed Aug 23, 2020
1 parent 529e3eb commit 9b02e96
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 16 deletions.
21 changes: 18 additions & 3 deletions simple_filters/polynomial_filter_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ class PolynomialFilterStrategy(FilterStrategy):
to the median, multiplied by the outlier_rejection_ratio
"""

def __init__(self, poly_degree=3, reject_outliers=True, outlier_rejection_ratio=2.0):
def __init__(self, poly_degree=3, reject_outliers=True, outlier_rejection_ratio=2.0, filter_weight=1.0, max_items=None):
super().__init__()

self.poly_degree = poly_degree
self.reject_outliers = reject_outliers
self.outlier_rejection_ratio = outlier_rejection_ratio
self.max_items = max_items
self.history = None
self.filter_weight = filter_weight

self.__poly_fn = None

Expand All @@ -33,16 +35,28 @@ def eval(self, time=0):
history_size = self.history.shape[0]
offset_time = history_size + time - 1

# for debugging purposes
if self.poly_degree == 0:
return self.history[history_size - 1]

# in the case that the equation is underdetermined, we cannot predict a polynomial
# simply return the last state in the history
if history_size < self.poly_degree + 1:
return self.history[history_size - 1]

# if the polynomial functions are not existent, calculate them
if self.__poly_fn is None:
self.__update_polynomials()

return self.__eval_polynomials(offset_time)
predictions = self.__eval_polynomials(offset_time)

# finally applying a weight to the prediction
if time <= 0:
result = (self.filter_weight * predictions) + ((1 - self.filter_weight) * self.history[offset_time])
else:
result = predictions

return result

def __eval_polynomials(self, t):
length = self.history.shape[1]
Expand All @@ -69,6 +83,7 @@ def __calc_polynomial(self, x):
rel_delta = delta / np.median(delta)

mask = rel_delta < self.outlier_rejection_ratio

x = x[mask]
y = y[mask]

Expand Down
39 changes: 29 additions & 10 deletions simple_filters/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ class TrackedObject(Filter):
Acts as a simple proxy to the actual filter provided in the initialization
"""

def __init__(self, id, filter):
def __init__(self, id, filter, max_time_to_live):
self.id = id
self.filter = filter
self.time_to_live = 0
self.time_to_live = 1
self.max_time_to_live = max_time_to_live

def increase_time_to_live(self):
if self.time_to_live < self.max_time_to_live:
self.time_to_live += 1

def decrease_time_to_live(self):
if self.time_to_live > 0:
self.time_to_live -= 1
else:
self.time_to_live = 0

def update(self, state):
self.filter.update(state)
Expand All @@ -31,32 +42,35 @@ class Tracker:
"""

def __init__(self, filter_prototype,
time_to_live=0,
max_time_to_live=1,
time_to_birth=0,
distance_threshold=1.0,
distance_function=lambda x1, x2: np.linalg.norm(x1 - x2)):
self.distance_threshold = distance_threshold
self.time_to_live = time_to_live
self.max_time_to_live = max_time_to_live
self.time_to_birth = time_to_birth

self.object_counter = 0

self.__distance_function = distance_function
self.__filter_prototype = filter_prototype
self.__tracked_objects = []

def get_tracked_objects(self):
return self.__tracked_objects
return list(filter(lambda x: x.time_to_live > self.time_to_birth, self.__tracked_objects))

def to_numpy_array(self, raw=False):
"""
Returns the tracking id, plus the filtered object state if raw is False
"""
m = []
for t in self.__tracked_objects:
for t in self.get_tracked_objects():
if raw:
state = t.raw()
else:
state = t.eval()

m.append(np.insert(state, 0, t.id))
m.append(np.insert(np.array(t.id, dtype=np.float32), 0, state))

return np.array(m)

Expand Down Expand Up @@ -118,6 +132,7 @@ def update(self, states):
if t in objects_to_match and s in states_to_match:
objects_to_match.remove(t)
states_to_match.remove(s)
self.__tracked_objects[t].increase_time_to_live()
self.__tracked_objects[t].update(states[s])

## Delete objects
Expand All @@ -127,9 +142,9 @@ def update(self, states):
removals = []
for i in objects_to_match:
tracked_object = self.__tracked_objects[i]
tracked_object.time_to_live += 1
tracked_object.decrease_time_to_live()

if tracked_object.time_to_live > self.time_to_live:
if tracked_object.time_to_live < 1:
removals.append(tracked_object)
else:
# update the object with the next predicted state
Expand All @@ -142,6 +157,10 @@ def update(self, states):
# now go through all unmatched objects and create new objects
for i in states_to_match:
self.object_counter += 1
added_object = TrackedObject(self.object_counter, deepcopy(self.__filter_prototype))
added_object = TrackedObject(
self.object_counter,
deepcopy(self.__filter_prototype),
max_time_to_live=self.max_time_to_live
)
added_object.update(states[i])
self.__tracked_objects.append(added_object)
2 changes: 1 addition & 1 deletion tests/test_single_object_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_new_obj(self):
def test_interpolate_object_with_ttl(self):
strategy = PolynomialFilterStrategy(poly_degree=1, reject_outliers=False)
filter_prototype = Filter(strategy, history_size=3)
tracker = Tracker(filter_prototype, distance_threshold=1., time_to_live=1)
tracker = Tracker(filter_prototype, distance_threshold=1., max_time_to_live=2)

states = [
np.array([[1.0, 1.0]]),
Expand Down
5 changes: 3 additions & 2 deletions tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ def test_tracker_delete_objects(self):
self.static_update_and_assert(1, 1, 2)

def test_tracker_delete_objects_time_to_live(self):
self.tracker.time_to_live = 1
self.tracker.max_time_to_live = 2

self.static_update_and_assert(2, 2, 2)
self.static_update_and_assert(2, 2, 2) # increment ttl counter to 2
self.static_update_and_assert(1, 2, 2) # object should be retained, even if it doesn't appear
self.static_update_and_assert(1, 1, 2) # object should be removed after this update

def test_tracker_mapping(self):
# TODO: This assumes that the order is retained, but makes it easier for testing
reference_matrix = np.array([[1., 1., 2.], [2., 2., 3.]])
reference_matrix = np.array([[1., 2., 1.], [2., 3., 2.]])

self.tracker.update(self.generate_static_states(2))
self.assertTrue((self.tracker.to_numpy_array() == reference_matrix).all())
Expand Down

0 comments on commit 9b02e96

Please sign in to comment.