Commit d57e4de9 authored by Ninad Bhat's avatar Ninad Bhat
Browse files

yapf updates

parent 61967fc1
Loading
Loading
Loading
Loading
+20 −23
Original line number Diff line number Diff line
@@ -488,8 +488,8 @@ class UserCSVLoader(CSVLoader):
    shard[feature_fields] = shard[feature_fields].apply(pd.to_numeric)
    X_shard = shard[feature_fields].to_numpy()
    time2 = time.time()
    logger.info(
        "TIMING: user specified processing took %0.3f s" % (time2 - time1))
    logger.info("TIMING: user specified processing took %0.3f s" %
                (time2 - time1))
    return (X_shard, np.ones(len(X_shard), dtype=bool))


@@ -768,7 +768,6 @@ class SDFLoader(DataLoader):
    if not isinstance(inputs, list):
      inputs = [inputs]


    processed_files = []
    for input_file in inputs:
      filename, extension = os.path.splitext(input_file)
@@ -780,16 +779,14 @@ class SDFLoader(DataLoader):
        zip_ref = zipfile.ZipFile(input_file, 'r')
        zip_ref.extractall(path=zip_dir)
        zip_ref.close()
        zip_files = [
            os.path.join(zip_dir, name) for name in zip_ref.namelist()
        ]
        zip_files = [os.path.join(zip_dir, name) for name in zip_ref.namelist()]
        for zip_file in zip_files:
          _, extension = os.path.splitext(zip_file)
          extension = extension.lower()
          if extension in [".sdf"]:
            processed_files.append(zip_file)
      else:
          raise ValueError("unsupported file format")
        raise ValueError("Unsupported file format")

    inputs = processed_files

@@ -834,8 +831,7 @@ class SDFLoader(DataLoader):
    Iterator[pd.DataFrame]
      Iterator over shards
    """
    return load_sdf_files(
        input_files=input_files,
    return load_sdf_files(input_files=input_files,
                          clean_mols=self.sanitize,
                          tasks=self.tasks,
                          shard_size=shard_size)
@@ -1035,11 +1031,12 @@ class ImageLoader(DataLoader):

    if in_memory:
      if data_dir is None:
        return NumpyDataset(
            load_image_files(image_files), y=labels, w=weights, ids=image_files)
        return NumpyDataset(load_image_files(image_files),
                            y=labels,
                            w=weights,
                            ids=image_files)
      else:
        dataset = DiskDataset.from_numpy(
            load_image_files(image_files),
        dataset = DiskDataset.from_numpy(load_image_files(image_files),
                                         y=labels,
                                         w=weights,
                                         ids=image_files,
@@ -1188,8 +1185,8 @@ class InMemoryLoader(DataLoader):

  # FIXME: Signature of "_featurize_shard" incompatible with supertype "DataLoader"
  def _featurize_shard(  # type: ignore[override]
      self, shard: List, global_index: int
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
      self, shard: List, global_index: int) -> Tuple[np.ndarray, np.ndarray,
                                                     np.ndarray, np.ndarray]:
    """Featurizes a shard of an input data.

    Parameters
+32 −19
Original line number Diff line number Diff line
@@ -5,8 +5,9 @@ import deepchem as dc
def test_sdf_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(
      os.path.join(current_dir, "membrane_permeability.sdf"))
  assert len(dataset) == 2
@@ -15,26 +16,32 @@ def test_sdf_load():
def test_singleton_sdf_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir, "singleton.sdf"))
  assert len(dataset) == 1


def test_singleton_sdf_zip_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir, "singleton.zip"))
  assert len(dataset) == 1


def test_sharded_sdf_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  dataset = loader.create_dataset(
      os.path.join(current_dir, "membrane_permeability.sdf"), shard_size=1)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir,
                                               "membrane_permeability.sdf"),
                                  shard_size=1)
  assert dataset.get_number_shards() == 2
  assert len(dataset) == 2

@@ -42,8 +49,9 @@ def test_sharded_sdf_load():
def test_sharded_multi_file_sdf_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  input_files = [
      os.path.join(current_dir, "membrane_permeability.sdf"),
      os.path.join(current_dir, "singleton.sdf")
@@ -52,23 +60,28 @@ def test_sharded_multi_file_sdf_load():
  assert dataset.get_number_shards() == 3
  assert len(dataset) == 3


def test_sharded_multi_file_sdf_zip_load():
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir, "multiple_sdf.zip"), shard_size=1)
  loader = dc.data.SDFLoader(["LogP(RRCK)"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir, "multiple_sdf.zip"),
                                  shard_size=1)
  assert dataset.get_number_shards() == 3
  assert len(dataset) == 3


def test_sdf_load_with_csv():
  """Test a case where SDF labels are in associated csv file"""
  current_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=16)
  loader = dc.data.SDFLoader(
      ["atomization_energy"], featurizer=featurizer, sanitize=True)
  dataset = loader.create_dataset(
      os.path.join(current_dir, "water.sdf"), shard_size=1)
  loader = dc.data.SDFLoader(["atomization_energy"],
                             featurizer=featurizer,
                             sanitize=True)
  dataset = loader.create_dataset(os.path.join(current_dir, "water.sdf"),
                                  shard_size=1)
  assert len(dataset) == 10
  assert dataset.get_number_shards() == 10
  assert dataset.get_task_names() == ["atomization_energy"]