Source code for torch_geometric.datasets.pcpnet_dataset

import os
import os.path as osp

import torch
from torch_geometric.data import (Data, InMemoryDataset, download_url,
                                  extract_zip)
from torch_geometric.io import read_txt_array


[docs]class PCPNetDataset(InMemoryDataset): r"""The PCPNet dataset from the `"PCPNet: Learning Local Shape Properties from Raw Point Clouds" <https://arxiv.org/abs/1710.04954>`_ paper, consisting of 30 shapes, each given as a point cloud, densely sampled with 100k points. For each shape, surface normals and local curvatures are given as node features. Args: root (string): Root directory where the dataset should be saved. category (string): The training set category (one of :obj:`"NoNoise"`, :obj:`"Noisy"`, :obj:`"VarDensity"`, :obj:`"NoisyAndVarDensity"` for :obj:`split="train"` or :obj:`split="val"`, or one of :obj:`"All"`, :obj:`"LowNoise"`, :obj:`"MedNoise"`, :obj:`"HighNoise", :obj:`"VarDensityStriped", :obj:`"VarDensityGradient"` for :obj:`split="test"`). split (string): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = 'http://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/pclouds.zip' category_files_train = { 'NoNoise': 'trainingset_no_noise.txt', 'Noisy': 'trainingset_whitenoise.txt', 'VarDensity': 'trainingset_vardensity.txt', 'NoisyAndVarDensity': 'trainingset_vardensity_whitenoise.txt' } category_files_val = { 'NoNoise': 'validationset_no_noise.txt', 'Noisy': 'validationset_whitenoise.txt', 'VarDensity': 'validationset_vardensity.txt', 'NoisyAndVarDensity': 'validationset_vardensity_whitenoise.txt' } category_files_test = { 'All': 'testset_all.txt', 'NoNoise': 'testset_no_noise.txt', 'LowNoise': 'testset_low_noise.txt', 'MedNoise': 'testset_med_noise.txt', 'HighNoise': 'testset_high_noise.txt', 'VarDensityStriped': 'testset_vardensity_striped.txt', 'VarDensityGradient': 'testset_vardensity_gradient.txt' } def __init__(self, root, category, split='train', transform=None, pre_transform=None, pre_filter=None): assert split in ['train', 'val', 'test'] if split == 'train': assert category in self.category_files_train.keys() elif split == 'val': assert category in self.category_files_val.keys() else: assert category in self.category_files_test.keys() self.category = category self.split = split super(PCPNetDataset, self).__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): if self.split == 'train': return self.category_files_train[self.category] elif self.split == 'val': return self.category_files_val[self.category] else: return self.category_files_test[self.category] @property def processed_file_names(self): return self.split + '_' + self.category + '.pt' def download(self): path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self): path_file = self.raw_paths with open(path_file[0], "r") as f: filenames = f.read().split('\n')[:-1] data_list = [] for filename in filenames: pos_path = osp.join(self.raw_dir, filename + '.xyz') normal_path = osp.join(self.raw_dir, filename + '.normals') curv_path = osp.join(self.raw_dir, filename + '.curv') idx_path = osp.join(self.raw_dir, filename + '.pidx') pos = read_txt_array(pos_path) normals = read_txt_array(normal_path) curv = read_txt_array(curv_path) normals_and_curv = torch.cat([normals, curv], dim=1) test_idx = read_txt_array(idx_path, dtype=torch.long) data = Data(pos=pos, x=normals_and_curv) data.test_idx = test_idx if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) torch.save(self.collate(data_list), self.processed_paths[0]) def __repr__(self): return '{}({}, category={})'.format(self.__class__.__name__, len(self), self.category)