Unverified Commit 369d9b51 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1660 from VIGS25/molnet-addons-chemception

Stratified splitters, and minor changes for MolNet
parents 371163ff a1589b46
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -280,7 +280,7 @@ class ChemCeption(KerasModel):
    avg_pooling_out = GlobalAveragePooling2D()(inceptionC_out)

    if self.mode == "classification":
      logits = Dense(self.n_tasks * self.n_classes)(rnn_embeddings)
      logits = Dense(self.n_tasks * self.n_classes)(avg_pooling_out)
      logits = Reshape((self.n_tasks, self.n_classes))(logits)
      if self.n_classes == 2:
        output = Activation(activation='sigmoid')(logits)
+22 −4
Original line number Diff line number Diff line
@@ -90,11 +90,20 @@ def load_bace_regression(featurizer='ECFP',
  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter()
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'stratified': deepchem.splits.SingletaskStratifiedSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split data using {} splitter".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)

  transformers = [
      deepchem.trans.NormalizationTransformer(
@@ -182,12 +191,21 @@ def load_bace_classification(featurizer='ECFP',
  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter()
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'stratified': deepchem.splits.RandomStratifiedSplitter()
  }

  splitter = splitters[split]
  logger.info("About to split data using {} splitter".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)

  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
+20 −4
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ def load_bbbc001(split='index',
    save_dir = DEFAULT_DIR

  if reload:
    save_folder = os.path.join(save_dir, "bbbc001-featurized/" + str(split))
    save_folder = os.path.join(save_dir, "bbbc001-featurized", str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_folder)
    if loaded:
@@ -86,7 +86,15 @@ def load_bbbc001(split='index',
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)
  transformers = []
  all_dataset = (train, valid, test)
  if reload:
@@ -117,7 +125,7 @@ def load_bbbc002(split='index',
    save_dir = DEFAULT_DIR

  if reload:
    save_folder = os.path.join(save_dir, "bbbc002-featurized/" + str(split))
    save_folder = os.path.join(save_dir, "bbbc002-featurized", str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_folder)
    if loaded:
@@ -161,7 +169,15 @@ def load_bbbc002(split='index',
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)
  all_dataset = (train, valid, test)
  transformers = []
  if reload:
+9 −1
Original line number Diff line number Diff line
@@ -83,7 +83,15 @@ def load_bbbp(featurizer='ECFP',
  }
  splitter = splitters[split]
  logger.info("About to split data with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)

  # Initialize transformers
  transformers = [
+10 −3
Original line number Diff line number Diff line
@@ -36,8 +36,7 @@ def load_cell_counting(split=None,
  # For now images are loaded directly by ImageLoader
  featurizer = ""
  if reload:
    save_folder = os.path.join(save_dir,
                               "cell_counting-featurized/" + str(split))
    save_folder = os.path.join(save_dir, "cell_counting-featurized", str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_folder)
    if loaded:
@@ -64,7 +63,15 @@ def load_cell_counting(split=None,
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)
  transformers = []
  all_dataset = (train, valid, test)
  if reload:
Loading