Commit 47006c5b authored by peastman's avatar peastman
Browse files

Minor improvements to molnet loader functions

parent 413c6a43
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -17,10 +17,11 @@ from deepchem.splits.splitters import Splitter
logger = logging.getLogger(__name__)

featurizers = {
    'ECFP': dc.feat.CircularFingerprint(size=1024),
    'GraphConv': dc.feat.ConvMolFeaturizer(),
    'Weave': dc.feat.WeaveFeaturizer(),
    'Raw': dc.feat.RawFeaturizer()
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
    'raw': dc.feat.RawFeaturizer(),
    'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
+4 −4
Original line number Diff line number Diff line
@@ -68,9 +68,9 @@ def load_delaney(
    splitter = kwargs['split']
    logger.warning("'split' is deprecated.  Use 'splitter' instead.")
  if isinstance(featurizer, str):
    featurizer = dc.molnet.defaults.featurizers[featurizer]
    featurizer = dc.molnet.defaults.featurizers[featurizer.lower()]
  if isinstance(splitter, str):
    splitter = dc.molnet.defaults.splitters[splitter]
    splitter = dc.molnet.defaults.splitters[splitter.lower()]
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
@@ -80,8 +80,8 @@ def load_delaney(
  # Try to reload cached datasets.

  if reload:
    featurizer_name = str(featurizer.__class__.__name__)
    splitter_name = str(splitter.__class__.__name__)
    featurizer_name = str(featurizer)
    splitter_name = str(splitter)
    if not move_mean:
      featurizer_name = featurizer_name + "_mean_unmoved"
    save_folder = os.path.join(save_dir, "delaney-featurized", featurizer_name,
+40 −3
Original line number Diff line number Diff line
"""
Contains an abstract base class that supports chemically aware data splits.
"""
import inspect
import os
import random
import tempfile
@@ -270,7 +271,30 @@ class Splitter(object):
    >>> str(dc.splits.RandomSplitter())
    'RandomSplitter'
    """
    return self.__class__.__name__
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_num = len(args_names)
    args_default_values = [None for _ in range(args_num)]
    if args_spec.defaults is not None:
      defaults = list(args_spec.defaults)
      args_default_values[-len(defaults):] = defaults

    override_args_info = ''
    for arg_name, default in zip(args_names, args_default_values):
      if arg_name in self.__dict__:
        arg_value = self.__dict__[arg_name]
        # validation
        # skip list
        if isinstance(arg_value, list):
          continue
        if isinstance(arg_value, str):
          # skip path string
          if "\\/." in arg_value or "/" in arg_value or '.' in arg_value:
            continue
        # main logic
        if default != arg_value:
          override_args_info += '_' + arg_name + '_' + str(arg_value)
    return self.__class__.__name__ + override_args_info

  def __repr__(self) -> str:
    """Convert self to repr representation.
@@ -284,9 +308,22 @@ class Splitter(object):
    --------
    >>> import deepchem as dc
    >>> dc.splits.RandomSplitter()
    RandomSplitter
    RandomSplitter[]
    """
    return self.__str__()
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_info = ''
    for arg_name in args_names:
      value = self.__dict__[arg_name]
      # for str
      if isinstance(value, str):
        value = "'" + value + "'"
      # for list
      if isinstance(value, list):
        threshold = get_print_threshold()
        value = np.array2string(np.array(value), threshold=threshold)
      args_info += arg_name + '=' + str(value) + ', '
    return self.__class__.__name__ + '[' + args_info[:-2] + ']'


class RandomSplitter(Splitter):