Commit 5358af0f authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Add more data folder options to ANI

parent d418c4d2
Loading
Loading
Loading
Loading
+21 −11
Original line number Diff line number Diff line
@@ -12,6 +12,16 @@ def convert_species_to_atomic_nums(s):
    res.append(PERIODIC_TABLE[k])
  return np.array(res, dtype=np.float32)


# replace with your own scratch directory
data_dir = "/media/yutong/datablob/deepchem"

all_dir = os.path.join(data_dir, "all")
test_dir = os.path.join(data_dir, "test")
fold_dir = os.path.join(data_dir, "fold")
train_dir = os.path.join(fold_dir, "train")
valid_dir = os.path.join(fold_dir, "valid")

def load_roiterberg_ANI(mode="atomization"):
  """
  Load the ANI dataset.
@@ -49,8 +59,8 @@ def load_roiterberg_ANI(mode="atomization"):
      'ani_gdb_s04.h5',
      'ani_gdb_s05.h5',
      'ani_gdb_s06.h5',
      # 'ani_gdb_s07.h5',
      # 'ani_gdb_s08.h5'
      'ani_gdb_s07.h5',
      'ani_gdb_s08.h5'
  ]

  hdf5files = [os.path.join(base_dir, f) for f in hdf5files]
@@ -59,7 +69,7 @@ def load_roiterberg_ANI(mode="atomization"):

  def shard_generator():

    shard_size = 1024 * 64
    shard_size = 4096 * 64

    row_idx = 0
    group_idx = 0
@@ -151,13 +161,13 @@ def load_roiterberg_ANI(mode="atomization"):
      yield np.array(X_cache), np.array(y_cache), np.array(w_cache), np.array(ids_cache)

  tasks = ["ani"]
  dataset = dc.data.DiskDataset.create_dataset(shard_generator(), tasks=tasks)
  dataset = dc.data.DiskDataset.create_dataset(shard_generator(), tasks=tasks, data_dir=all_dir)

  print("Number of groups", np.amax(groups))
  splitter = dc.splits.RandomGroupSplitter(groups)

  train_dataset, test_dataset = splitter.train_test_split(
      dataset, frac_train=.8)
      dataset, train_dir=fold_dir, test_dir=test_dir, frac_train=.8)

  print(train_dataset.y)

@@ -186,7 +196,7 @@ if __name__ == "__main__":
      dc.metrics.Metric(dc.metrics.pearson_r2_score, mode="regression")
  ]

  model_dir = "/tmp/ani3.pkl"
  model_dir = "/tmp/ani8.pkl"

  if os.path.exists(model_dir):
    print("Restoring existing model...")
@@ -198,11 +208,11 @@ if __name__ == "__main__":

    splitter = dc.splits.RandomGroupSplitter(broadcast(train_valid_dataset, all_groups))

    print("performing 1-fold split...")

    # for fold in range(n_folds):
    print("Folding once....")
    train_dataset, valid_dataset = splitter.train_test_split(train_valid_dataset)
    print("Performing 1-fold split...")
    train_dataset, valid_dataset = splitter.train_test_split(
      train_valid_dataset,
      train_dir=train_dir,
      test_dir=valid_dir)

    transformers = [
        dc.trans.NormalizationTransformer(