Commit 003c685c authored by leswing's avatar leswing
Browse files

Fix Tests

parent 3f3ec446
Loading
Loading
Loading
Loading
+2 −14
Original line number Diff line number Diff line
@@ -382,26 +382,14 @@ def test_WeightedError_pickle():
  tg.save()


def test_Combine_Separate_AP_pickle():
  tg = TensorGraph()
  atom_feature = Feature(shape=(None, 10))
  pair_feature = Feature(shape=(None, 5))
  C_AP = Combine_AP(in_layers=[atom_feature, pair_feature])
  S_AP = Separate_AP(in_layers=[C_AP])
  tg.add_output(S_AP)
  tg.set_loss(S_AP)
  tg.build()
  tg.save()


def test_Weave_pickle():
  tg = TensorGraph()
  atom_feature = Feature(shape=(None, 75))
  pair_feature = Feature(shape=(None, 14))
  pair_split = Feature(shape=(None,), dtype=tf.int32)
  atom_to_pair = Feature(shape=(None, 2), dtype=tf.int32)
  C_AP = Combine_AP(in_layers=[atom_feature, pair_feature])
  weave = WeaveLayer(in_layers=[C_AP, pair_split, atom_to_pair])
  weave = WeaveLayer(
      in_layers=[atom_feature, pair_feature, pair_split, atom_to_pair])
  tg.add_output(weave)
  tg.set_loss(weave)
  tg.build()