Commit 14855aa7 authored by cc's avatar cc
Browse files

Increased ChEMBL dataset size, and added year split

parent 8f04d1dd
Loading
Loading
Loading
Loading

datasets/chembl.csv

deleted100644 → 0
+0 −19107

File deleted.

Preview size limit exceeded, changes collapsed.

+988 KiB

File added.

No diff preview for this file type.

+8.11 MiB

File added.

No diff preview for this file type.

+54 −30
Original line number Diff line number Diff line
@@ -14,26 +14,8 @@ import sys
sys.path.append(".")
from chembl_tasks import chembl_tasks

def remove_missing_entries(dataset):
    """Remove missing entries.

    Some of the datasets have missing entries that sneak in as zero'd out
    feature vectors. Get rid of them.
    """
    for i, (X, y, w, ids) in enumerate(dataset.itershards()):
        available_rows = X.any(axis=1)
        print("Shard %d has %d missing entries."
              % (i, np.count_nonzero(~available_rows)))
        X = X[available_rows]
        y = y[available_rows]
        w = w[available_rows]
        ids = ids[available_rows]
        dataset.set_shard(i, X, y, w, ids)


# Set shard size low to avoid memory problems.
def load_chembl(shard_size=2000, featurizer="ECFP", split='random'):
    """Load KAGGLE datasets. Does not do train/test split"""
def load_chembl(shard_size=2000, featurizer="ECFP", set="5thresh", split="random"):
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
@@ -42,10 +24,19 @@ def load_chembl(shard_size=2000, featurizer="ECFP", split='random'):

    # Load dataset
    print("About to load ChEMBL dataset.")
    if split == "year":
        train_datasets, valid_datasets, test_datasets = [], [], []
        train_files = os.path.join(current_dir,
                                   "year_sets/chembl_%s_ts_train.csv.gz" % set)
        valid_files = os.path.join(current_dir,
                                   "year_sets/chembl_%s_ts_valid.csv.gz" % set)
        test_files = os.path.join(current_dir,
                                  "year_sets/chembl_%s_ts_test.csv.gz" % set)
    else:
        dataset_path = os.path.join(
        current_dir, "../../datasets/chembl.csv")
            current_dir, "../../datasets/chembl_%s.csv.gz" % set)

    # Featurize KAGGLE dataset
    # Featurize ChEMBL dataset
    print("About to featurize ChEMBL dataset.")
    if featurizer == 'ECFP':
        featurizer = dc.feat.CircularFingerprint(size=1024)
@@ -55,10 +46,30 @@ def load_chembl(shard_size=2000, featurizer="ECFP", split='random'):
    loader = dc.data.CSVLoader(
        tasks=chembl_tasks, smiles_field="smiles", featurizer=featurizer)

    if split == "year":
        print("Featurizing train datasets")
        train_dataset = loader.featurize(
            train_files, shard_size=shard_size)

        print("Featurizing valid datasets")
        valid_dataset = loader.featurize(
            valid_files, shard_size=shard_size)

        print("Featurizing test datasets")
        test_dataset = loader.featurize(
            test_files, shard_size=shard_size)
    else:
        dataset = loader.featurize(dataset_path, shard_size=shard_size)

    # Initialize transformers
    print("About to transform data")
    if split == "year":
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=train_dataset)]
        for transformer in transformers:
            for dataset in [train_dataset, valid_dataset, test_dataset]:
                transformer.transform(dataset)
    else:
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]
        for transformer in transformers:
@@ -67,8 +78,21 @@ def load_chembl(shard_size=2000, featurizer="ECFP", split='random'):
    splitters = {'index': dc.splits.IndexSplitter(),
                 'random': dc.splits.RandomSplitter(),
                 'scaffold': dc.splits.ScaffoldSplitter()}
    if split in splitters:
        splitter = splitters[split]
        print("Performing new split.")
        train, valid, test = splitter.train_valid_test_split(dataset)
    elif split == "year":
        print("Featurizing train datasets")
        train = loader.featurize(
            train_files, shard_size=shard_size)

        print("Featurizing valid datasets")
        valid = loader.featurize(
            valid_files, shard_size=shard_size)

        print("Featurizing test datasets")
        test = loader.featurize(
            test_files, shard_size=shard_size)

    return chembl_tasks, (train, valid, test), transformers
+3 −3
Original line number Diff line number Diff line
@@ -20,8 +20,8 @@ K.set_session(sess)

with g.as_default():
  tf.set_random_seed(123)
  chembl_tasks, datasets, transformers = load_chembl(
      featurizer='GraphConv', split='index')
  chembl_tasks, datasets, transformers = load_chembl(shard_size=2000,
    featurizer="ECFP", set="5thresh", split="random")
  train_dataset, valid_dataset, test_dataset = datasets

  # Fit models
@@ -52,7 +52,7 @@ with g.as_default():
      optimizer_type="adam", beta1=.9, beta2=.999)

    # Fit trained model
    model.fit(train_dataset, nb_epoch=25)
    model.fit(train_dataset, nb_epoch=20)

    print("Evaluating model")
    train_scores = model.evaluate(train_dataset, [metric], transformers)
Loading