Commit 6edf2440 authored by leswing's avatar leswing
Browse files

Remove Names from Tox21 GraphConv

parent 709ed42f
Loading
Loading
Loading
Loading
+19 −26
Original line number Diff line number Diff line
@@ -28,62 +28,55 @@ model_dir = "/tmp/graph_conv"
def graph_conv_model(batch_size, tasks):
  model = TensorGraph(
      model_dir=model_dir, batch_size=batch_size, use_queue=False)
  atom_features = Feature(shape=(None, 75), name="ATOMFEATURES")
  degree_slice = Feature(shape=(None, 2), dtype=tf.int32, name="DEGREE_SLICE")
  membership = Feature(shape=(None,), dtype=tf.int32, name="MEMBERSHIP")
  atom_features = Feature(shape=(None, 75))
  degree_slice = Feature(shape=(None, 2), dtype=tf.int32)
  membership = Feature(shape=(None,), dtype=tf.int32)

  deg_adjs = []
  for i in range(0, 10 + 1):
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32, name="DEGADJ%i" % i)
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
    deg_adjs.append(deg_adj)
  gc1 = GraphConvLayer(
      64,
      activation_fn=tf.nn.relu,
      in_layers=[atom_features, degree_slice, membership] + deg_adjs,
      name="CONV1")
  batch_norm1 = BatchNormLayer(in_layers=[gc1], name="BATCHNORM1")
      in_layers=[atom_features, degree_slice, membership] + deg_adjs)
  batch_norm1 = BatchNormLayer(in_layers=[gc1])
  gp1 = GraphPoolLayer(
      in_layers=[batch_norm1, degree_slice, membership] + deg_adjs,
      name="POOL1")
      in_layers=[batch_norm1, degree_slice, membership] + deg_adjs)
  gc2 = GraphConvLayer(
      64,
      activation_fn=tf.nn.relu,
      in_layers=[gp1, degree_slice, membership] + deg_adjs,
      name="CONV2")
  batch_norm2 = BatchNormLayer(in_layers=[gc2], name="BATCHNORM2")
      in_layers=[gp1, degree_slice, membership] + deg_adjs)
  batch_norm2 = BatchNormLayer(in_layers=[gc2])
  gp2 = GraphPoolLayer(
      in_layers=[batch_norm2, degree_slice, membership] + deg_adjs,
      name="POOL2")
      in_layers=[batch_norm2, degree_slice, membership] + deg_adjs)
  dense = Dense(
      out_channels=128, activation_fn=None, in_layers=[gp2], name="DENSE1")
  batch_norm3 = BatchNormLayer(in_layers=[dense], name="BATCHNORM3")
      out_channels=128, activation_fn=None, in_layers=[gp2])
  batch_norm3 = BatchNormLayer(in_layers=[dense])
  gg1 = GraphGather(
      batch_size=batch_size,
      activation_fn=tf.nn.tanh,
      in_layers=[batch_norm3, degree_slice, membership] + deg_adjs,
      name="GATHER")
      in_layers=[batch_norm3, degree_slice, membership] + deg_adjs)

  costs = []
  labels = []
  for task in tasks:
    classification = Dense(
        out_channels=2,
        name="GUESS_%s" % task,
        activation_fn=None,
        in_layers=[gg1])

    softmax = SoftMax(name="SOFTMAX_%s" % task, in_layers=[classification])
    softmax = SoftMax(in_layers=[classification])
    model.add_output(softmax)

    label = Label(shape=(None, 2), name="LABEL_%s" % task)
    label = Label(shape=(None, 2))
    labels.append(label)
    cost = SoftMaxCrossEntropy(
        name="COST_%s" % task, in_layers=[label, classification])
    cost = SoftMaxCrossEntropy(in_layers=[label, classification])
    costs.append(cost)

  entropy = Concat(name="ENT", in_layers=costs)
  task_weights = Weights(shape=(None, len(tasks)), name="W")
  loss = WeightedError(name="ERROR", in_layers=[entropy, task_weights])
  entropy = Concat(in_layers=costs)
  task_weights = Weights(shape=(None, len(tasks)))
  loss = WeightedError(in_layers=[entropy, task_weights])
  model.set_loss(loss)

  def feed_dict_generator(dataset, batch_size, epochs=1):