Commit 5531c218 authored by miaecle's avatar miaecle
Browse files

irv&lr fix

parent 75d6a018
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -71,7 +71,6 @@ class TensorflowLogisticRegression(TensorflowGraphModel):
            bias_init=tf.constant(value=bias_init_consts[0],
                                  shape=[1]))
        lg_list.append(lg)

    return lg_list
    
  def add_label_placeholders(self, graph, name_scopes):
@@ -173,9 +172,8 @@ class TensorflowLogisticRegression(TensorflowGraphModel):
        # transfer 2D prediction tensor to 2D x n_classes(=2) 
        complimentary = np.ones(np.shape(batch_outputs))
        complimentary = complimentary - batch_outputs
        batch_outputs = np.squeeze(np.stack(arrays = [complimentary,
						      batch_outputs],
                                            axis = 2))
        batch_outputs = np.concatenate([complimentary, batch_outputs],
                                            axis = batch_outputs.ndim-1)
        # reshape to batch_size x n_tasks x ...
        if batch_outputs.ndim == 3:
          batch_outputs = batch_outputs.transpose((1, 0, 2))
+45 −0
Original line number Diff line number Diff line
"""
Script that trains multitask models on hiv dataset.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import numpy as np
import deepchem as dc
from hiv_datasets import load_hiv

# Only for debug!
np.random.seed(123)

# Load hiv dataset
n_features = 512
hiv_tasks, hiv_datasets, transformers = load_hiv()
train_dataset, valid_dataset, test_dataset = hiv_datasets

# Fit models
metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

transformer = dc.trans.IRVTransformer(10, len(hiv_tasks), train_dataset)
train_dataset = transformer.transform(train_dataset)
valid_dataset = transformer.transform(valid_dataset)

model = dc.models.TensorflowMultiTaskIRVClassifier(
        len(hiv_tasks),
        K=10,
        batch_size=50,
        learning_rate=0.001)

# Fit trained model
model.fit(train_dataset)
model.save()

print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)