Commit 5151acad authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent 54105489
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -543,9 +543,6 @@ class JsonLoader(DataLoader):

    """

    if not isinstance(input_files, list):
      input_files = [input_files]

    def shard_generator():
      """Yield X, y, w, and ids for shards."""
      for shard_num, shard in enumerate(
+3 −3
Original line number Diff line number Diff line
@@ -888,7 +888,7 @@ class NumpyDataset(Dataset):
        for i in order:
          yield (self._X[i], self._y[i], self._w[i], self._ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):
    class TorchDataset(torch.utils.data.IterableDataset):  # type: ignore

      def __iter__(self):
        return iterate()
@@ -1415,7 +1415,7 @@ class DiskDataset(Dataset):
          for i in range(X.shape[0]):
            yield (X[i], y[i], w[i], ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):
    class TorchDataset(torch.utils.data.IterableDataset):  # type: ignore

      def __iter__(self):
        return iterate()
@@ -2174,7 +2174,7 @@ class ImageDataset(Dataset):
          yield (get_image(self._X, i), get_image(self._y, i), self._w[i],
                 self._ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):
    class TorchDataset(torch.utils.data.IterableDataset):  # type: ignore

      def __iter__(self):
        return iterate()
+4 −12
Original line number Diff line number Diff line
@@ -265,10 +265,6 @@ class StructureFeaturizer(Featurizer):

    """

    # Special case handling of single crystal structure
    if not isinstance(structures, Iterable):
      structures = [structures]
    else:
    # Convert iterables to list
    structures = list(structures)

@@ -336,10 +332,6 @@ class CompositionFeaturizer(Featurizer):

    """

    # Special case handling of single crystal composition
    if not isinstance(compositions, Iterable):
      compositions = [compositions]
    else:
    # Convert iterables to list
    compositions = list(compositions)

+2 −2
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ class ElementPropertyFingerprint(CompositionFeaturizer):

    self.data_source = data_source

  def _featurize(self, composition: "pymatgen.Composition"):
  def _featurize(self, composition: "pymatgen.Composition"):  # type: ignore
    """
    Calculate chemical fingerprint from crystal composition.

@@ -124,7 +124,7 @@ class SineCoulombMatrix(StructureFeaturizer):
    self.max_atoms = int(max_atoms)
    self.flatten = flatten

  def _featurize(self, struct: "pymatgen.Structure"):
  def _featurize(self, struct: "pymatgen.Structure"):  # type: ignore
    """
    Calculate sine Coulomb matrix from pymatgen structure.