Commit f4524b46 authored by leswing's avatar leswing
Browse files

Change radius to 2

parent 58401dc9
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -109,7 +109,7 @@ class ScScoreModel(TensorGraph):

  def predict_mols(self, mols):
    featurizer = CircularFingerprint(
        size=self.n_features, radius=3, chiral=True)
        size=self.n_features, radius=2, chiral=True)
    features = np.expand_dims(featurizer.featurize(mols), axis=1)
    features = np.concatenate([features, features], axis=1)
    ds = NumpyDataset(features, None, None, None)
@@ -121,6 +121,9 @@ class ScScoreModel(TensorGraph):
    for layer, column in zip([self.m1_features, self.m2_features],
                             feature_columns):
      tensors[layer] = tf.feature_column.input_layer(features, [column])
    if weight_column is not None:
      tensors[self.task_weights[0]] = tf.feature_column.input_layer(
          features, [weight_column])
    if labels is not None:
      tensors[self.labels[0]] = tf.cast(labels, tf.int32)
    return tensors
+2 −10
Original line number Diff line number Diff line
@@ -306,13 +306,10 @@ class TestEstimators(unittest.TestCase):

    x_col1 = tf.feature_column.numeric_column('x1', shape=(n_features,))
    x_col2 = tf.feature_column.numeric_column('x2', shape=(n_features,))
    weight_col = tf.feature_column.numeric_column('weights', shape=(1,))

    def accuracy(labels, predictions, weights):
      return tf.metrics.accuracy(labels, tf.round(predictions), weights)

    metrics = {'accuracy': accuracy}
    estimator = model.make_estimator(
        feature_columns=[x_col1, x_col2], metrics=metrics)
        feature_columns=[x_col1, x_col2], metrics={}, weight_column=weight_col)

    # Train the model.

@@ -321,9 +318,4 @@ class TestEstimators(unittest.TestCase):
    # Evaluate the model.

    results = estimator.evaluate(input_fn=lambda: input_fn(1))
    print(results)
    assert results['loss'] < 1e-4
    # TODO(LESWING) Discuss with peastman.
    #  The output here is human readable
    # score 1-5 per molecule not a probability of class
    # assert results['accuracy'] > 0.9
+261 −228

File changed.

Preview size limit exceeded, changes collapsed.