Skip to content

Commit

Permalink
Comment out track_queries_placeholder_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
timmeinhardt committed Apr 28, 2022
1 parent cadde18 commit b2f64c5
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 52 deletions.
32 changes: 22 additions & 10 deletions src/trackformer/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,36 @@ def make_results(outputs, targets, postprocessors, tracking, return_only_orig=Tr
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)

# remove placeholder track queries
results_mask = None
if tracking:
results_mask = [~t['track_queries_placeholder_mask'] for t in targets]
for target, res_mask in zip(targets, results_mask):
target['track_queries_mask'] = target['track_queries_mask'][res_mask]
target['track_queries_fal_pos_mask'] = target['track_queries_fal_pos_mask'][res_mask]
# results_mask = None
# if tracking:
# results_mask = [~t['track_queries_placeholder_mask'] for t in targets]
# for target, res_mask in zip(targets, results_mask):
# target['track_queries_mask'] = target['track_queries_mask'][res_mask]
# target['track_queries_fal_pos_mask'] = target['track_queries_fal_pos_mask'][res_mask]

# results = None
# if not return_only_orig:
# results = postprocessors['bbox'](outputs, target_sizes, results_mask)
# results_orig = postprocessors['bbox'](outputs, orig_target_sizes, results_mask)

# if 'segm' in postprocessors:
# results_orig = postprocessors['segm'](
# results_orig, outputs, orig_target_sizes, target_sizes, results_mask)
# if not return_only_orig:
# results = postprocessors['segm'](
# results, outputs, target_sizes, target_sizes, results_mask)

results = None
if not return_only_orig:
results = postprocessors['bbox'](outputs, target_sizes, results_mask)
results_orig = postprocessors['bbox'](outputs, orig_target_sizes, results_mask)
results = postprocessors['bbox'](outputs, target_sizes)
results_orig = postprocessors['bbox'](outputs, orig_target_sizes)

if 'segm' in postprocessors:
results_orig = postprocessors['segm'](
results_orig, outputs, orig_target_sizes, target_sizes, results_mask)
results_orig, outputs, orig_target_sizes, target_sizes)
if not return_only_orig:
results = postprocessors['segm'](
results, outputs, target_sizes, target_sizes, results_mask)
results, outputs, target_sizes, target_sizes)

if results is None:
return results_orig, results
Expand Down
32 changes: 17 additions & 15 deletions src/trackformer/models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def loss_labels(self, outputs, targets, indices, _, log=True):
target_classes = target_classes.clone()
target_classes[i, target['track_queries_fal_pos_mask']] = 0

weight = None
if self.tracking:
weight = torch.stack([~t['track_queries_placeholder_mask'] for t in targets]).float()
loss_ce *= weight
# weight = None
# if self.tracking:
# weight = torch.stack([~t['track_queries_placeholder_mask'] for t in targets]).float()
# loss_ce *= weight

loss_ce = loss_ce.sum() / self.empty_weight[target_classes].sum()

Expand Down Expand Up @@ -229,20 +229,22 @@ def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):

target_classes_onehot = target_classes_onehot[:,:,:-1]

query_mask = None
if self.tracking:
query_mask = torch.stack([~t['track_queries_placeholder_mask'] for t in targets])[..., None]
query_mask = query_mask.repeat(1, 1, self.num_classes)
# query_mask = None
# if self.tracking:
# query_mask = torch.stack([~t['track_queries_placeholder_mask'] for t in targets])[..., None]
# query_mask = query_mask.repeat(1, 1, self.num_classes)

loss_ce = sigmoid_focal_loss(
src_logits, target_classes_onehot, num_boxes,
alpha=self.focal_alpha, gamma=self.focal_gamma, query_mask=query_mask)

if self.tracking:
mean_num_queries = torch.tensor([len(m.nonzero()) for m in query_mask]).float().mean()
loss_ce *= mean_num_queries
else:
loss_ce *= src_logits.shape[1]
alpha=self.focal_alpha, gamma=self.focal_gamma)
# , query_mask=query_mask)

# if self.tracking:
# mean_num_queries = torch.tensor([len(m.nonzero()) for m in query_mask]).float().mean()
# loss_ce *= mean_num_queries
# else:
# loss_ce *= src_logits.shape[1]
loss_ce *= src_logits.shape[1]
losses = {'loss_ce': loss_ce}

