Commit 155e3fba authored by miaecle's avatar miaecle
Browse files

low_data benchmark

parent 8545a632
Loading
Loading
Loading
Loading
+110 −2
Original line number Diff line number Diff line
@@ -28,7 +28,7 @@ def benchmark_classification(
    model,
    test=False,
    hyper_parameters=None,
    seed=123,):
    seed=123):
  """
  Calculate performance of different models on the specific dataset & tasks
  
@@ -259,7 +259,7 @@ def benchmark_regression(
    model,
    test=False,
    hyper_parameters=None,
    seed=123,):
    seed=123):
  """
  Calculate performance of different models on the specific dataset & tasks
  
@@ -428,3 +428,111 @@ def benchmark_regression(
    test_scores[model_name] = model.evaluate(test_dataset, metric, transformers)

  return train_scores, valid_scores, test_scores

def low_data_benchmark_classification(
    train_dataset, 
    valid_dataset, 
    test_dataset,
    n_features,
    metric,
    model='siamese',
    hyper_parameters=None,
    seed=123):
  """
  Calculate low data benchmark performance
  
  Parameters
  ----------
  train_dataset : dataset struct
      loaded dataset, ConvMol struct, used for training
  valid_dataset : dataset struct
      loaded dataset, ConvMol struct, used for validation
  test_dataset : dataset struct
      loaded dataset, ConvMol struct, used for test
  n_features : integer
      number of features, or length of binary fingerprints
  metric: list of dc.metrics.Metric objects
      metrics used for evaluation
  model : string,  optional (default='siamese')
      choice of which model to use, should be: siamese, attn, res
  hyper_parameters: dict, optional (default=None)
      hyper parameters for designated model, None = use preset values

  Returns
  -------
  scores : dict
	predicting results(AUC) on valid set

  """
  scores = {}
  
  assert model in ['siamese','attn','res']
  if hyper_parameters is None:
    hyper_parameters = hps[model]
  model_name = model

  # Loading hyperparameters
  # num positive/negative ligands
  n_pos = hyper_parameters['n_pos']
  n_neg = hyper_parameters['n_neg']
  # Set batch sizes for network
  test_batch_size = hyper_parameters['test_batch_size']
  support_batch_size = n_pos + n_neg
  # Model structure
  n_filters = hyper_parameters['n_filters']
  n_fully_connected_nodes = hyper_parameters['n_fully_connected_nodes']

  # Traning settings
  nb_epochs = hyper_parameters['nb_epochs']
  n_train_trials = hyper_parameters['n_train_trials']
  n_eval_trials = hyper_parameters['n_eval_trials'] 

  learning_rate = hyper_parameters['learning_rate']

  tf.set_random_seed(seed)
  support_graph = deepchem.nn.SequentialSupportGraph(n_features)
  prev_features = n_features
  for count, n_filter in enumerate(n_filters):
    support_graph.add(
        dc.nn.GraphConv(int(n_filter), prev_features, activation='relu'))
    support_graph.add(dc.nn.GraphPool())
    prev_features = int(n_filter)

  for count, n_fcnode in enumerate(n_fully_connected_nodes):
    support_graph.add(
        dc.nn.Dense(int(n_fcnode), prev_features, activation='tanh'))
    prev_features = int(n_fcnode)

  support_graph.add_test(dc.nn.GraphGather(test_batch_size, 
                                           activation='tanh'))
  support_graph.add_support(dc.nn.GraphGather(support_batch_size, 
                                              activation='tanh'))
  if model in ['siamese']:
    pass
  elif model in ['attn']:
    max_depth = hyper_parameters['max_depth']
    support_graph.join(dc.nn.AttnLSTMEmbedding(
        test_batch_size, support_batch_size, max_depth))
  elif model in ['res']:
    max_depth = hyper_parameters['max_depth']
    support_graph.join(dc.nn.ResiLSTMEmbedding(
        test_batch_size, support_batch_size, max_depth))
      
    with tf.Session() as sess:
      model_low_data = dc.models.SupportGraphClassifier(
          sess, support_graph, test_batch_size=test_batch_size,
          support_batch_size=support_batch_size, learning_rate=learning_rate)
        
      print('-------------------------------------')
      print('Start fitting by low data model: ' + model)
      # Fit trained model
      model_low_data.fit(train_dataset, nb_epochs=nb_epochs,
            n_episodes_per_epoch=n_train_trials,
            n_pos=n_pos, n_neg=n_neg,
            log_every_n_samples=50)
      # Evaluating graph convolution model
      scores[model] = model_low_data.evaluate(valid_dataset, 
                          classification_metric, n_pos, n_neg, 
                          n_trials=n_eval_trials)
      
  return scores