Unverified Commit 8174ab67 authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1105 from ueser/patch-1

8x faster graph normalization
parents bc649405 5408a47f
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ import torch.optim as optim
import random
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy


def symmetric_normalize_adj(adj):
@@ -24,8 +23,8 @@ def symmetric_normalize_adj(adj):
    adj = adj[:n_atoms, :n_atoms]
    degree = np.sum(adj, axis=1)
    D = np.diag(degree)
    D_sqrt = scipy.linalg.sqrtm(D)
    D_sqrt_inv = scipy.linalg.inv(D_sqrt)
    D_sqrt = np.sqrt(D)
    D_sqrt_inv = np.linalg.inv(D_sqrt)
    sym_norm = D_sqrt_inv.dot(adj)
    sym_norm = sym_norm.dot(D_sqrt_inv)
    new_adj = np.zeros(orig_shape)