import torch
import torch.nn.functional as F
from torch_scatter import scatter_add
[docs]class GenerateMeshNormals(object):
r"""Generate normal vectors for each mesh node based on neighboring
faces."""
def __call__(self, data):
assert 'face' in data
pos, face = data.pos, data.face
vec1 = pos[face[1]] - pos[face[0]]
vec2 = pos[face[2]] - pos[face[0]]
face_norm = F.normalize(vec1.cross(vec2), p=2, dim=-1) # [F, 3]
idx = torch.cat([face[0], face[1], face[2]], dim=0)
face_norm = face_norm.repeat(3, 1)
norm = scatter_add(face_norm, idx, dim=0, dim_size=pos.size(0))
norm = F.normalize(norm, p=2, dim=-1) # [N, 3]
data.norm = norm
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)