Commit fc7d9850 authored by Bharath's avatar Bharath
Browse files

Simple shape fix

parent c08762b1
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -636,7 +636,7 @@ def convert_df_to_numpy(df, feature_type, tasks, mol_id_field):
    for task in range(n_tasks):
      if y[ind, task] == "":
        missing[ind, task] = 1
  x = np.array(list(df[feature_type].values))
  x = np.squeeze(np.array(list(df[feature_type].values)))
  ############################################################## DEBUG
  time2 = time.time()
  print("CONVERT_DF_TO_NUMPY X COMP TOOK %0.3f s" % (time2-time1))
+3 −0
Original line number Diff line number Diff line
@@ -41,6 +41,9 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    ################################################################# DEBUG
    print("solubility_dataset.get_data_shape()")
    print(solubility_dataset.get_data_shape())
    X, y, w, ids = solubility_dataset.to_numpy()
    print("X.shape, y.shape, w.shape, ids.shape")
    print(X.shape, y.shape, w.shape, ids.shape)
    ################################################################# DEBUG
    assert solubility_dataset.get_data_shape() == (1024,)