Commit 0bcad71e authored by ZHENQIN WU's avatar ZHENQIN WU
Browse files

Merge remote-tracking branch 'remotes/mine/temp'

parents a38ee74f 28a7e59e
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -140,8 +140,9 @@ CheckFeaturizer = {
    ('sampl', 'weave_regression'): ['Weave', 75],
    ('kaggle', 'tf_regression'): [None, 14293],
    ('kaggle', 'rf_regression'): [None, 14293],
    ('pdbbind', 'tf_regression'): ['grid', 2052],
    ('pdbbind', 'rf_regression'): ['grid', 2052],
    ('pdbbind', 'tf_regression'): ['ECFP', 1024],
    ('pdbbind', 'rf_regression'): ['ECFP', 1024],
    ('pdbbind', 'graphconvreg'): ['GraphConv', 75],
    ('qm7', 'tf_regression'): ['ECFP', 1024],
    ('qm7', 'rf_regression'): ['ECFP', 1024],
    ('qm7', 'graphconvreg'): ['GraphConv', 75],
+72 −17
Original line number Diff line number Diff line
@@ -52,9 +52,10 @@ def featurize_pdbbind(data_dir=None, feat="grid", subset="core"):

def load_pdbbind_grid(split="random",
                      featurizer="grid",
                      subset="full",
                      subset="core",
                      reload=True):
  """Load PDBBind datasets. Does not do train/test split"""
  if featurizer == 'grid':
    dataset, tasks = featurize_pdbbind(feat=featurizer, subset=subset)

    splitters = {
@@ -71,5 +72,59 @@ def load_pdbbind_grid(split="random",
      valid = transformer.transform(valid)
    for transformer in transformers:
      test = transformer.transform(test)
  else:
    if "DEEPCHEM_DATA_DIR" in os.environ:
      data_dir = os.environ["DEEPCHEM_DATA_DIR"]
    else:
      data_dir = "/tmp"
    if reload:
      save_dir = os.path.join(data_dir, "pdbbind_" + subset + "/" + featurizer + "/" + split)

    dataset_file = os.path.join(data_dir, subset + "_smiles_labels.csv")

    if not os.path.exists(dataset_file):
      os.system(
          'wget -P ' + data_dir +
          ' http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/' + subset + "_smiles_labels.csv"
      )

    tasks = ["-logKd/Ki"]
    if reload:
      loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
          save_dir)
      if loaded:
        return tasks, all_dataset, transformers

    if featurizer == 'ECFP':
      featurizer = deepchem.feat.CircularFingerprint(size=1024)
    elif featurizer == 'GraphConv':
      featurizer = deepchem.feat.ConvMolFeaturizer()
    elif featurizer == 'Weave':
      featurizer = deepchem.feat.WeaveFeaturizer()
    elif featurizer == 'Raw':
      featurizer = deepchem.feat.RawFeaturizer()

    loader = deepchem.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(dataset_file, shard_size=8192)
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset)
    ]

    for transformer in transformers:
      dataset = transformer.transform(dataset)

    splitters = {
        'index': deepchem.splits.IndexSplitter(),
        'random': deepchem.splits.RandomSplitter(),
        'scaffold': deepchem.splits.ScaffoldSplitter()
    }
    splitter = splitters[split]
    train, valid, test = splitter.train_valid_test_split(dataset)

    if reload:
      deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                               transformers)

  return tasks, (train, valid, test), transformers