Commit 2aac4331 authored by miaecle's avatar miaecle
Browse files

update progressive

parent 68798d1f
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ from deepchem.utils.save import log
from deepchem.metrics import to_one_hot
from deepchem.metrics import from_one_hot
from deepchem.models.tensorgraph.tensor_graph import TensorGraph, TFWrapper
from deepchem.models.tensorgraph.layers import Feature, Label, Weights, \
from deepchem.models.tensorgraph.layers import Layer, Feature, Label, Weights, \
    WeightedError, Dense, Dropout, WeightDecay, Reshape, SoftMaxCrossEntropy, \
    L2Loss, ReduceSum, Concat, Stack

@@ -157,6 +157,16 @@ class ProgressiveMultitaskRegressor(TensorGraph):
          layer = layer + lateral_contrib
      outputs.append(layer)
    output = Concat(in_layers=outputs)
    self.add_output(output)
    labels = Label(shape=(None, n_tasks))
    weights = Weights(shape=(None, n_tasks))
    weighted_loss = ReduceSum(L2Loss(in_layers=[labels, output, weights]))
    if weight_decay_penalty != 0.0:
      weighted_loss = WeightDecay(
          weight_decay_penalty,
          weight_decay_penalty_type,
          in_layers=[weighted_loss])
    self.set_loss(weighted_loss)