Commit 39363f44 authored by miaecle's avatar miaecle
Browse files

fix

parents 1009c2b2 882feccf
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ from deepchem.feat.base_classes import Featurizer
from deepchem.feat.base_classes import ComplexFeaturizer
from deepchem.feat.base_classes import UserDefinedFeaturizer
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import WeaveFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
from deepchem.feat.basic import RDKitDescriptors
from deepchem.feat.coulomb_matrices import CoulombMatrix
+175 −42
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from rdkit import Chem
import itertools, operator

from deepchem.feat import Featurizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.feat.mol_graphs import ConvMol, WeaveMol


def one_of_k_encoding(x, allowable_set):
  if x not in allowable_set:
@@ -15,12 +16,14 @@ def one_of_k_encoding(x, allowable_set):
        "input {0} not in allowable set{1}:".format(x, allowable_set))
  return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
  """Maps inputs not in the allowable set to the last element."""
  if x not in allowable_set:
    x = allowable_set[-1]
  return list(map(lambda s: x == s, allowable_set))


def get_intervals(l):
  """For list of lists, gets the cumulative products of the lengths"""
  intervals = len(l) * [0]
@@ -31,6 +34,7 @@ def get_intervals(l):

  return intervals


def safe_index(l, e):
  """Gets the index of e in l, providing an index of len(l) if not found"""
  try:
@@ -38,22 +42,26 @@ def safe_index(l, e):
  except:
    return len(l)

possible_atom_list = ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br',
                      'Fe', 'Ca', 'Cu', 'Mc', 'Pd', 'Pb',
                      'K','I','Al','Ni','Mn']

possible_atom_list = [
    'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu',
    'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn'
]
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_hybridization_list = [Chem.rdchem.HybridizationType.SP,
                               Chem.rdchem.HybridizationType.SP2,
                               Chem.rdchem.HybridizationType.SP3,
                               Chem.rdchem.HybridizationType.SP3D,
                               Chem.rdchem.HybridizationType.SP3D2]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]

reference_lists = [possible_atom_list, possible_numH_list,
                   possible_valence_list, possible_formal_charge_list,
                   possible_number_radical_e_list, possible_hybridization_list]
reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list
]

intervals = get_intervals(reference_lists)

@@ -70,6 +78,7 @@ def get_feature_list(atom):

  return features


def features_to_id(features, intervals):
  """Convert list of features into index using spacings provided in intervals"""
  id = 0
@@ -80,6 +89,7 @@ def features_to_id(features, intervals):
  id = id + 1
  return id


def id_to_features(id, intervals):
  features = 6 * [0]

@@ -94,47 +104,133 @@ def id_to_features(id, intervals):
  features[0] = id
  return features


def atom_to_id(atom):
  """Return a unique id corresponding to the atom type"""
  features = get_feature_list(atom)
  return features_to_id(features, intervals)


def atom_features(atom, bool_id_feat=False):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
    return np.array(one_of_k_encoding_unk(
        atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
         'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',    # H?
         'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) +
        [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
    return np.array(
        one_of_k_encoding_unk(
            atom.GetHybridization(),
            [Chem.rdchem.HybridizationType.SP,
             Chem.rdchem.HybridizationType.SP2,
             Chem.rdchem.HybridizationType.SP3,
             Chem.rdchem.HybridizationType.SP3D,
             Chem.rdchem.HybridizationType.SP3D2]) +
        [atom.GetIsAromatic()])
            atom.GetSymbol(),
            [
                'C',
                'N',
                'O',
                'S',
                'F',
                'Si',
                'P',
                'Cl',
                'Br',
                'Mg',
                'Na',
                'Ca',
                'Fe',
                'As',
                'Al',
                'I',
                'B',
                'V',
                'K',
                'Tl',
                'Yb',
                'Sb',
                'Sn',
                'Ag',
                'Pd',
                'Co',
                'Se',
                'Ti',
                'Zn',
                'H',  # H?
                'Li',
                'Ge',
                'Cu',
                'Au',
                'Ni',
                'Cd',
                'In',
                'Mn',
                'Zr',
                'Cr',
                'Pt',
                'Hg',
                'Pb',
                'Unknown'
            ]) + one_of_k_encoding(atom.GetDegree(), [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
            ]) + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])
        + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
        one_of_k_encoding_unk(atom.GetHybridization(), [
            Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
            SP3D, Chem.rdchem.HybridizationType.SP3D2
        ]) + [atom.GetIsAromatic()])


