Commit 50f076da authored by ErikXiong's avatar ErikXiong
Browse files

update

parent ad4efd21
Loading
Loading
Loading
Loading
+23 −16
Original line number Diff line number Diff line
@@ -50,6 +50,9 @@ possible_atom_list = [
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_chirality_list = [
    Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.ChiralType.CHI_UNSPECIFIED]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
@@ -60,14 +63,14 @@ possible_number_radical_e_list = [0, 1, 2]
reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list
    possible_hybridization_list, possible_chirality_list
]

intervals = get_intervals(reference_lists)


def get_feature_list(atom):
  features = 6 * [0]
  features = 7 * [0]
  features[0] = safe_index(possible_atom_list, atom.GetSymbol())
  features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
  features[2] = safe_index(possible_valence_list, atom.GetImplicitValence())
@@ -75,7 +78,7 @@ def get_feature_list(atom):
  features[4] = safe_index(possible_number_radical_e_list,
                           atom.GetNumRadicalElectrons())
  features[5] = safe_index(possible_hybridization_list, atom.GetHybridization())

  features[6] = safe_index(possible_chirality_list, atom.GetChiralTag())
  return features


@@ -91,15 +94,15 @@ def features_to_id(features, intervals):


def id_to_features(id, intervals):
  features = 6 * [0]
  features = 7 * [0]

  # Correct for null
  id -= 1

  for k in range(0, 6 - 1):
    # print(6-k-1, id)
    features[6 - k - 1] = id // intervals[6 - k - 1]
    id -= features[6 - k - 1] * intervals[6 - k - 1]
  for k in range(0, 7 - 1):
    # print(7-k-1, id)
    features[7 - k - 1] = id // intervals[7 - k - 1]
    id -= features[7 - k - 1] * intervals[7 - k - 1]
  # Correct for last one
  features[0] = id
  return features
@@ -111,7 +114,7 @@ def atom_to_id(atom):
  return features_to_id(features, intervals)


def atom_features(atom, bool_id_feat=False, explicit_H=False):
def atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=False):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
@@ -166,9 +169,6 @@ def atom_features(atom, bool_id_feat=False, explicit_H=False):
                             [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
              one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
              [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
              one_of_k_encoding_unk(atom.GetChiralTag(),
                [Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW,
                 Chem.ChiralType.CHI_UNSPECIFIED]) + \
              one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
@@ -178,6 +178,10 @@ def atom_features(atom, bool_id_feat=False, explicit_H=False):
    if not explicit_H:
      results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    if use_chirality:
      results = results + one_of_k_encoding_unk(atom.GetChiralTag(),
                [Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW,
                 Chem.ChiralType.CHI_UNSPECIFIED]) 

    return np.array(results)

@@ -252,7 +256,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
class ConvMolFeaturizer(Featurizer):
  name = ['conv_mol']

  def __init__(self, master_atom=False):
  def __init__(self, master_atom=False, use_chirality=False):
    """
    Parameters
    ----------
@@ -269,11 +273,12 @@ class ConvMolFeaturizer(Featurizer):
    """
    self.dtype = object
    self.master_atom = master_atom
    self.use_chirality = use_chirality

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    # Get the node features
    idx_nodes = [(a.GetIdx(), atom_features(a)) for a in mol.GetAtoms()]
    idx_nodes = [(a.GetIdx(), atom_features(a, use_chirality=self.use_chirality)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))

@@ -305,7 +310,7 @@ class ConvMolFeaturizer(Featurizer):
class WeaveFeaturizer(Featurizer):
  name = ['weave_mol']

  def __init__(self, graph_distance=True, explicit_H=False):
  def __init__(self, graph_distance=True, explicit_H=False, use_chirality=False):
    # Distance is either graph distance(True) or Euclidean distance(False,
    # only support datasets providing Cartesian coordinates)
    self.graph_distance = graph_distance
@@ -313,11 +318,13 @@ class WeaveFeaturizer(Featurizer):
    self.dtype = object
    # If includes explicit hydrogens
    self.explicit_H = explicit_H
    # If uses use_chirality
    self.use_chirality = use_chirality

  def _featurize(self, mol):
    """Encodes mol as a WeaveMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a, explicit_H=self.explicit_H))
    idx_nodes = [(a.GetIdx(), atom_features(a, explicit_H=self.explicit_H, use_chirality=self.use_chirality))
                 for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
    idx, nodes = list(zip(*idx_nodes))