Commit c94fa0b1 authored by nd-02110114's avatar nd-02110114
Browse files

🐛 go back to original implementaion

parent 88dd25b5
Loading
Loading
Loading
Loading
+11 −13
Original line number Diff line number Diff line
@@ -144,10 +144,10 @@ class MolecularFeaturizer(Featurizer):
  molecule.

  The defining feature of a `MolecularFeaturizer` is that it
  uses SMILES strings and RDKIT molecule objects to represent
  uses SMILES strings and RDKit molecule objects to represent
  small molecules. All other featurizers which are subclasses of
  this class should plan to process input which comes as smiles
  strings or RDKIT molecules.
  strings or RDKit molecules.

  Child classes need to implement the _featurize method for
  calculating features for a single molecule.
@@ -157,7 +157,7 @@ class MolecularFeaturizer(Featurizer):
  The subclasses of this class require RDKit to be installed.
  """

  def featurize(self, molecules, log_every_n=1000, canonical=True):
  def featurize(self, molecules, log_every_n=1000):
    """Calculate features for molecules.

    Parameters
@@ -167,8 +167,6 @@ class MolecularFeaturizer(Featurizer):
      strings.
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
    canonical: bool, default True
      Whether to use a canonical order of atoms returned by RDKit

    Returns
    -------
@@ -177,6 +175,8 @@ class MolecularFeaturizer(Featurizer):
    """
    try:
      from rdkit import Chem
      from rdkit.Chem import rdmolfiles
      from rdkit.Chem import rdmolops
      from rdkit.Chem.rdchem import Mol
    except ModuleNotFoundError:
      raise ValueError("This class requires RDKit to be installed.")
@@ -194,13 +194,11 @@ class MolecularFeaturizer(Featurizer):
        logger.info("Featurizing datapoint %i" % i)
      try:
        if isinstance(mol, str):
          # mol must be a SMILES string so parse
          # mol must be a RDKit Mol object, so parse a SMILES
          mol = Chem.MolFromSmiles(mol)
        # canonicalize
        if canonical:
          canonical_smiles = Chem.MolToSmiles(mol)
          mol = Chem.MolFromSmiles(canonical_smiles)

          # SMILES is unique, so set a canonical order of atoms
          new_order = rdmolfiles.CanonicalRankAtoms(mol)
          mol = rdmolops.RenumberAtoms(mol, new_order)
        features.append(self._featurize(mol))
      except:
        logger.warning(
+56 −23
Original line number Diff line number Diff line
@@ -16,8 +16,8 @@ class GraphData:
    Graph connectivity in COO format with shape [2, num_edges]
  edge_features: np.ndarray, optional (default None)
    Edge feature matrix with shape [num_edges, num_edge_features]
  graph_features: np.ndarray, optional (default None)
    Graph feature vector with shape [num_graph_features,]
  node_pos_features: np.ndarray, optional (default None)
    Node position matrix with shape [num_nodes, num_dimensions].
  num_nodes: int
    The number of nodes in the graph
  num_node_features: int
@@ -40,7 +40,7 @@ class GraphData:
      node_features: np.ndarray,
      edge_index: np.ndarray,
      edge_features: Optional[np.ndarray] = None,
      graph_features: Optional[np.ndarray] = None,
      node_pos_features: Optional[np.ndarray] = None,
  ):
    """
    Parameters
@@ -51,8 +51,8 @@ class GraphData:
      Graph connectivity in COO format with shape [2, num_edges]
    edge_features: np.ndarray, optional (default None)
      Edge feature matrix with shape [num_edges, num_edge_features]
    graph_features: np.ndarray, optional (default None)
      Graph feature vector with shape [num_graph_features,]
    node_pos_features: np.ndarray, optional (default None)
      Node position matrix with shape [num_nodes, num_dimensions].
    """
    # validate params
    if isinstance(node_features, np.ndarray) is False:
@@ -74,14 +74,18 @@ class GraphData:
        raise ValueError('The first dimension of edge_features must be the \
                          same as the second dimension of edge_index.')

    if graph_features is not None and isinstance(graph_features,
                                                 np.ndarray) is False:
      raise ValueError('graph_features must be np.ndarray or None.')
    if node_pos_features is not None:
      if isinstance(node_pos_features, np.ndarray) is False:
        raise ValueError('node_pos_features must be np.ndarray or None.')
      elif node_pos_features.shape[0] != node_features.shape[0]:
        raise ValueError(
            'The length of node_pos_features must be the same as the \
                          length of node_features.')

    self.node_features = node_features
    self.edge_index = edge_index
    self.edge_features = edge_features
    self.graph_features = graph_features
    self.node_pos_features = node_pos_features
    self.num_nodes, self.num_node_features = self.node_features.shape
    self.num_edges = edge_index.shape[1]
    if self.edge_features is not None:
@@ -106,12 +110,18 @@ class GraphData:
      raise ValueError(
          "This function requires PyTorch Geometric to be installed.")

    edge_features = self.edge_features
    if edge_features is not None:
      edge_features = torch.from_numpy(self.edge_features).float()
    node_pos_features = self.node_pos_features
    if node_pos_features is not None:
      node_pos_features = torch.from_numpy(self.node_pos_features).float()

    return Data(
        x=torch.from_numpy(self.node_features),
        x=torch.from_numpy(self.node_features).float(),
        edge_index=torch.from_numpy(self.edge_index).long(),
        edge_attr=None
        if self.edge_features is None else torch.from_numpy(self.edge_features),
    )
        edge_attr=edge_features,
        pos=node_pos_features)

  def to_dgl_graph(self):
    """Convert to DGL graph data instance
