Commit 927750df authored by ZHENQIN WU's avatar ZHENQIN WU
Browse files

performance table update and description for IRV transformer

parent b6580cf7
Loading
Loading
Loading
Loading
+26 −7
Original line number Diff line number Diff line
@@ -327,17 +327,22 @@ Scaffold splitting
|chembl          |MT-NN regression    |Index       |0.443         |0.427         |
|                |MT-NN regression    |Random      |0.464         |0.434         |
|                |MT-NN regression    |Scaffold    |0.484         |0.361         |
|qm7             |MT-NN regression    |Index       |0.994         |0.010         |
|                |MT-NN regression    |Random      |0.860         |0.773         |
|                |MT-NN regression    |User-defined|0.996         |0.996         | 
|qm7             |MT-NN regression    |Index       |0.994         |0.969         |
|                |MT-NN regression    |Random      |0.995         |0.992         |
|                |MT-NN regression    |Stratified  |0.992         |0.992         | 
|qm7b            |MT-NN regression    |Index       |0.883         |0.785         |
|                |MT-NN regression    |Random      |0.864         |0.838         |
|                |MT-NN regression    |Stratified  |0.871         |0.847         | 
|kaggle          |MT-NN regression    |User-defined|0.748         |0.452         |

|Dataset         |Model               |Splitting   |Train score/MAE(kcal/mol)|Valid score/MAE(kcal/mol)|
|----------------|--------------------|------------|-------------------------|-------------------------|
|qm7             |MT-NN regression    |Index       |18.3                     |172.0                    |
|                |MT-NN regression    |Random      |44.3                     |59.1                     |
|qm7             |MT-NN regression    |Index       |22.1                     |23.2                     |
|                |MT-NN regression    |Random      |16.2                     |17.7                     |
|                |MT-NN regression    |Stratified  |20.5                     |20.8                     |
|                |MT-NN regression    |User-defined|9.0                      |9.5                      |


* General features

Number of tasks and examples in the datasets
@@ -358,7 +363,8 @@ Number of tasks and examples in the datasets
|pdbbind(refined)|1          |3706       |
|pdbbind(full)   |1          |11908      |
|chembl(5thresh) |691        |23871      |
|gdb7            |1          |7165       |
|qm7             |1          |7165       |
|qm7b            |14         |7211       |



@@ -369,6 +375,8 @@ Time needed for benchmark test(~20h in total)
|tox21           |logistic regression |30              |60             |
|                |Multitask network   |30              |60             |
|                |robust MT-NN        |30              |90             |
|                |random forest       |30              |6000           |
|                |IRV                 |30              |650            |
|                |graph convolution   |40              |160            |
|muv             |logistic regression |600             |450            |
|                |Multitask network   |600             |400            |
@@ -381,22 +389,33 @@ Time needed for benchmark test(~20h in total)
|sider           |logistic regression |15              |80             |
|                |Multitask network   |15              |75             |
|                |robust MT-NN        |15              |150            |
|                |random forest       |15              |2200           |
|                |IRV                 |15              |150            |
|                |graph convolution   |20              |50             |
|toxcast         |logistic regression |80              |2600           |
|                |Multitask network   |80              |2300           |
|                |robust MT-NN        |80              |4000           |
|                |graph convolution   |80              |900            |
|clintox         |logistic regression |15              |10             |
|                |Multitask network   |15              |20             |
|                |robust MT-NN        |15              |30             |
|                |random forest       |15              |200            |
|                |IRV                 |15              |10             |
|                |graph convolution   |20              |130            |
|delaney         |MT-NN regression    |10              |40             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |30             |
|sampl           |MT-NN regression    |10              |30             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |20             |
|nci             |MT-NN regression    |400             |1200           |
|                |graphconv regression|400             |2500           |
|pdbbind(core)   |MT-NN regression    |0(featurized)   |30             |
|pdbbind(refined)|MT-NN regression    |0(featurized)   |40             |
|pdbbind(full)   |MT-NN regression    |0(featurized)   |60             |
|chembl          |MT-NN regression    |200             |9000           |
|gdb7            |MT-NN regression    |10              |110            |
|qm7             |MT-NN regression    |10              |400            |
|qm7b            |MT-NN regression    |10              |600            |
|kaggle          |MT-NN regression    |2200            |3200           |


+34 −0
Original line number Diff line number Diff line
@@ -633,6 +633,25 @@ class IRVTransformer():
    self.transform_y = transform_y

  def realize(self, similarity, y, w):
    """find samples with top ten similarity values in the reference dataset
    
    Parameters:
    -----------
    similarity: np.ndarray
      similarity value between target dataset and reference dataset
      should have size of (n_samples_in_target, n_samples_in_reference)
    y: np.array
      labels for a single task
    w: np.array
      weights for a single task
   
    Return:
    ----------
    features: list
      n_samples * np.array of size (2*K,)
      each array includes K similarity values and corresponding labels

    """
    features = []
    similarity_xs = similarity * np.sign(w)
    for similarity_x in similarity_xs:
@@ -646,6 +665,21 @@ class IRVTransformer():
    return features

  def X_transform(self, X_target):
    """ Calculate similarity between target dataset(X_target) and 
    reference dataset(X): #(1 in intersection)/#(1 in union)
         similarity = (X_target ∩ X)/(X_target U X)
    Parameters:
    -----------
    X_target: np.ndarray
      fingerprints of target dataset
      should have same length with X in the second axis
    
    Returns:
    ----------
    X_target: np.ndarray
      features of size(batch_size, 2*K*n_tasks)
    
    """
    X_target2 = []
    n_features = X_target.shape[1]
    similarity = np.matmul(X_target, np.transpose(self.X)) / (