Unverified Commit a8f150eb authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2641 from Suzukazole/usptotok

Fix splitter errors for datasets without labels
parents 22a8fbd7 c99d1260
......@@ -1500,10 +1500,10 @@ class DiskDataset(Dataset):
"""Gets size of shards on disk."""
if not len(self.metadata_df):
raise ValueError("No data in dataset.")
sample_y = load_from_disk(
sample_ids = load_from_disk(
os.path.join(self.data_dir,
next(self.metadata_df.iterrows())[1]['y']))
return len(sample_y)
next(self.metadata_df.iterrows())[1]['ids']))
return len(sample_ids)
def _get_metadata_filename(self) -> Tuple[str, str]:
"""Get standard location for metadata file."""
......@@ -2369,11 +2369,11 @@ class DiskDataset(Dataset):
if y is not None:
y_sel = y[shard_inds]
else:
y_sel = None
y_sel = np.array([])
if w is not None:
w_sel = w[shard_inds]
else:
w_sel = None
w_sel = np.array([])
ids_sel = ids[shard_inds]
Xs.append(X_sel)
ys.append(y_sel)
......@@ -2399,9 +2399,16 @@ class DiskDataset(Dataset):
np.where(sorted_indices == orig_index)[0][0]
for orig_index in select_shard_indices
])
X, y, w, ids = X[reverted_indices], y[reverted_indices], w[
reverted_indices], ids[reverted_indices]
yield (X, y, w, ids)
if y.size == 0:
tup_y = y
else:
tup_y = y[reverted_indices]
if w.size == 0:
tup_w = w
else:
tup_w = w[reverted_indices]
X, ids = X[reverted_indices], ids[reverted_indices]
yield (X, tup_y, tup_w, ids)
start = end
select_shard_num += 1
......
reactions
CCS(=O)(=O)Cl.OCCBr>CCN(CC)CC.CCOCC>CCS(=O)(=O)OCCBr
CC(C)CS(=O)(=O)Cl.OCCCl>CCN(CC)CC.CCOCC>CC(C)CS(=O)(=O)OCCCl
O=[N+]([O-])c1cccc2cnc(Cl)cc12>CC(=O)O.O.[Fe].[Na+].[OH-]>Nc1cccc2cnc(Cl)cc12
Cc1cc2c([N+](=O)[O-])cccc2c[n+]1[O-].O=P(Cl)(Cl)Cl>>Cc1cc2c([N+](=O)[O-])cccc2c(Cl)n1
\ No newline at end of file
......@@ -106,3 +106,29 @@ def test_disk_dataset_get_legacy_shape_multishard():
assert y_shape == (num_datapoints, num_tasks)
assert w_shape == (num_datapoints, num_tasks)
assert ids_shape == (num_datapoints,)
def test_get_shard_size():
"""
Test that using ids for getting the shard size does not break the method.
The issue arises when attempting to load a dataset that does not have a labels
column. The create_dataset method of the DataLoader class sets the y to None
in this case, which causes the existing implementation of the get_shard_size()
method to fail, as it relies on the dataset having a not None y column. This
consequently breaks all methods depending on this, like the splitters for
example.
Note
----
DiskDatasets without labels cannot be resharded!
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_dir, "reaction_smiles.csv")
featurizer = dc.feat.DummyFeaturizer()
loader = dc.data.CSVLoader(
tasks=[], feature_field="reactions", featurizer=featurizer)
dataset = loader.create_dataset(file_path)
assert dataset.get_shard_size() == 4
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment