Skip to content

Commit

Permalink
Fixed several bugs and documented the code
Browse files Browse the repository at this point in the history
  • Loading branch information
pablo committed Sep 3, 2024
1 parent 281202a commit 1bd827f
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 54 deletions.
3 changes: 2 additions & 1 deletion Sewformer/configs/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
experiment:
project_name: Garments-Reconstruction
run_name: Detr2d-V6-final-dif-ce-focal-schd-agp
run_id:
run_id:
local_dir: ./garment_outputs/Detr2d-V6-final-dif-ce-focal-schd-agp
is_training: False

# ----- Dataset-related properties -----
dataset:
Expand Down
3 changes: 2 additions & 1 deletion Sewformer/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,13 @@ def load_detr_dataset(self, data_root, eval_config={}, unseen=False, batch_size=

# Dataset
data_class = getattr(data, data_config['class'])
dataset = data_class(data_root, data_config, gt_caching=True, feature_caching=False)
dataset = data_class(data_root, sim_root={}, start_config=data_config, gt_caching=True, feature_caching=False)

datawrapper = data.RealisticDatasetDetrWrapper(dataset, known_split=split, batch_size=batch_size)
return dataset, datawrapper

def load_detr_model(self, data_config, others=False):
# Load Model
model, criterion = models.build_former(self.in_config)
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
model = nn.DataParallel(model, device_ids=[device])
Expand Down
4 changes: 3 additions & 1 deletion Sewformer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# My modules
import sys, os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
pkg_path = "{}/SewFactory/packages".format(root_path)
pkg_path = "{}/sewformer/SewFactory/packages".format(root_path)
print(pkg_path)
sys.path.insert(0, pkg_path)

Expand Down Expand Up @@ -131,8 +131,10 @@ def is_img_file(fn):
if __name__ == "__main__":

np.set_printoptions(precision=4, suppress=True)
# Load system info from system.json
system_info = customconfig.Properties('./system.json')

# Load configuration and arguments from configs/test.yaml
config, args = get_values_from_args()

shape_experiment = ExperimentWrappper(config, system_info['wandb_username']) # finished experiment
Expand Down
89 changes: 55 additions & 34 deletions Sewformer/models/garment_detr_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,47 @@
class GarmentDETRv6(nn.Module):
def __init__(self, backbone, panel_transformer, num_panel_queries, num_edges, num_joints, **edge_kwargs):
super().__init__()
# CNN backbone
self.backbone = backbone

self.num_panel_queries = num_panel_queries
self.num_joint_queries = num_joints
# Panel Transformer ( Encoder + Decoder panel)
self.panel_transformer = panel_transformer

self.hidden_dim = self.panel_transformer.d_model
self.edge_kwargs = edge_kwargs["edge_kwargs"]

self.panel_embed = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 2)
# Edge Decoder
self.build_edge_decoder(self.hidden_dim, self.edge_kwargs["nheads"],
self.hidden_dim, self.edge_kwargs["dropout"],
"relu", self.edge_kwargs["pre_norm"],
self.edge_kwargs["dec_layers"])

self.panel_joints_query_embed = nn.Embedding(self.num_panel_queries + self.num_joint_queries, self.hidden_dim)

self.num_panel_queries = num_panel_queries # Number of panel queries [23]
self.num_joint_queries = num_joints # Number of joint queries [22]
self.num_edges = num_edges # Number of edges [14]
self.num_edge_queries = self.num_panel_queries * num_edges # Number of edge queries [23*14=322]

# Convolution 1x1 to reduce the dimension of output of CNN backbone
self.input_proj = nn.Conv2d(backbone.num_channels, self.hidden_dim, kernel_size=1)
self.panel_rt_decoder = MLP(self.hidden_dim, self.hidden_dim, 7, 2)
self.joints_decoder = MLP(self.hidden_dim, self.hidden_dim, 6, 2)