def bond_features(bond):
  bt = bond.GetBondType()
  return np.array([bt == Chem.rdchem.BondType.SINGLE,
                   bt == Chem.rdchem.BondType.DOUBLE,
                   bt == Chem.rdchem.BondType.TRIPLE,
                   bt == Chem.rdchem.BondType.AROMATIC,
                   bond.GetIsConjugated(),
                   bond.IsInRing()])
  return np.array([
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(), bond.IsInRing()
  ])


def pair_features(mol, edge_list, canon_adj_list, bt_len=6):
  max_distance = 7
  features = np.zeros(
      (mol.GetNumAtoms(), mol.GetNumAtoms(), bt_len + max_distance + 1))
  num_atoms = mol.GetNumAtoms()
  rings = mol.GetRingInfo().AtomRings()
  for a1 in range(num_atoms):
    for a2 in canon_adj_list[a1]:
      # first `bt_len` features are bond features(if applicable)
      features[a1, a2, :bt_len] = np.asarray(
          edge_list[tuple(sorted((a1, a2)))], dtype=float)
    for ring in rings:
      if a1 in ring:
        # `bt_len`-th feature is if the pair of atoms are in the same ring
        features[a1, ring, bt_len] = 1
        features[a1, a1, bt_len] = 0.
    # graph distance between two atoms
    distance = find_distance(
        a1, num_atoms, canon_adj_list, max_distance=max_distance)
    features[a1, :, bt_len + 1:] = distance

  return features


def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
  distance = np.zeros((num_atoms, max_distance))
  radial = 0
  # atoms `radial` bonds away from `a1`
  adj_list = set(canon_adj_list[a1])
  # atoms less than `radial` bonds away
  all_list = set([a1])
  while radial < max_distance:
    distance[list(adj_list), radial] = 1
    all_list.update(adj_list)
    # find atoms `radial`+1 bonds away
    next_adj = set()
    for adj in adj_list:
      next_adj.update(canon_adj_list[adj])
    adj_list = next_adj - all_list
    radial = radial + 1
  return distance


class ConvMolFeaturizer(Featurizer):

  name = ['conv_mol']

  def __init__(self):
    # Since ConvMol is an object and not a numpy array, need to set dtype to
    # object.
@@ -151,7 +247,8 @@ class ConvMolFeaturizer(Featurizer):
    nodes = np.vstack(nodes)

    # Get bond lists with reverse edges included
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx())
                 for b in mol.GetBonds()]

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
@@ -160,3 +257,39 @@ class ConvMolFeaturizer(Featurizer):
      canon_adj_list[edge[1]].append(edge[0])

    return ConvMol(nodes, canon_adj_list)


