Commit 604e7830 authored by leswing's avatar leswing
Browse files

readability

parent 4708ba13
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -63,7 +63,7 @@ class WeaveTensorGraph(TensorGraph):
    self.pair_split = Feature(shape=(None,), dtype=tf.int32)
    self.atom_split = Feature(shape=(None,), dtype=tf.int32)
    self.atom_to_pair = Feature(shape=(None, 2), dtype=tf.int32)
    weave_layer1 = WeaveLayerFactory(
    weave_layer1A, weave_layer1P = WeaveLayerFactory(
        n_atom_input_feat=self.n_atom_feat,
        n_pair_input_feat=self.n_pair_feat,
        n_atom_output_feat=self.n_hidden,
@@ -72,20 +72,19 @@ class WeaveTensorGraph(TensorGraph):
            self.atom_features, self.pair_features, self.pair_split,
            self.atom_to_pair
        ])
    weave_layer2 = WeaveLayerFactory(
    weave_layer2A, weave_layer2P = WeaveLayerFactory(
        n_atom_input_feat=self.n_hidden,
        n_pair_input_feat=self.n_hidden,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden,
        update_pair=False,
        in_layers=[
            weave_layer1[0], weave_layer1[1], self.pair_split, self.atom_to_pair
            weave_layer1A, weave_layer1P, self.pair_split, self.atom_to_pair
        ])
    separated = weave_layer2[0]
    dense1 = Dense(
        out_channels=self.n_graph_feat,
        activation_fn=tf.nn.tanh,
        in_layers=separated)
        in_layers=weave_layer2A)
    batch_norm1 = BatchNormalization(epsilon=1e-5, mode=1, in_layers=[dense1])
    weave_gather = WeaveGather(
        self.batch_size,