import os
import os.path as osp
import torch
from torch_geometric.data import (Data, InMemoryDataset, download_url,
extract_zip)
from torch_geometric.io import read_txt_array
[docs]class PCPNetDataset(InMemoryDataset):
r"""The PCPNet dataset from the `"PCPNet: Learning Local Shape Properties
from Raw Point Clouds" <https://arxiv.org/abs/1710.04954>`_ paper,
consisting of 30 shapes, each given as a point cloud, densely sampled with
100k points.
For each shape, surface normals and local curvatures are given as node
features.
Args:
root (string): Root directory where the dataset should be saved.
category (string): The training set category (one of :obj:`"NoNoise"`,
:obj:`"Noisy"`, :obj:`"VarDensity"`, :obj:`"NoisyAndVarDensity"`
for :obj:`split="train"` or :obj:`split="val"`,
or one of :obj:`"All"`, :obj:`"LowNoise"`, :obj:`"MedNoise"`,
:obj:`"HighNoise", :obj:`"VarDensityStriped",
:obj:`"VarDensityGradient"` for :obj:`split="test"`).
split (string): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
url = 'http://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/pclouds.zip'
category_files_train = {
'NoNoise': 'trainingset_no_noise.txt',
'Noisy': 'trainingset_whitenoise.txt',
'VarDensity': 'trainingset_vardensity.txt',
'NoisyAndVarDensity': 'trainingset_vardensity_whitenoise.txt'
}
category_files_val = {
'NoNoise': 'validationset_no_noise.txt',
'Noisy': 'validationset_whitenoise.txt',
'VarDensity': 'validationset_vardensity.txt',
'NoisyAndVarDensity': 'validationset_vardensity_whitenoise.txt'
}
category_files_test = {
'All': 'testset_all.txt',
'NoNoise': 'testset_no_noise.txt',
'LowNoise': 'testset_low_noise.txt',
'MedNoise': 'testset_med_noise.txt',
'HighNoise': 'testset_high_noise.txt',
'VarDensityStriped': 'testset_vardensity_striped.txt',
'VarDensityGradient': 'testset_vardensity_gradient.txt'
}
def __init__(self, root, category, split='train', transform=None,
pre_transform=None, pre_filter=None):
assert split in ['train', 'val', 'test']
if split == 'train':
assert category in self.category_files_train.keys()
elif split == 'val':
assert category in self.category_files_val.keys()
else:
assert category in self.category_files_test.keys()
self.category = category
self.split = split
super(PCPNetDataset, self).__init__(root, transform, pre_transform,
pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
if self.split == 'train':
return self.category_files_train[self.category]
elif self.split == 'val':
return self.category_files_val[self.category]
else:
return self.category_files_test[self.category]
@property
def processed_file_names(self):
return self.split + '_' + self.category + '.pt'
def download(self):
path = download_url(self.url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
def process(self):
path_file = self.raw_paths
with open(path_file[0], "r") as f:
filenames = f.read().split('\n')[:-1]
data_list = []
for filename in filenames:
pos_path = osp.join(self.raw_dir, filename + '.xyz')
normal_path = osp.join(self.raw_dir, filename + '.normals')
curv_path = osp.join(self.raw_dir, filename + '.curv')
idx_path = osp.join(self.raw_dir, filename + '.pidx')
pos = read_txt_array(pos_path)
normals = read_txt_array(normal_path)
curv = read_txt_array(curv_path)
normals_and_curv = torch.cat([normals, curv], dim=1)
test_idx = read_txt_array(idx_path, dtype=torch.long)
data = Data(pos=pos, x=normals_and_curv)
data.test_idx = test_idx
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[0])
def __repr__(self):
return '{}({}, category={})'.format(self.__class__.__name__, len(self),
self.category)