Unverified Commit 3573cca2 authored by Suzukazole's avatar Suzukazole
Browse files

update test, docs

parent f114ba67
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
reactions
CCS(=O)(=O)Cl.OCCBr>CCN(CC)CC.CCOCC>CCS(=O)(=O)OCCBr
CC(C)CS(=O)(=O)Cl.OCCCl>CCN(CC)CC.CCOCC>CC(C)CS(=O)(=O)OCCCl
O=[N+]([O-])c1cccc2cnc(Cl)cc12>CC(=O)O.O.[Fe].[Na+].[OH-]>Nc1cccc2cnc(Cl)cc12
Cc1cc2c([N+](=O)[O-])cccc2c[n+]1[O-].O=P(Cl)(Cl)Cl>>Cc1cc2c([N+](=O)[O-])cccc2c(Cl)n1
 No newline at end of file
+20 −13
Original line number Diff line number Diff line
@@ -109,18 +109,25 @@ def test_disk_dataset_get_legacy_shape_multishard():


def test_get_shard_size():
  """Test that using ids for getting the shard size does not break the method."""
  num_datapoints = 100
  num_features = 10
  num_tasks = 10
  # Generate data
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.random.randint(2, size=(num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  """
  Test that using ids for getting the shard size does not break the method.
  The issue arises when attempting to load a dataset that does not have a labels
  column. The create_dataset method of the DataLoader class sets the y to None
  in this case, which causes the existing implementation of the get_shard_size()
  method to fail, as it relies on the dataset having a not None y column. This
  consequently breaks all methods depending on this, like the splitters for
  example.

  Note
  ----
  DiskDatasets without labels cannot be resharded!
  """

  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)
  assert dataset.get_shard_size() == 100
  current_dir = os.path.dirname(os.path.abspath(__file__))
  file_path = os.path.join(current_dir, "reaction_smiles.csv")

  featurizer = dc.feat.DummyFeaturizer()
  loader = dc.data.CSVLoader(tasks=[], feature_field="reactions", featurizer=featurizer)

  dataset.reshard(shard_size=15)
  assert dataset.get_shard_size() == 15
  dataset = loader.create_dataset(file_path)
  assert dataset.get_shard_size() == 4