class WeaveFeaturizer(Featurizer):

  name = ['weave_mol']

  def __init__(self):
    # Set dtype
    self.dtype = object

  def _featurize(self, mol):
    """Encodes mol as a WeaveMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))

    # Stack nodes into an array
    nodes = np.vstack(nodes)

    # Get bond lists
    edge_list = {}
    for b in mol.GetBonds():
      edge_list[tuple(sorted([b.GetBeginAtomIdx(), b.GetEndAtomIdx()
                             ]))] = bond_features(b)

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
    for edge in edge_list.keys():
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])

    # Calculate pair features
    pairs = pair_features(mol, edge_list, canon_adj_list, bt_len=6)

    return WeaveMol(nodes, pairs)
+25 −0
Original line number Diff line number Diff line
@@ -388,3 +388,28 @@ class MultiConvMol(object):

  def get_num_molecules(self):
    return self.num_mols


class WeaveMol(object):
  """Holds information about a molecule
  Molecule struct used in weave models
  """

  def __init__(self, nodes, pairs):

    self.nodes = nodes
    self.pairs = pairs
    self.num_atoms = self.nodes.shape[0]
    self.n_features = self.nodes.shape[1]

  def get_pair_features(self):
    return self.pairs

  def get_atom_features(self):
    return self.nodes

  def get_num_atoms(self):
    return self.num_atoms

  def get_num_features(self):
    return self.n_features
 No newline at end of file
+99 −0
Original line number Diff line number Diff line
@@ -752,6 +752,105 @@ class TestOverfit(test_util.TensorFlowTestCase):

    assert scores[regression_metric.name] > .8

  def test_weave_singletask_classification_overfit(self):
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.WeaveFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_classification.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score)

    n_atom_feat = 75
    n_pair_feat = 14
    n_feat = 128
    batch_size = 10
    max_atoms = 50

    graph = dc.nn.SequentialWeaveGraph(
        max_atoms=max_atoms, n_atom_feat=n_atom_feat, n_pair_feat=n_pair_feat)
    graph.add(dc.nn.WeaveLayer(max_atoms, 75, 14))
    graph.add(dc.nn.WeaveConcat(batch_size, n_output=n_feat))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.WeaveGather(batch_size, n_input=n_feat))

    model = dc.models.MultitaskGraphClassifier(
        graph,
        n_tasks,
        n_feat,
        batch_size=batch_size,
        learning_rate=1e-3,
        learning_rate_decay_time=1000,
        optimizer_type="adam",
        beta1=.9,
        beta2=.999)

    # Fit trained model
    model.fit(dataset, nb_epoch=20)
    model.save()

    # Eval model on train
    scores = model.evaluate(dataset, [classification_metric])

    assert scores[classification_metric.name] > .65

  def test_weave_singletask_regression_overfit(self):
    """Test weave model overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    n_tasks = 1

    # Load mini log-solubility dataset.
    featurizer = dc.feat.WeaveFeaturizer()
    tasks = ["outcome"]
    input_file = os.path.join(self.current_dir, "example_regression.csv")
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)
    dataset = loader.featurize(input_file)

    regression_metric = dc.metrics.Metric(
        dc.metrics.pearson_r2_score, task_averager=np.mean)

    n_atom_feat = 75
    n_pair_feat = 14
    n_feat = 128
    batch_size = 10
    max_atoms = 50

    graph = dc.nn.SequentialWeaveGraph(
        max_atoms=max_atoms, n_atom_feat=n_atom_feat, n_pair_feat=n_pair_feat)
    graph.add(dc.nn.WeaveLayer(max_atoms, 75, 14))
    graph.add(dc.nn.WeaveConcat(batch_size, n_output=n_feat))
    graph.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph.add(dc.nn.WeaveGather(batch_size, n_input=n_feat))

    model = dc.models.MultitaskGraphRegressor(
        graph,
        n_tasks,
        n_feat,
        batch_size=batch_size,
        learning_rate=1e-3,
        learning_rate_decay_time=1000,
        optimizer_type="adam",
        beta1=.9,
        beta2=.999)

    # Fit trained model
    model.fit(dataset, nb_epoch=40)
    model.save()

    # Eval model on train
    scores = model.evaluate(dataset, [regression_metric])

    assert scores[regression_metric.name] > .9

  def test_siamese_singletask_classification_overfit(self):
    """Test siamese singletask model overfits tiny data."""
    np.random.seed(123)
+35 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ __license__ = "MIT"

import tensorflow as tf
from deepchem.nn.layers import GraphGather
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology, WeaveGraphTopology


class SequentialGraph(object):
@@ -162,6 +162,40 @@ class SequentialDAGGraph(SequentialGraph):
      self.layers.append(layer)


class SequentialWeaveGraph(SequentialGraph):
  """SequentialGraph for Weave models
  """

  def __init__(self, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
    self.graph = tf.Graph()
    self.max_atoms = max_atoms
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    with self.graph.as_default():
      self.graph_topology = WeaveGraphTopology(self.max_atoms, self.n_atom_feat,
                                               self.n_pair_feat)
      self.output = self.graph_topology.get_atom_features_placeholder()
      self.output_P = self.graph_topology.get_pair_features_placeholder()
    self.layers = []

  def add(self, layer):
    """Adds a new layer to model."""
    with self.graph.as_default():
      if type(layer).__name__ in ['WeaveLayer']:
        self.output, self.output_P = layer([
            self.output, self.output_P
        ] + self.graph_topology.get_topology_placeholders())
      elif type(layer).__name__ in ['WeaveConcat']:
        self.output = layer(
            [self.output, self.graph_topology.atom_mask_placeholder])
      elif type(layer).__name__ in ['WeaveGather']:
        self.output = layer(
            [self.output, self.graph_topology.membership_placeholder])
      else:
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialSupportGraph(object):
  """An analog of Keras Sequential model for test/support models."""

Loading