Commit b7b56fa3 authored by nd-02110114's avatar nd-02110114
Browse files

♻️ fix default atom type

parent 60a12f60
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -79,9 +79,9 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
  to modify return values of `construct_atom_feature` or `construct_bond_feature`.

  The default node representation are constructed by concatenating the following values,
  and the feature length is 38.
  and the feature length is 39.

  - Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Br", "I", "other atoms".
  - Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "other atoms".
  - Chirality: A one-hot vector of the chirality, "R" or "S".
  - Formal charge: Integer electronic charge.
  - Partial charge: Calculated partial charge.
@@ -111,7 +111,7 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
  >>> type(out[0])
  <class 'deepchem.feat.graph_data.GraphData'>
  >>> out[0].num_node_features
  38
  39
  >>> out[0].num_edge_features
  11

+4 −4
Original line number Diff line number Diff line
@@ -13,13 +13,13 @@ class TestMolGraphConvFeaturizer(unittest.TestCase):

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 38
    assert graph_feat[0].num_node_features == 39
    assert graph_feat[0].num_edges == 12
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 38
    assert graph_feat[1].num_node_features == 39
    assert graph_feat[1].num_edges == 44
    assert graph_feat[1].num_edge_features == 11

@@ -31,12 +31,12 @@ class TestMolGraphConvFeaturizer(unittest.TestCase):

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 38
    assert graph_feat[0].num_node_features == 39
    assert graph_feat[0].num_edges == 12 + 6
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 38
    assert graph_feat[1].num_node_features == 39
    assert graph_feat[1].num_edges == 44 + 22
    assert graph_feat[1].num_edge_features == 11
+6 −6
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ class GAT(nn.Module):

  def __init__(
      self,
      in_node_dim: int = 38,
      in_node_dim: int = 39,
      hidden_node_dim: int = 64,
      heads: int = 4,
      dropout: float = 0.0,
@@ -59,8 +59,8 @@ class GAT(nn.Module):
    """
    Parameters
    ----------
    in_node_dim: int, default 38
      The length of the initial node feature vectors. The 38 is
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 64
      The length of the hidden node feature vectors.
@@ -152,7 +152,7 @@ class GATModel(TorchModel):
  """

  def __init__(self,
               in_node_dim: int = 38,
               in_node_dim: int = 39,
               hidden_node_dim: int = 64,
               heads: int = 4,
               dropout: float = 0.0,
@@ -165,8 +165,8 @@ class GATModel(TorchModel):

    Parameters
    ----------
    in_node_dim: int, default 38
      The length of the initial node feature vectors. The 38 is
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 64
      The length of the hidden node feature vectors.
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ DEFAULT_ATOM_TYPE_SET = [
    "F",
    "P",
    "S",
    "Cl",
    "Br",
    "I",
]
+3 −3
Original line number Diff line number Diff line
@@ -33,15 +33,15 @@ class TestGraphConvUtils(unittest.TestCase):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_type_one_hot(atoms[0])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    # check unknown atoms
    atoms = self.mol_copper_sulfate.GetAtoms()
    assert atoms[0].GetSymbol() == "Cu"
    one_hot = get_atom_type_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
    one_hot = get_atom_type_one_hot(atoms[0], include_unknown_set=False)
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    assert one_hot == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    # check original set
    atoms = self.mol.GetAtoms()