Commit 3f8be5df authored by Vignesh's avatar Vignesh
Browse files

Changed feature shapes to deal with tf.feature_column

parent ec30df0f
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -114,6 +114,8 @@ class BPSymmetryFunctionRegression(TensorGraph):
          feed_dict[self.task_weights[0]] = w_b

        atom_feats, atom_flags = self.compute_features_on_batch(X_b)
        atom_feats = atom_feats.reshape(-1, self.max_atoms * self.n_feat)
        atom_flags = atom_flags.reshape(-1, self.max_atoms * self.max_atoms)
        feed_dict[self.atom_feats] = atom_feats
        feed_dict[self.atom_flags] = atom_flags

@@ -214,10 +216,13 @@ class ANIRegression(TensorGraph):
      feed_dict = dict()
      X = dataset.X
      flags = np.sign(np.array(X[:upper_lim, :, 0]))
      feed_dict[self.atom_flags] = np.stack([flags]*self.max_atoms, axis=2)*\
      atom_flags = np.stack([flags]*self.max_atoms, axis=2)*\
          np.stack([flags]*self.max_atoms, axis=1)
      feed_dict[self.atom_numbers] = np.array(X[:upper_lim, :, 0], dtype=int)
      feed_dict[self.atom_feats] = np.array(X[:upper_lim, :, :], dtype=float)
      feed_dict[self.atom_flags] = atom_flags.reshape(-1, self.max_atoms * self.max_atoms)
      atom_numbers = np.array(X[:upper_lim, :, 0], dtype=int)
      feed_dict[self.atom_numbers] = atom_numbers
      atom_feats = np.array(X[:upper_lim, :, :], dtype=float)
      feed_dict[self.atom_feats] = atom_feats.reshape(-1, self.max_atoms * 4)
      return self.session.run([self.grad], feed_dict=feed_dict)

  def pred_one(self, X, atomic_nums, constraints=None):
@@ -284,7 +289,8 @@ class ANIRegression(TensorGraph):
    X = Z
    inp = np.array(X).reshape((1, self.max_atoms, 4))
    dd = dc.data.NumpyDataset(inp, np.array([1]), np.array([1]))
    res = self.compute_grad(dd)[0][0][0]
    res = self.compute_grad(dd)[0][0]
    res = res.reshape(self.max_atoms, 4)
    res = res[:num_atoms, 1:]

    if constraints is not None:
@@ -402,6 +408,8 @@ class ANIRegression(TensorGraph):

        atom_feats, atom_numbers, atom_flags = self.compute_features_on_batch(
            X_b)
        atom_feats = atom_feats.reshape(-1, self.max_atoms * 4)
        atom_flags = atom_flags.reshape(-1, self.max_atoms * self.max_atoms)
        feed_dict[self.atom_feats] = atom_feats
        feed_dict[self.atom_numbers] = atom_numbers
        feed_dict[self.atom_flags] = atom_flags