Unverified Commit 6ccf705d authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2242 from nd-02110114/fix-mol2vec

Fix bug for Mol2VecFingerprint
parents 7dc98630 3ddc610f
Loading
Loading
Loading
Loading
+2 −15
Original line number Diff line number Diff line
@@ -42,8 +42,7 @@ class Mol2VecFingerprint(MolecularFeaturizer):
  def __init__(self,
               pretrain_model_path: Optional[str] = None,
               radius: int = 1,
               unseen: str = 'UNK',
               gather_method: str = 'sum'):
               unseen: str = 'UNK'):
    """
    Parameters
    ----------
@@ -56,9 +55,6 @@ class Mol2VecFingerprint(MolecularFeaturizer):
      github repository.
    unseen: str, optional (default 'UNK')
      The string to used to replace uncommon words/identifiers while training.
    gather_method: str, optional (default 'sum')
      How to aggregate vectors of identifiers are extracted from Mol2vec.
      'sum' or 'mean' is supported.
    """
    try:
      from gensim.models import word2vec
@@ -68,7 +64,6 @@ class Mol2VecFingerprint(MolecularFeaturizer):

    self.radius = radius
    self.unseen = unseen
    self.gather_method = gather_method
    self.sentences2vec = sentences2vec
    self.mol2alt_sentence = mol2alt_sentence
    if pretrain_model_path is None:
@@ -98,13 +93,5 @@ class Mol2VecFingerprint(MolecularFeaturizer):
      1D array of mol2vec fingerprint. The default length is 300.
    """
    sentence = self.mol2alt_sentence(mol, self.radius)
    vec_identifiers = self.sentences2vec(
        sentence, self.model, unseen=self.unseen)
    if self.gather_method == 'sum':
      feature = np.sum(vec_identifiers, axis=0)
    elif self.gather_method == 'mean':
      feature = np.mean(vec_identifiers, axis=0)
    else:
      raise ValueError(
          'Not supported gather_method type. Please set "sum" or "mean"')
    feature = self.sentences2vec([sentence], self.model, unseen=self.unseen)[0]
    return feature
+2 −8
Original line number Diff line number Diff line
import unittest

import numpy as np

from deepchem.feat import Mol2VecFingerprint


@@ -23,9 +21,5 @@ class TestMol2VecFingerprint(unittest.TestCase):
    Test simple fingerprint.
    """
    featurizer = Mol2VecFingerprint()
    feature_sum = featurizer([self.mol])
    assert feature_sum.shape == (1, 300)
    featurizer = Mol2VecFingerprint(gather_method='mean')
    feature_mean = featurizer([self.mol])
    assert feature_mean.shape == (1, 300)
    assert not np.allclose(feature_sum, feature_mean)
    feature = featurizer([self.mol])
    assert feature.shape == (1, 300)
+1 −1
Original line number Diff line number Diff line
@@ -21,4 +21,4 @@ dependencies:
    - pymatgen
    - simdna
    - xgboost
    - -e git+https://github.com/samoturk/mol2vec#egg=mol2vec
    - git+https://github.com/samoturk/mol2vec