if log:
Expand Down
52 changes: 26 additions & 26 deletions src/trackformer/models/detr_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,38 +183,38 @@ def add_track_queries_to_targets(self, targets, prev_indices, prev_out, add_fals
]).bool()

# add placeholder track queries to allow for batch sizes > 1
max_track_query_hs_embeds = max([len(t['track_query_hs_embeds']) for t in targets])
for i, target in enumerate(targets):
# max_track_query_hs_embeds = max([len(t['track_query_hs_embeds']) for t in targets])
# for i, target in enumerate(targets):

num_add = max_track_query_hs_embeds - len(target['track_query_hs_embeds'])
# num_add = max_track_query_hs_embeds - len(target['track_query_hs_embeds'])

if not num_add:
target['track_queries_placeholder_mask'] = torch.zeros_like(target['track_queries_mask']).bool()
continue
# if not num_add:
# target['track_queries_placeholder_mask'] = torch.zeros_like(target['track_queries_mask']).bool()
# continue

raise NotImplementedError
# raise NotImplementedError

target['track_query_hs_embeds'] = torch.cat(
[torch.zeros(num_add, self.hidden_dim).to(device),
target['track_query_hs_embeds']
])
target['track_query_boxes'] = torch.cat(
[torch.zeros(num_add, 4).to(device),
target['track_query_boxes']
])
# target['track_query_hs_embeds'] = torch.cat(
# [torch.zeros(num_add, self.hidden_dim).to(device),
# target['track_query_hs_embeds']
# ])
# target['track_query_boxes'] = torch.cat(
# [torch.zeros(num_add, 4).to(device),
# target['track_query_boxes']
# ])

target['track_queries_mask'] = torch.cat([
torch.tensor([True, ] * num_add).to(device),
target['track_queries_mask']
]).bool()
# target['track_queries_mask'] = torch.cat([
# torch.tensor([True, ] * num_add).to(device),
# target['track_queries_mask']
# ]).bool()

target['track_queries_fal_pos_mask'] = torch.cat([
torch.tensor([False, ] * num_add).to(device),
target['track_queries_fal_pos_mask']
]).bool()
# target['track_queries_fal_pos_mask'] = torch.cat([
# torch.tensor([False, ] * num_add).to(device),
# target['track_queries_fal_pos_mask']
# ]).bool()

target['track_queries_placeholder_mask'] = torch.zeros_like(target['track_queries_mask']).bool()
target['track_queries_placeholder_mask'][:num_add] = True
# target['track_queries_placeholder_mask'] = torch.zeros_like(target['track_queries_mask']).bool()
# target['track_queries_placeholder_mask'][:num_add] = True

def forward(self, samples: NestedTensor, targets: list = None, prev_features=None):
if targets is not None and not self._tracking:
Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(self, samples: NestedTensor, targets: list = None, prev_features=Non
device = target['boxes'].device

target['track_query_hs_embeds'] = torch.zeros(0, self.hidden_dim).float().to(device)
target['track_queries_placeholder_mask'] = torch.zeros(self.num_queries).bool().to(device)
# target['track_queries_placeholder_mask'] = torch.zeros(self.num_queries).bool().to(device)
target['track_queries_mask'] = torch.zeros(self.num_queries).bool().to(device)
target['track_queries_fal_pos_mask'] = torch.zeros(self.num_queries).bool().to(device)
target['track_query_boxes'] = torch.zeros(0, 4).to(device)
Expand Down
3 changes: 2 additions & 1 deletion src/trackformer/models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def forward(self, outputs, targets):

prop_i = 0
for j in range(cost_matrix.shape[1]):
if target['track_queries_fal_pos_mask'][j] or target['track_queries_placeholder_mask'][j]:
# if target['track_queries_fal_pos_mask'][j] or target['track_queries_placeholder_mask'][j]:
if target['track_queries_fal_pos_mask'][j]:
# false positive and palceholder track queries should not
# be matched to any target
cost_matrix[i, j] = np.inf
Expand Down

0 comments on commit b2f64c5

Please sign in to comment.