Commit 231664cd authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent 5151acad
Loading
Loading
Loading
Loading
+19 −5
Original line number Diff line number Diff line
@@ -12,8 +12,9 @@ import time
import sys
import logging
import warnings
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Any, Sequence

from deepchem.utils.typing import OneOrMany
from deepchem.utils.save import load_csv_files, load_json_files
from deepchem.utils.save import load_sdf_files
from deepchem.utils.genomics import encode_fasta_sequence
@@ -478,7 +479,7 @@ class JsonLoader(DataLoader):
  """

  def __init__(self,
               tasks: List[str],
               tasks: OneOrMany[str],
               feature_field: str,
               label_field: str = None,
               weight_field: str = None,
@@ -521,14 +522,14 @@ class JsonLoader(DataLoader):
    self.log_every_n = log_every_n

  def create_dataset(self,
                     input_files: List[str],
                     input_files: OneOrMany[str],
                     data_dir: Optional[str] = None,
                     shard_size: Optional[int] = 8192) -> DiskDataset:
    """Creates a `Dataset` from input JSON files.

    Parameters
    ----------
    input_files: List[str]
    input_files: OneOrMany[str]
      List of JSON filenames.
    data_dir: Optional[str], default None
      Name of directory where featurized data is stored.
@@ -542,6 +543,16 @@ class JsonLoader(DataLoader):
      from `input_files`.

    """
    if not isinstance(input_files, list):
      try:
        if isinstance(input_files, str):
          input_files = [input_files]
        else:
          input_files = list(input_files)
      except TypeError:
        raise ValueError(
            "input_files is of an unrecognized form. Must be one filename or a list of filenames."
        )

    def shard_generator():
      """Yield X, y, w, and ids for shards."""
@@ -902,7 +913,10 @@ class InMemoryLoader(DataLoader):

  """

  def create_dataset(self, inputs, data_dir=None, shard_size=8192):
  def create_dataset(self,
                     inputs: Sequence[Any],
                     data_dir: Optional[str] = None,
                     shard_size: int = 8192) -> DiskDataset:
    """Creates and returns a `Dataset` object by featurizing provided files.

    Reads in `inputs` and uses `self.featurizer` to featurize the
+18 −27
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ Tests for JsonLoader class.
"""

import os
import unittest
import tempfile
import shutil
import numpy as np
@@ -12,18 +11,9 @@ from deepchem.data.data_loader import JsonLoader
from deepchem.feat.materials_featurizers import SineCoulombMatrix


class TestJsonLoader(unittest.TestCase):
  """
  Test JsonLoader
  """

  def setUp(self):
    super(TestJsonLoader, self).setUp()
    self.current_dir = os.path.dirname(os.path.abspath(__file__))

  def test_json_loader(self):
    input_file = os.path.join(self.current_dir,
                              'inorganic_crystal_sample_data.json')
def test_json_loader():
  current_dir = os.path.dirname(os.path.abspath(__file__))
  input_file = os.path.join(current_dir, 'inorganic_crystal_sample_data.json')
  featurizer = SineCoulombMatrix(max_atoms=5)
  loader = JsonLoader(
      tasks=['e_form'],
@@ -31,6 +21,7 @@ class TestJsonLoader(unittest.TestCase):
      id_field='formula',
      label_field='e_form',
      featurizer=featurizer)

  dataset = loader.create_dataset(input_file, shard_size=1)

  a = [4625.32086965, 6585.20209678, 61.00680193, 48.72230922, 48.72230922]