Skip to content

Commit

Permalink
adding code
Browse files Browse the repository at this point in the history
  • Loading branch information
ranahanocka committed May 25, 2020
1 parent 782ab12 commit 767ac0e
Show file tree
Hide file tree
Showing 29 changed files with 1,889 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.idea/*
*.pyc
data/
checkpoints/
94 changes: 94 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<img src='docs/images/lizard2.gif' align="right" width=325>
<br><br><br>

# Point2Mesh in PyTorch


### SIGGRAPH 2020 [[Paper]](https://arxiv.org/abs/2005.11084) [[Project Page]](https://ranahanocka.github.io/point2mesh/)<br>

Point2Mesh is a technique for reconstructing a surface mesh from an input point cloud.
This approach "learns" from a single object, by optimizing the weights of a CNN to deform some initial mesh to shrink-wrap the input point cloud.
The argument for going this route is: since the (local) convolutional kernels are optimized globally across the entire shape,
this encourages local-scale geometric self-similarity across the reconstructed shape surface.

<img src="docs/images/global_anky.gif" align="center" width="250px"> <br>

The code was written by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) and Gal Metzer.

# Getting Started

### Installation
- Clone this repo:
```bash
git clone https://github.com/ranahanocka/point2mesh.git
cd point2mesh
```
#### Setup Conda Environment
- Relies on [PyTorch](https://pytorch.org/) version 1.4 (or 1.5) and [PyTorch3D](https://github.com/facebookresearch/pytorch3d) version 0.2.0.
Install via conda environment `conda env create -f environment.yml` (creates an environment called point2mesh)

#### Install "Manifold" Software
This code relies on the [Robust Watertight Manifold Software](https://github.com/hjwdzh/Manifold).
First ```cd``` into the location you wish to install the software. For example, we used ```cd ~/code```.
Then follow the installation instructions in the Watertight README.
If you installed Manifold in a different path than ```~/code/Manifold/build```, please update ```options.py``` accordingly.

# Running Examples

### Get Data
Download our example data
```bash
bash ./scripts/get_data.sh
```

### Running Reconstruction
First, if using conda env first activate env e.g. ```source activate point2mesh```.
All the scripts can be found in ```./scripts/examples```.
Here are a few examples:

#### Giraffe
```bash
bash ./scripts/examples/giraffe.sh
```

#### Bull
```bash
bash ./scripts/examples/bull.sh
```

#### Tiki
```bash
bash ./scripts/examples/tiki.sh
```

#### Noisy Guitar
```bash
bash ./scripts/examples/noisy_guitar.sh
```
... and more.
#### All the examples
To run all the examples in this repo:
```bash
bash ./scripts/run_all_examples.sh
```

# Citation
If you find this code useful, please consider citing our paper
```
@article{Hanocka2020p2m,
title = {Point2Mesh: A Self-Prior for Deformable Meshes},
author = {Hanocka, Rana and Metzer, Gal and Giryes, Raja and Cohen-Or, Daniel},
year = {2020},
issue_date = {July 2020},
publisher = {Association for Computing Machinery},
volume = {39},
number = {4},
issn = {0730-0301},
url = {https://doi.org/10.1145/3386569.3392415},
doi = {10.1145/3386569.3392415},
journal = {ACM Trans. Graph.},
}
```

# Questions / Issues
If you have questions or issues running this code, please open an issue.
Binary file added docs/images/anky_prior.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/anky_resize_17.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/global_anky.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 24 additions & 2 deletions docs/index.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

<!DOCTYPE html>
<head>
<meta charset="utf-8"/>
Expand Down Expand Up @@ -67,7 +68,7 @@ <h1 class="subheader"> A Self-Prior for Deformable Meshes</h1>
<div class="p2m_authors_list_single w-row">
<div class="w-col w-col-4 w-col-small-3 w-col-tiny-4">
<a class="authors" href="" target="_blank">
<a href="https://drive.google.com/file/d/1qn0In94-ZgkJv6k4ebZJvYu9f6rZlTQj/view" target="_blank"><i
<a href="https://arxiv.org/abs/2005.11084" target="_blank"><i
class="far fa-4x fa-file text-primary mb-3 "></i></a>
</a></div>

Expand Down Expand Up @@ -146,4 +147,25 @@ <h1 class="subheader"> A Self-Prior for Deformable Meshes</h1>
</div>


</body></html>
<div class="white_section">
<div class="w-container"><h2 class="grey-heading">Point2Mesh Overview</h2>
<p class="paragraph-3 the_text">
Point2Mesh is a technique for reconstructing a surface mesh from an input point cloud.
This approach "learns" from a single object, by optimizing the weights of a CNN to deform some initial mesh to shrink-wrap the input point cloud:
</p>
<div><span class="center"><img src="images/anky_resize_17.gif"></span></div>
<p class="paragraph-3 the_text">
The optimized CNN weights act as a <i>prior</i>, which encode the expected shape properties, which we refer to as a <i>self-prior</i>.
The premise is that shapes are <i>not</i> random, and contain strong self-correlation across multiple scales.
</p>
<div><span class="center"><img src="images/global_anky.gif"></span></div>

<p class="paragraph-3 the_text">
Central to the self-prior is the weight-sharing structure of a CNN, which inherently models recurring and
correlated structures and, hence, is weak in modeling noise and outliers, which have non-recurring geometries.
</p>
<div><span class="center"><img src="images/anky_prior.gif"></span></div>
</div>
</div>

</body></html>
6 changes: 6 additions & 0 deletions docs/min.css
Original file line number Diff line number Diff line change
Expand Up @@ -4580,4 +4580,10 @@ img {
.text-primary:hover {
color: #4cacdc !important;
opacity: 1.0;
}

.center {
display: block;
width: 100%;
text-align: center
}
17 changes: 17 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: point2mesh
channels:
- pytorch
- defaults
- conda-forge
- fvcore
- pytorch3d
dependencies:
- python=3.8.2
- numpy=1.18.1
- pytorch=1.4.0
- torchvision=0.5.0
- fvcore=0.1
- pytorch3d=0.2.0
- pip
- pip:
- pytest==5.4.2
83 changes: 83 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from models.layers.mesh import Mesh, PartMesh
from models.networks import init_net, sample_surface, local_nonuniform_penalty
import utils
import numpy as np
from models.losses import chamfer_distance
from options import Options
import time
import os

options = Options()
opts = options.args

torch.manual_seed(opts.torch_seed)
device = torch.device('cuda:{}'.format(opts.gpu) if torch.cuda.is_available() else torch.device('cpu'))
print('device: {}'.format(device))

# initial mesh
mesh = Mesh(opts.initial_mesh, device=device, hold_history=True)

# input point cloud
input_xyz, input_normals = utils.read_pts(opts.input_pc)
# normalize point cloud based on initial mesh
input_xyz /= mesh.scale
input_xyz += mesh.translations[None, :]
input_xyz = torch.Tensor(input_xyz).type(options.dtype()).to(device)[None, :, :]
input_normals = torch.Tensor(input_normals).type(options.dtype()).to(device)[None, :, :]

part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)

for i in range(opts.iterations):
num_samples = options.get_num_samples(i % opts.upsamp)
if opts.global_step:
optimizer.zero_grad()
start_time = time.time()
for part_i, est_verts in enumerate(net(rand_verts, part_mesh)):
if not opts.global_step:
optimizer.zero_grad()
part_mesh.update_verts(est_verts[0], part_i)
num_samples = options.get_num_samples(i % opts.upsamp)
recon_xyz, recon_normals = sample_surface(part_mesh.main_mesh.faces, part_mesh.main_mesh.vs.unsqueeze(0), num_samples)
# calc chamfer loss w/ normals
recon_xyz, recon_normals = recon_xyz.type(options.dtype()), recon_normals.type(options.dtype())
xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(recon_xyz, input_xyz, x_normals=recon_normals, y_normals=input_normals,
unoriented=opts.unoriented)
loss = (xyz_chamfer_loss + (opts.ang_wt * normals_chamfer_loss))
if opts.local_non_uniform > 0:
loss += opts.local_non_uniform * local_nonuniform_penalty(part_mesh.main_mesh).float()
loss.backward()
if not opts.global_step:
optimizer.step()
scheduler.step()
part_mesh.main_mesh.vs.detach_()
if opts.global_step:
optimizer.step()
scheduler.step()
end_time = time.time()

if i % 1 == 0:
print(f'{os.path.basename(opts.input_pc)}; iter: {i} out of: {opts.iterations}; loss: {loss.item():.4f};'
f' sample count: {num_samples}; time: {end_time - start_time:.2f}')
if i % opts.export_interval == 0 and i > 0:
print('exporting reconstruction... current LR: {}'.format(optimizer.param_groups[0]['lr']))
with torch.no_grad():
part_mesh.export(os.path.join(opts.save_path, f'recon_iter:{i}.obj'))

if (i > 0 and (i + 1) % opts.upsamp == 0):
mesh = part_mesh.main_mesh
num_faces = int(np.clip(len(mesh.faces) * 1.5, len(mesh.faces), opts.max_faces))

if num_faces > len(mesh.faces):
mesh = utils.manifold_upsample(mesh, opts.save_path, Mesh,
num_faces=min(num_faces, opts.max_faces),
res=opts.manifold_res, simplify=True)

part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'upsampled to {len(mesh.faces)} faces; number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)

with torch.no_grad():
mesh.export(os.path.join(opts.save_path, 'last_recon.obj'))
Empty file added models/__init__.py
Empty file.
Empty file added models/layers/__init__.py
Empty file.
Loading

0 comments on commit 767ac0e

Please sign in to comment.