Commit 9b0afeee authored by marta-sd's avatar marta-sd
Browse files

deal with indices outside of the box in RdkitGridFeaturizer._voxelize

parent b5feed9c
Loading
Loading
Loading
Loading
+3 −7
Original line number Diff line number Diff line
@@ -1095,7 +1095,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                channel_power=None,
                nb_channel=16,
                dtype="np.int8"):
    # TODO(enf): make array index checking not a try-catch statement.

    if channel_power is not None:
      if channel_power == 0:
        nb_channel = 1
@@ -1115,22 +1115,18 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
      for key, features in feature_dict.items():
        voxels = get_voxels(coordinates, key, self.box_width, self.voxel_width)
        for voxel in voxels:
          try:
          if ((voxel >= 0) & (voxel < self.voxels_per_edge)).all():
            if hash_function is not None:
              feature_tensor[voxel[0], voxel[1], voxel[2],
                             hash_function(features, channel_power)] += 1.0
            else:
              feature_tensor[voxel[0], voxel[1], voxel[3], 0] += features
          except:
            continue
    elif feature_list is not None:
      for key in feature_list:
        voxels = get_voxels(coordinates, key, self.box_width, self.voxel_width)
        for voxel in voxels:
          try:
          if ((voxel >= 0) & (voxel < self.voxels_per_edge)).all():
            feature_tensor[voxel[0], voxel[1], voxel[2], 0] += 1.0
          except:
            continue

    return feature_tensor