Skip to content

Commit

Permalink
shape completion works
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Apr 24, 2017
1 parent 51450e3 commit d73a932
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 11 deletions.
39 changes: 30 additions & 9 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
import progressbar
import sys
import torchvision.transforms as transforms
import utils
import argparse
import json


class PartDataset(data.Dataset):
def __init__(self, root, npoints = 2500, classification = False, class_choice = None, train = True, parts_also = False):
def __init__(self, root, npoints = 2500, classification = False, class_choice = None, train = True, parts_also = False, shape_comp = False):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.parts_also = parts_also
self.shape_comp = shape_comp

self.classification = classification

Expand Down Expand Up @@ -91,21 +91,36 @@ def __getitem__(self, index):
#print(part.shape)
part = torch.from_numpy(part)



if self.shape_comp:
num_seg = len(np.unique(seg))
j = np.random.randint(num_seg) + 1
#print(len(point_set))
incomp = point_set[seg != j]
#print(len(incomp))

choice2 = np.random.choice(incomp.shape[0], 4 * self.npoints/5, replace=True)
incomp = incomp[choice2, :]
#print(part.shape)
incomp = torch.from_numpy(incomp)

point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))


if self.parts_also:
if self.shape_comp:
return point_set, incomp
elif self.parts_also:
return point_set, part

if self.classification:
elif self.classification:

return point_set, cls
else:
return point_set, seg






def __len__(self):
Expand All @@ -124,7 +139,13 @@ def __len__(self):
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())

d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, parts_also = True)
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', parts_also = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())
print(ps.size(), ps.type(), cls.size(),cls.type())


d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', shape_comp = True)
print(len(d))
ps, inc = d[0]
print(ps.size(), ps.type(), inc.size(),inc.type())
2 changes: 1 addition & 1 deletion nndistance/modules/nnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

class NNDModule(Module):
def forward(self, input1, input2):
return NNDFunction()(input1, input2)
return NNDFunction()(input1, input2)
44 changes: 43 additions & 1 deletion pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import matplotlib.pyplot as plt
import pdb
import utils
import torch.nn.functional as F


Expand Down Expand Up @@ -170,6 +169,49 @@ def forward(self, x):
return x


class PointGenComp(nn.Module):
def __init__(self, num_points = 2500):
super(PointGenComp, self).__init__()
self.fc1 = nn.Linear(2048, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, 500 * 3)
self.encoder = PointNetfeat(num_points = 2000)
self.th = nn.Tanh()
def forward(self, x, noise):
batchsize = x.size()[0]
x, _ = self.encoder(x)
x = torch.cat([x, noise], 1)

x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.th(self.fc4(x))
x = x.view(batchsize, 3, 500)
return x

class PointGenComp2(nn.Module):
def __init__(self, num_points = 2500):
super(PointGenComp2, self).__init__()
self.fc1 = nn.Linear(2048, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, 2500 * 3)
self.encoder = PointNetfeat(num_points = 2000)
self.th = nn.Tanh()
def forward(self, x, noise):
batchsize = x.size()[0]
x, _ = self.encoder(x)
x = torch.cat([x, noise], 1)

x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.th(self.fc4(x))
x = x.view(batchsize, 3, 2500)
return x


class PointGenR(nn.Module):
def __init__(self, num_points = 2500):
super(PointGenR, self).__init__()
Expand Down
72 changes: 72 additions & 0 deletions show_gan_comp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import print_function
from show3d_balls import *
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets import PartDataset
from pointnet import PointGen, PointGenC, PointGenComp
import torch.nn.functional as F
import matplotlib.pyplot as plt


#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default = '', help='model path')



opt = parser.parse_args()
print (opt)


dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'], shape_comp = True)




gen = PointGenComp()
gen.load_state_dict(torch.load(opt.model))

ld = len(dataset)

idx = np.random.randint(ld)

print(ld, idx)

_,part = dataset[idx]

sim_noise = Variable(torch.randn(2, 1024))
sim_noises = Variable(torch.zeros(30,1024))
for i in range(30):
x = i/30.0
sim_noises[i] = sim_noise[0] * x + sim_noise[1] * (1-x)

part = Variable(part.view(1,2000,3).transpose(2,1)).repeat(30,1,1)

points = gen(part, sim_noises)
print(points.size(), part.size())
points = torch.cat([points, part], 2)

cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:,:3]

color = cmap[np.array([0] * 500 + [2] * 2000), :]

point_np = points.transpose(2,1).data.numpy()

showpoints(point_np, color)


72 changes: 72 additions & 0 deletions show_gan_comp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import print_function
from show3d_balls import *
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets import PartDataset
from pointnet import PointGen, PointGenC, PointGenComp2
import torch.nn.functional as F
import matplotlib.pyplot as plt


#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default = '', help='model path')



opt = parser.parse_args()
print (opt)


dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'], shape_comp = True)




gen = PointGenComp2()
gen.load_state_dict(torch.load(opt.model))

ld = len(dataset)

idx = np.random.randint(ld)

print(ld, idx)

_,part = dataset[idx]

sim_noise = Variable(torch.randn(2, 1024))
sim_noises = Variable(torch.zeros(30,1024))
for i in range(30):
x = i/30.0
sim_noises[i] = sim_noise[0] * x + sim_noise[1] * (1-x)

part = Variable(part.view(1,2000,3).transpose(2,1)).repeat(30,1,1)

points = gen(part, sim_noises)
print(points.size(), part.size())
points = torch.cat([points, part], 2)

cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:,:3]

color = cmap[np.array([0] * 2500 + [2] * 2000), :]

point_np = points.transpose(2,1).data.numpy()

showpoints(point_np, color)


Loading

0 comments on commit d73a932

Please sign in to comment.