Commit aa183374 authored by miaecle's avatar miaecle
Browse files

indice splitting

parent 7f4328a3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
from deepchem.feat.basic import RDKitDescriptors
from deepchem.feat.coulomb_matrices import CoulombMatrixEig
from deepchem.feat.coulomb_matrices import CoulombMatrix
from deepchem.feat.grid_featurizer import GridFeaturizer
from deepchem.feat.nnscore_utils import hydrogenate_and_compute_partial_charges
from deepchem.feat.binding_pocket_features import BindingPocketFeaturizer
+1 −0
Original line number Diff line number Diff line
@@ -10,5 +10,6 @@ from deepchem.splits.splitters import *
from deepchem.splits.splitters import ScaffoldSplitter
from deepchem.splits.splitters import SpecifiedSplitter
from deepchem.splits.splitters import IndexSplitter
from deepchem.splits.splitters import IndiceSplitter
from deepchem.splits.task_splitter import merge_fold_datasets
from deepchem.splits.task_splitter import TaskSplitter
+40 −0
Original line number Diff line number Diff line
@@ -327,6 +327,46 @@ class IndexSplitter(Splitter):
    return (indices[:train_cutoff], indices[train_cutoff:valid_cutoff],
            indices[valid_cutoff:])

class IndiceSplitter(Splitter):
  """
  Class for simple order based splits. 
  """
  def __init__(self, verbose=False, valid_indices=None, test_indices=None):
    """
    Parameters
    -----------
    valid_indices: list of int
        indices of samples in the valid set
    test_indices: list of int
        indices of samples in the test set
    """
    self.verbose = verbose
    self.valid_indices = valid_indices
    self.test_indices = test_indices
    
  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits internal compounds into train/validation/test in designated order.
        
    """
    num_datapoints = len(dataset)
    indices = np.arange(num_datapoints).tolist()
    if self.valid_indices is None:
      self.valid_indices = []
    else:
      for indice in indices:
        if indice in self.valid_indices:
          indices.remove(indice)
    if self.test_indices is None:
      self.test_indices = []
    else:
      for indice in indices:
        if indice in self.valid_indices:
          indices.remove(indice)

    return (indices, self.valid_indices, self.test_indices)


class ScaffoldSplitter(Splitter):
  """
+18 −5
Original line number Diff line number Diff line
@@ -9,8 +9,9 @@ import os
import numpy as np
import shutil
import deepchem as dc
import csv

def load_gdb7(featurizer=None, split='random'):
def load_gdb7(featurizer=None, split='indice'):
  """Load gdb7 datasets."""
  # Featurize gdb7 dataset
  print("About to featurize gdb7 dataset.")
@@ -19,7 +20,7 @@ def load_gdb7(featurizer=None, split='random'):
      current_dir, "./gdb7.sdf")
  gdb7_tasks = ["u0_atom"]
  if featurizer is None:
    featurizer = dc.feat.CoulombMatrix(23)
    featurizer = dc.feat.CoulombMatrixEig(23)
  else:
    raise ValueError('Only support Coulomb Matrix featurizer')
  loader = dc.data.SDFLoader(tasks=gdb7_tasks, smiles_field="smiles", 
@@ -28,16 +29,28 @@ def load_gdb7(featurizer=None, split='random'):
 
  # Initialize transformers 
  transformers = [
      dc.trans.NormalizationTransformer(transform_X=True, dataset=dataset),
      dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]

  print("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)
  
  split_file = os.path.join(
      current_dir, "./gdb7_splits.csv")

  split_indices = []
  with open(split_file, 'r') as f:
    reader = csv.reader(f)
    for row in reader:
      row_int = (np.asarray(list(map(int, row)))-1).tolist()
      split_indices.append(row_int)
  
  
  splitters = {'index': dc.splits.IndexSplitter(),
               'random': dc.splits.RandomSplitter()}
               'random': dc.splits.RandomSplitter(),
               'indice': dc.splits.IndiceSplitter(valid_indices=split_indices[1])}
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)
 
  print(valid.X.shape)
  print(train.X.shape)
  return gdb7_tasks, (train, valid, test), transformers
+2 −2
Original line number Diff line number Diff line
@@ -18,8 +18,8 @@ train_dataset, valid_dataset, test_dataset = datasets
regression_metric = dc.metrics.Metric(dc.metrics.mean_absolute_error, 
                                      mode="regression")
model = dc.models.TensorflowMultiTaskRegressor(
    n_tasks=len(gdb7_tasks), n_features=276,
    learning_rate=.0001, momentum=.8, batch_size=512,
    n_tasks=len(gdb7_tasks), n_features=23,
    learning_rate=.0002, momentum=.8, batch_size=512,
    weight_init_stddevs=[1/np.sqrt(2000),1/np.sqrt(800),1/np.sqrt(800),1/np.sqrt(1000)],
    bias_init_consts=[0.,0.,0.,0.], layer_sizes=[2000,800,800,1000], 
    dropouts=[0.1,0.1,0.1,0.1], seed=123)