Unverified Commit 3ce5d7cf authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1958 from deepchem/irv

Rename TensorflowMultitaskIRVClassifier to MultitaskIRVClassifier
parents 347335b0 a476d75c
Loading
Loading
Loading
Loading
+14 −3
Original line number Diff line number Diff line
@@ -79,7 +79,7 @@ class Slice(Layer):
    return tf.slice(inputs, [0] * axis + [slice_num], [-1] * axis + [1])


class TensorflowMultitaskIRVClassifier(KerasModel):
class MultitaskIRVClassifier(KerasModel):

  def __init__(self,
               n_tasks,
@@ -87,7 +87,7 @@ class TensorflowMultitaskIRVClassifier(KerasModel):
               penalty=0.0,
               mode="classification",
               **kwargs):
    """Initialize TensorflowMultitaskIRVClassifier
    """Initialize MultitaskIRVClassifier

    Parameters
    ----------
@@ -119,8 +119,19 @@ class TensorflowMultitaskIRVClassifier(KerasModel):
        if len(logits) == 1 else Concatenate(axis=1)(logits)
    ]
    model = tf.keras.Model(inputs=[mol_features], outputs=outputs)
    super(TensorflowMultitaskIRVClassifier, self).__init__(
    super(MultitaskIRVClassifier, self).__init__(
        model,
        SigmoidCrossEntropy(),
        output_types=['prediction', 'loss'],
        **kwargs)


class TensorflowMultitaskIRVClassifier(MultitaskIRVClassifier):

  def __init__(self, *args, **kwargs):

    warnings.warn(
        "TensorflowMultitaskIRVClassifier is deprecated and has been renamed to MultitaskIRVClassifier",
        FutureWarning)

    super(TensorflowMultitaskIRVClassifier, self).__init__(*args, **kwargs)
+2 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ from deepchem.models.callbacks import ValidationCallback
from deepchem.models.fcnet import MultitaskRegressor
from deepchem.models.fcnet import MultitaskClassifier
from deepchem.models.fcnet import MultitaskFitTransformRegressor
from deepchem.models.IRV import TensorflowMultitaskIRVClassifier
from deepchem.models.IRV import MultitaskIRVClassifier
from deepchem.models.robust_multitask import RobustMultitaskClassifier
from deepchem.models.robust_multitask import RobustMultitaskRegressor
from deepchem.models.progressive_multitask import ProgressiveMultitaskRegressor, ProgressiveMultitaskClassifier
@@ -29,3 +29,4 @@ from deepchem.models.chemnet_models import Smiles2Vec, ChemCeption

from deepchem.models.text_cnn import TextCNNTensorGraph
from deepchem.models.graph_models import WeaveTensorGraph, DTNNTensorGraph, DAGTensorGraph, GraphConvTensorGraph, MPNNTensorGraph
from deepchem.models.IRV import TensorflowMultitaskIRVClassifier
+1 −1
Original line number Diff line number Diff line
@@ -429,7 +429,7 @@ class TestOverfit(test_util.TensorFlowTestCase):
    dataset_trans = IRV_transformer.transform(dataset)
    classification_metric = dc.metrics.Metric(
        dc.metrics.accuracy_score, task_averager=np.mean)
    model = dc.models.TensorflowMultitaskIRVClassifier(
    model = dc.models.MultitaskIRVClassifier(
        n_tasks, K=5, learning_rate=0.01, batch_size=n_samples)

    # Fit trained model