self.num_edges = num_edges
self.num_edge_queries = self.num_panel_queries * num_edges
self.edge_kwargs = edge_kwargs["edge_kwargs"]
# Panel Decoder: Panel Queries [23+22]
self.panel_joints_query_embed = nn.Embedding(self.num_panel_queries + self.num_joint_queries, self.hidden_dim)
# Panel Decoder: output MLP ( panel + joint tokens)
self.panel_embed = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 2)

# Panel Decoder: MLP for Rotation and Translation
self.panel_rt_decoder = MLP(self.hidden_dim, self.hidden_dim, 7, 2)
# TODO MLP for Joints
self.joints_decoder = MLP(self.hidden_dim, self.hidden_dim, 6, 2)

# Expanded tokens
self.panel_decoder = MLP(self.hidden_dim, self.hidden_dim, self.num_edges * 4, 2)
self.edge_query_mlp = MLP(self.hidden_dim + 4, self.hidden_dim, self.hidden_dim, 1)

self.build_edge_decoder(self.hidden_dim, self.edge_kwargs["nheads"],
self.hidden_dim, self.edge_kwargs["dropout"],
"relu", self.edge_kwargs["pre_norm"],
self.edge_kwargs["dec_layers"])

# MLP Edge Decoder
self.edge_embed = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 2)
# MLP
self.edge_cls = MLP(self.hidden_dim, self.hidden_dim // 2, 1, 2)
# MLP to predict edges
self.edge_decoder = MLP(self.hidden_dim, self.hidden_dim, 4, 2)

def build_edge_decoder(self, d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_layers):
Expand All @@ -72,41 +82,49 @@ def _reset_parameters(self, ):
def forward(self, samples, gt_stitches=None, gt_edge_mask=None, return_stitches=False):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
# Forward CNN [features = (1,2048,24,24), pos = (1,256,24,24)]
features, panel_pos = self.backbone(samples)

src, mask = features[-1].decompose()
B = src.shape[0]
assert mask is not None
panel_joint_hs, panel_memory, _ = self.panel_transformer(self.input_proj(src), mask, self.panel_joints_query_embed.weight, panel_pos[-1])
# Panel Transformer [Panel tokens + Joint tokens, Visual Tokens ]
panel_joint_hs, visual_tokens, _ = self.panel_transformer(self.input_proj(src), mask, self.panel_joints_query_embed.weight, panel_pos[-1])

# Output MLP Panel Decoder: Panel tokens + Joint tokens
panel_joint_hs = self.panel_embed(panel_joint_hs)
panel_hs = panel_joint_hs[:, :, :self.num_panel_queries, :]
joint_hs = panel_joint_hs[:, :, self.num_panel_queries:, :]
output_panel_rt = self.panel_rt_decoder(panel_hs)
# Panel tokens matrix, Joint tokens matrix
panel_tokens, joint_hs = panel_joint_hs[:, :, :self.num_panel_queries, :], panel_joint_hs[:, :, self.num_panel_queries:, :]

output_rotations = output_panel_rt[:, :, :, :4]
output_translations = output_panel_rt[:, :, :, 4:]
# MLP Rotation + translation matrix (Panel Decoder)
output_panel_rt = self.panel_rt_decoder(panel_tokens)
# Output: Rotation and Translation vectors
output_rotations, output_translations = output_panel_rt[:, :, :, :4], output_panel_rt[:, :, :, 4:]
out = {"rotations": output_rotations[-1], "translations": output_translations[-1]}

out = {"rotations": output_rotations[-1],
"translations": output_translations[-1]}

# TODO MLP MLP Joints Decoder
output_joints = self.joints_decoder(joint_hs)
out.update({"smpl_joints": output_joints[-1]})

edge_output = self.panel_decoder(panel_hs)[-1].view(B, self.num_panel_queries, self.num_edges, 4)
edge_query = self.edge_query_mlp(torch.cat((panel_joint_hs[-1, :, :self.num_panel_queries, :].unsqueeze(2).expand(-1, -1, self.num_edges, -1), edge_output), dim=-1)).reshape(B, -1, self.hidden_dim).permute(1, 0, 2)
# Expanded tokens
edge_output = self.panel_decoder(panel_tokens)[-1].view(B, self.num_panel_queries, self.num_edges, 4) # (6, B, 23, 256) -> (B, 23, 14, 4)
aux = torch.cat((panel_tokens[-1, :, :, :].unsqueeze(2).expand(-1, -1, self.num_edges, -1), edge_output), dim=-1)
# Edge Query: MLP (322, B, 256)
edge_query = self.edge_query_mlp(aux).reshape(B, -1, self.hidden_dim).permute(1, 0, 2)

# Edge Decoder transformer
tgt = torch.zeros_like(edge_query)
memory = panel_memory.view(B, self.hidden_dim, -1).permute(2, 0, 1) # latten NxCxHxW to HWxNxC
edge_hs = self.edge_trans_decoder(tgt, memory,
memory = visual_tokens.view(B, self.hidden_dim, -1).permute(2, 0, 1) # latten NxCxHxW to HWxNxC
edge_hs = self.edge_trans_decoder(tgt, memory,
memory_key_padding_mask=mask.flatten(1),
query_pos=edge_query).transpose(1, 2)

# Edge Tokens
output_edge_embed = self.edge_embed(edge_hs)[-1]

# TODO
output_edge_cls = self.edge_cls(output_edge_embed)
# Predicted Edges
output_edges = self.edge_decoder(output_edge_embed) + edge_output.view(B, -1, 4)


out.update({"outlines": output_edges, "edge_cls": output_edge_cls})

if return_stitches:
Expand Down Expand Up @@ -352,13 +370,16 @@ def eval(self):


def build(args):
# Number of maximum patterns
num_classes = args["dataset"]["max_pattern_len"]
devices = torch.device(args["trainer"]["devices"][0] if isinstance(args["trainer"]["devices"], list) else args["trainer"]["devices"])
# 4.1.1 Visual Encoder (Resnet + Positional Encoding)
backbone = build_backbone(args)
# 4.1.2 Two-level Transformer decoder
panel_transformer = build_transformer(args)

# SewFormer model
model = GarmentDETRv6(backbone, panel_transformer, num_classes, 14, 22, edge_kwargs=args["NN"])

# Loss function
criterion = SetCriterionWithOutMatcher(args["dataset"], args["NN"]["loss"])
criterion.to(devices)
return model, criterion
33 changes: 17 additions & 16 deletions Sewformer/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
return_intermediate_dec=False):
super().__init__()

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None

self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
Expand All @@ -45,19 +44,20 @@ def forward(self, src, mask, query_embed, pos_embed, return_self_attns=False):
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # Positional Embedding
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # Panel Query
mask = mask.flatten(1)

tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
# Transformers encoder: Visual tokens
visual_tokens = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# Panel Decoder: Panel tokens
hs = self.decoder(tgt, visual_tokens, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
if return_self_attns:
self_attns = self.encoder.layers[-1].self_attn
else:
self_attn = None
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w), self_attn
return hs.transpose(1, 2), visual_tokens.permute(1, 2, 0).view(bs, c, h, w), self_attn


class TransformerEncoder(nn.Module):
Expand All @@ -75,8 +75,7 @@ def forward(self, src,
output = src

for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)

Expand Down Expand Up @@ -227,7 +226,7 @@ def forward_post(self, tgt, memory,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(slef.dropout(self.activation(self.linear1(tgt))))
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
Expand All @@ -240,9 +239,11 @@ def forward_pre(self, tgt, memory,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt2 = self.self_attn(self.with_pos_embed(tgt2, query_pos),
self.with_pos_embed(tgt2, query_pos),
value=tgt2,
attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]

tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
Expand Down
2 changes: 1 addition & 1 deletion Sewformer/system.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"output": "path to put the training logs",
"datasets_path": "path to the sewfactory",
"sim_root": "path to the sim2real images",
"wandb_username": ""
"wandb_username": "pabloriosn"
}

0 comments on commit 1bd827f

Please sign in to comment.