Commit 875e0ac5 authored by evanfeinberg's avatar evanfeinberg
Browse files

fixed pytorch graph conv issues

parent e9140421
Loading
Loading
Loading
Loading
+15 −15
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ import torch.optim as optim
import random
import numpy as np
from sklearn.metrics import roc_auc_score

import scipy


def symmetric_normalize_adj(adj):
+134 −128
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from deepchem.feat.base_classes import Featurizer
from deepchem.feat.graph_features import atom_features
from scipy.sparse import csr_matrix


def get_atom_type(atom):
  elem = atom.GetAtomicNum()
  hyb = str(atom.GetHybridization).lower()
@@ -70,10 +71,14 @@ def get_atom_type(atom):
    return (21)
  return (22)

def get_atom_adj_matrices(mol, n_atom_types, max_n_atoms=200,

def get_atom_adj_matrices(mol,
                          n_atom_types,
                          max_n_atoms=200,
                          max_valence=4,
                          graph_conv_features=True,
                          nxn=True):
  if not graph_conv_features:
    bond_matrix = np.zeros((max_n_atoms, 4 * max_valence)).astype(np.uint8)

  if nxn:
@@ -82,8 +87,9 @@ def get_atom_adj_matrices(mol, n_atom_types, max_n_atoms=200,
    adj_matrix = np.zeros((max_n_atoms, max_valence)).astype(np.uint8)
    adj_matrix += (adj_matrix.shape[0] - 1)

    #atom_matrix = np.zeros((max_n_atoms, n_atom_types+3)).astype(np.uint8)
    #atom_matrix[:,atom_matrix.shape[1]-1] = 1
  if not graph_conv_features:
    atom_matrix = np.zeros((max_n_atoms, n_atom_types + 3)).astype(np.uint8)
    atom_matrix[:, atom_matrix.shape[1] - 1] = 1

  atom_arrays = []
  for a_idx in range(0, mol.GetNumAtoms()):
@@ -102,7 +108,8 @@ def get_atom_adj_matrices(mol, n_atom_types, max_n_atoms=200,
        adj_matrix[a_idx][a_idx] = 1
      else:
        adj_matrix[a_idx][n_idx] = neighbor.GetIdx()
            """

      if not graph_conv_features:
        bond = mol.GetBondBetweenAtoms(a_idx, neighbor.GetIdx())
        bond_type = str(bond.GetBondType()).lower()
        if "single" in bond_type:
@@ -114,7 +121,6 @@ def get_atom_adj_matrices(mol, n_atom_types, max_n_atoms=200,
        elif "aromatic" in bond_type:
          bond_order = 3
        bond_matrix[a_idx][(4 * n_idx) + bond_order] = 1
            """

  if graph_conv_features:
    n_feat = len(atom_arrays[0])
@@ -122,21 +128,23 @@ def get_atom_adj_matrices(mol, n_atom_types, max_n_atoms=200,
    for idx, atom_array in enumerate(atom_arrays):
      atom_matrix[idx, :] = atom_array
  else:
        atom_matrix = np.concatenate([atom_matrix, bond_matrix], axis=1).astype(np.uint8)
    atom_matrix = np.concatenate(
        [atom_matrix, bond_matrix], axis=1).astype(np.uint8)

  return (adj_matrix.astype(np.uint8), atom_matrix.astype(np.uint8))


def featurize_mol(mol,
                  n_atom_types,
                  max_n_atoms,
                  max_valence):
def featurize_mol(mol, n_atom_types, max_n_atoms, max_valence):

    adj_matrix, atom_matrix = get_atom_adj_matrices(mol, n_atom_types, max_n_atoms, max_valence)
  adj_matrix, atom_matrix = get_atom_adj_matrices(mol, n_atom_types,
                                                  max_n_atoms, max_valence)
  return ((adj_matrix, atom_matrix))


class AdjacencyFingerprint(Featurizer):
    def __init__(self, n_atom_types=23,

  def __init__(self,
               n_atom_types=23,
               max_n_atoms=200,
               add_hydrogens=False,
               max_valence=4):
@@ -151,9 +159,7 @@ class AdjacencyFingerprint(Featurizer):
    for idx, mol in enumerate(rdkit_mols):
      if self.add_hydrogens:
        mol = Chem.AddHs(mol)
            featurized_mol = featurize_mol(mol,
                                           self.n_atom_types,
                                           self.max_n_atoms,
      featurized_mol = featurize_mol(mol, self.n_atom_types, self.max_n_atoms,
                                     self.max_valence)
      featurized_mols[idx] = featurized_mol
    return (featurized_mols)
+2 −2

File changed.

Contains only whitespace changes.