Unverified Commit 37da8e39 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2167 from ncfrey/qm9_fix

QM9 loader bugfix
parents 651df9f9 89ac2c4b
Loading
Loading
Loading
Loading
+4 −8
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ def load_qm9(featurizer='CoulombMatrix',
  QM9 is a comprehensive dataset that provides geometric, energetic, 
  electronic and thermodynamic properties for a subset of GDB-17 database, 
  comprising 134 thousand stable organic molecules with up to 9 heavy atoms.
  All moleucles are modeled using density functional theory
  All molecules are modeled using density functional theory
  (B3LYP/6-31G(2df,p) based DFT).

  Random splitting is recommended for this dataset.
@@ -119,11 +119,7 @@ def load_qm9(featurizer='CoulombMatrix',
    elif featurizer == 'MP':
      featurizer = deepchem.feat.WeaveFeaturizer(
          graph_distance=False, explicit_H=True)
    loader = deepchem.data.SDFLoader(
        tasks=qm9_tasks,
        smiles_field="smiles",
        mol_field="mol",
        featurizer=featurizer)
    loader = deepchem.data.SDFLoader(tasks=qm9_tasks, featurizer=featurizer)
  else:
    if featurizer == 'ECFP':
      featurizer = deepchem.feat.CircularFingerprint(size=1024)
@@ -137,9 +133,9 @@ def load_qm9(featurizer='CoulombMatrix',
      featurizer = deepchem.feat.SmilesToImage(
          img_size=img_size, img_spec=img_spec)
    loader = deepchem.data.CSVLoader(
        tasks=qm9_tasks, smiles_field="smiles", featurizer=featurizer)
        tasks=qm9_tasks, feature_field="smiles", featurizer=featurizer)

  dataset = loader.featurize(dataset_file)
  dataset = loader.create_dataset(dataset_file)
  if split == None:
    raise ValueError()

+11 −0
Original line number Diff line number Diff line
mol_id,smiles,A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom
gdb_1,C,157.7118,157.70997,157.70699,0.0,13.21,-0.3877,0.1171,0.5048,35.3641,0.044749000000000004,-40.47893,-40.476062,-40.475117,-40.498597,6.468999999999999,-395.99959459400003,-398.643290011,-401.01464652199996,-372.471772148
gdb_2,N,293.60975,293.54111,191.39397,1.6256,9.46,-0.257,0.0829,0.3399,26.1563,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316,-276.861363363,-278.62027109,-280.399259105,-259.338802047
gdb_3,O,799.58812,437.90385999999995,282.94545,1.8511,6.31,-0.2928,0.0687,0.3615,19.0002,0.021375,-76.404702,-76.40186700000001,-76.400922,-76.422349,6.002000000000001,-213.08762369299998,-213.97429391,-215.15965841099998,-201.407171167
gdb_4,C#C,0.0,35.6100361,35.6100361,0.0,16.28,-0.2845,0.0506,0.3351,59.5248,0.026841000000000004,-77.30842700000001,-77.305527,-77.304583,-77.32742900000001,8.574,-385.501996533,-387.23768642699997,-389.01604693300004,-365.800723969
gdb_5,C#N,0.0,44.593883,44.593883,2.8937,12.99,-0.3604,0.0191,0.3796,48.7476,0.016600999999999998,-93.411888,-93.40937,-93.408425,-93.431246,6.278,-301.820533838,-302.906751917,-304.091488909,-288.720028445
gdb_6,C=O,285.48839,38.9823,34.29892,2.1089,14.18,-0.267,-0.0406,0.2263,59.9891,0.026602999999999998,-114.48361299999999,-114.480746,-114.479802,-114.50526799999999,6.412999999999999,-358.756935444,-360.512705626,-362.29106613199997,-340.464420585
gdb_7,CC,80.46225,19.906489999999998,19.90633,0.0,23.95,-0.3385,0.1041,0.4426,109.5031,0.074542,-79.764152,-79.760666,-79.759722,-79.787269,10.097999999999999,-670.78829573,-675.7104763259999,-679.860820852,-626.927299157
gdb_8,CO,127.83497,24.85872,23.978720000000003,1.5258,16.97,-0.2653,0.0784,0.3437,83.794,0.051208000000000004,-115.67913600000001,-115.675816,-115.674872,-115.701876,8.751,-481.10675773699995,-484.35537183,-487.319724346,-450.124128371
gdb_9,CC#C,160.28041000000002,8.59323,8.593210000000001,0.7156,28.78,-0.2609,0.0613,0.3222,177.1963,0.05541,-116.609549,-116.60555,-116.604606,-116.633775,12.482000000000001,-670.268090769,-673.980434013,-677.537155025,-631.346845044
gdb_10,CC#N,159.03566999999998,9.22327,9.223239999999999,3.8266,24.45,-0.3264,0.0376,0.364,160.7223,0.045286,-132.71815,-132.714563,-132.713619,-132.742149,10.287,-589.8120243340001,-592.893721033,-595.85744604,-557.125708033
+25 −0
Original line number Diff line number Diff line
"""
Tests for qm9 loader.
"""

import os
import numpy as np
from deepchem.molnet import load_qm9


def test_qm9_loader():
  current_dir = os.path.dirname(os.path.abspath(__file__))
  tasks, datasets, transformers = load_qm9(
      reload=False,
      data_dir=current_dir,
      featurizer='ECFP',
      splitter_kwargs={
          'seed': 42,
          'frac_train': 0.6,
          'frac_valid': 0.2,
          'frac_test': 0.2
      })

  assert len(tasks) == 12
  assert tasks[0] == 'mu'
  assert datasets[0].X.shape == (8, 1024)