import os
import os.path as osp
import torch
import torch.nn.functional as F
from torch_sparse import coalesce
from torch_geometric.data import (InMemoryDataset, download_url, extract_zip,
Data)
try:
import rdkit
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem.rdchem import HybridizationType
from rdkit import RDConfig
from rdkit.Chem import ChemicalFeatures
from rdkit.Chem.rdchem import BondType as BT
rdBase.DisableLog('rdApp.error')
except ImportError:
rdkit = None
[docs]class QM9(InMemoryDataset):
r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular
Machine Learning" <https://arxiv.org/abs/1703.00564>`_ paper, consisting of
about 130,000 molecules with 16 regression targets.
Each molecule includes complete spatial information for the single low
energy conformation of the atoms in the molecule.
In addition, we provide the atom features from the `"Neural Message
Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_ paper.
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Target | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| 0 | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 1 | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 2 | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 3 | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 4 | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 5 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 6 | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 7 | :math:`U_0` | Internal energy at 0K | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 8 | :math:`U` | Internal energy at 298.15K | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 9 | :math:`H` | Enthalpy at 298.15K | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 10 | :math:`G` | Free energy at 298.15K | :math:`E_{\textrm{h}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 11 | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 12 | :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\frac{\textrm{kcal}}{\textrm{mol}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 13 | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\frac{\textrm{kcal}}{\textrm{mol}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 14 | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\frac{\textrm{kcal}}{\textrm{mol}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| 15 | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\frac{\textrm{kcal}}{\textrm{mol}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Args:
root (string): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
""" # noqa: E501
raw_url = ('https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/'
'molnet_publish/qm9.zip')
processed_url = 'http://www.roemisch-drei.de/qm9.zip'
if rdkit is not None:
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
def __init__(self, root, transform=None, pre_transform=None,
pre_filter=None):
super(QM9, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return 'qm9.pt' if rdkit is None else ['gdb9.sdf', 'gdb9.sdf.csv']
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
url = self.processed_url if rdkit is None else self.raw_url
file_path = download_url(url, self.raw_dir)
extract_zip(file_path, self.raw_dir)
os.unlink(file_path)
def process(self):
if rdkit is None:
print('Using a pre-processed version of the dataset. Please '
'install `rdkit` to alternatively process the raw data.')
self.data, self.slices = torch.load(self.raw_paths[0])
data_list = [data for data in self]
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
return
with open(self.raw_paths[1], 'r') as f:
target = f.read().split('\n')[1:-1]
target = [[float(x) for x in line.split(',')[4:20]]
for line in target]
target = torch.tensor(target, dtype=torch.float)
suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
data_list = []
for i, mol in enumerate(suppl):
if mol is None:
continue
text = suppl.GetItemText(i)
N = mol.GetNumAtoms()
pos = text.split('\n')[4:4 + N]
pos = [[float(x) for x in line.split()[:3]] for line in pos]
pos = torch.tensor(pos, dtype=torch.float)
type_idx = []
atomic_number = []
acceptor = []
donor = []
aromatic = []
sp = []
sp2 = []
sp3 = []
num_hs = []
for atom in mol.GetAtoms():
type_idx.append(self.types[atom.GetSymbol()])
atomic_number.append(atom.GetAtomicNum())
donor.append(0)
acceptor.append(0)
aromatic.append(1 if atom.GetIsAromatic() else 0)
hybridization = atom.GetHybridization()
sp.append(1 if hybridization == HybridizationType.SP else 0)
sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))
feats = factory.GetFeaturesForMol(mol)
for j in range(0, len(feats)):
if feats[j].GetFamily() == 'Donor':
node_list = feats[j].GetAtomIds()
for k in node_list:
donor[k] = 1
elif feats[j].GetFamily() == 'Acceptor':
node_list = feats[j].GetAtomIds()
for k in node_list:
acceptor[k] = 1
x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.types))
x2 = torch.tensor([
atomic_number, acceptor, donor, aromatic, sp, sp2, sp3, num_hs
], dtype=torch.float).t().contiguous()
x = torch.cat([x1.to(torch.float), x2], dim=-1)
row, col, bond_idx = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
bond_idx += 2 * [self.bonds[bond.GetBondType()]]
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = F.one_hot(torch.tensor(bond_idx),
num_classes=len(self.bonds)).to(torch.float)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
y = target[i].unsqueeze(0)
name = mol.GetProp('_Name')
data = Data(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr, y=y, name=name)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[0])