@@ -136,10 +146,13 @@ class GraphData:
    g.add_edges(
        torch.from_numpy(self.edge_index[0]).long(),
        torch.from_numpy(self.edge_index[1]).long())
    g.ndata['x'] = torch.from_numpy(self.node_features)
    g.ndata['x'] = torch.from_numpy(self.node_features).float()

    if self.node_pos_features is not None:
      g.ndata['pos'] = torch.from_numpy(self.node_pos_features).float()

    if self.edge_features is not None:
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features)
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features).float()

    return g

@@ -149,8 +162,28 @@ class BatchGraphData(GraphData):

  Attributes
  ----------
  node_features: np.ndarray
    Concatenated node feature matrix with shape [num_nodes, num_node_features].
    `num_nodes` is total number of nodes in the batch graph.
  edge_index: np.ndarray, dtype int
    Concatenated graph connectivity in COO format with shape [2, num_edges].
    `num_edges` is total number of edges in the batch graph.
  edge_features: np.ndarray, optional (default None)
    Concatenated edge feature matrix with shape [num_edges, num_edge_features].
    `num_edges` is total number of edges in the batch graph.
  node_pos_features: np.ndarray, optional (default None)
    Concatenated node position matrix with shape [num_nodes, num_dimensions].
    `num_nodes` is total number of edges in the batch graph.
  num_nodes: int
    The number of nodes in the batch graph.
  num_node_features: int
    The number of features per node in the graph.
  num_edges: int
    The number of edges in the batch graph.
  num_edges_features: int, optional (default None)
    The number of features per edge in the graph.
  graph_index: np.ndarray, dtype int
    This vector indicates which graph the node belongs with shape [num_nodes,]
    This vector indicates which graph the node belongs with shape [num_nodes,].

  Examples
  --------
@@ -177,7 +210,7 @@ class BatchGraphData(GraphData):
    batch_node_features = np.vstack(
        [graph.node_features for graph in graph_list])

    # before stacking edge_features or graph_features,
    # before stacking edge_features or node_pos_features,
    # we should check whether these are None or not
    if graph_list[0].edge_features is not None:
      batch_edge_features = np.vstack(
@@ -185,11 +218,11 @@ class BatchGraphData(GraphData):
    else:
      batch_edge_features = None

    if graph_list[0].graph_features is not None:
      batch_graph_features = np.vstack(
          [graph.graph_features for graph in graph_list])
    if graph_list[0].node_pos_features is not None:
      batch_node_pos_features = np.vstack(
          [graph.node_pos_features for graph in graph_list])
    else:
      batch_graph_features = None
      batch_node_pos_features = None

    # create new edge index
    num_nodes_list = [graph.num_nodes for graph in graph_list]
@@ -208,5 +241,5 @@ class BatchGraphData(GraphData):
        node_features=batch_node_features,
        edge_index=batch_edge_index,
        edge_features=batch_edge_features,
        graph_features=batch_graph_features,
        node_pos_features=batch_node_pos_features,
    )
+3 −3
Original line number Diff line number Diff line
@@ -15,13 +15,13 @@ class TestGraph(unittest.TestCase):
        [0, 1, 2, 2, 3, 4],
        [1, 2, 0, 3, 4, 0],
    ])
    graph_features = None
    node_pos_features = None

    graph = GraphData(
        node_features=node_features,
        edge_index=edge_index,
        edge_features=edge_features,
        graph_features=graph_features)
        node_pos_features=node_pos_features)

    assert graph.num_nodes == num_nodes
    assert graph.num_node_features == num_node_features
@@ -92,7 +92,7 @@ class TestGraph(unittest.TestCase):
            edge_index=edge_index_list[i],
            edge_features=np.random.random_sample((num_edge_list[i],
                                                   num_edge_features)),
            graph_features=None) for i in range(len(num_edge_list))
            node_pos_features=None) for i in range(len(num_edge_list))
    ]
    batch = BatchGraphData(graph_list)