Commit a2be59ac authored by miaecle's avatar miaecle
Browse files

Merge remote-tracking branch 'remotes/mine/textcnn'

parents 88138d3c 184278e9
Loading
Loading
Loading
Loading
+14 −12
Original line number Diff line number Diff line
@@ -185,14 +185,17 @@ class WeaveTensorGraph(TensorGraph):
  def predict_on_generator(self, generator, transformers=[], outputs=None):
      out = super(WeaveTensorGraph, self).predict_on_generator(
          generator, 
          transformers=transformers, 
          transformers=[], 
          outputs=outputs)
      if outputs is None:
        outputs = self.outputs
      if len(outputs) == 1:
      if len(outputs) > 1:
        out = np.stack(out, axis=1)
      
      out = undo_transforms(out, transformers)
      return out
      else:
        return np.stack(out, axis=1)
      
      

class DTNNTensorGraph(TensorGraph):

@@ -345,7 +348,6 @@ class DTNNTensorGraph(TensorGraph):

        yield feed_dict

  '''
  def predict(self, dataset, transformers=[], outputs=None):
    if outputs is None:
      outputs = self.outputs
@@ -357,7 +359,6 @@ class DTNNTensorGraph(TensorGraph):
      return retval
    retval = np.concatenate(retval, axis=-1)
    return undo_transforms(retval, transformers)
  '''

class DAGTensorGraph(TensorGraph):

@@ -511,14 +512,15 @@ class DAGTensorGraph(TensorGraph):
  def predict_on_generator(self, generator, transformers=[], outputs=None):
      out = super(DAGTensorGraph, self).predict_on_generator(
          generator, 
          transformers=transformers, 
          transformers=[], 
          outputs=outputs)
      if outputs is None:
        outputs = self.outputs
      if len(outputs) == 1:
      if len(outputs) > 1:
        out = np.stack(out, axis=1)
      
      out = undo_transforms(out, transformers)
      return out
      else:
        return np.stack(out, axis=1)

class PetroskiSuchTensorGraph(TensorGraph):
  """
+6 −5
Original line number Diff line number Diff line
@@ -274,11 +274,12 @@ class TextCNNTensorGraph(TensorGraph):
  def predict_on_generator(self, generator, transformers=[], outputs=None):
      out = super(TextCNNTensorGraph, self).predict_on_generator(
          generator, 
          transformers=transformers, 
          transformers=[], 
          outputs=outputs)
      if outputs is None:
        outputs = self.outputs
      if len(outputs) == 1:
      if len(outputs) > 1:
        out = np.stack(out, axis=1)
      
      out = undo_transforms(out, transformers)
      return out
 No newline at end of file
      else:
        return np.stack(out, axis=1)
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -140,7 +140,7 @@ hps['graphconvreg'] = {
hps['dtnn'] = {
    'batch_size': 64,
    'nb_epoch': 100,
    'learning_rate': 0.0005,
    'learning_rate': 0.001,
    'n_embedding': 50,
    'n_distance': 170,
    'seed': 123
+1 −1

File changed.

Contains only whitespace changes.