Commit b731912a authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

fix IRV transformer

parent e35c9900
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -1610,7 +1610,6 @@ class IRVTransformer(Transformer):
      n_samples * np.array of size (2*K,)
      each array includes K similarity values and corresponding labels
    """
    import tensorflow as tf
    features = []
    similarity_xs = similarity * np.sign(w)
    [target_len, reference_len] = similarity_xs.shape
@@ -1622,8 +1621,9 @@ class IRVTransformer(Transformer):
                                                 100, target_len), :]
      # generating batch of data by slicing similarity matrix
      # into 100*reference_dataset_length
      value, indice = tf.nn.top_k(similarity, k=self.K + 1, sorted=True)
      top_label = tf.gather(y, indice)
      indice = np.argsort(similarity)[:, -(self.K + 1):][:, ::-1]
      value = np.take_along_axis(similarity, indice, axis=1)
      top_label = np.take(y, indice)
      values.append(value)
      top_labels.append(top_label)
    values_np = np.concatenate(values, axis=0)