Commit ec8551ab authored by Zhenqin Wu's avatar Zhenqin Wu
Browse files

Add support for protein sequence featurization

parent 7e745b93
Loading
Loading
Loading
Loading
+138 −0
Original line number Diff line number Diff line
import os
import numpy as np
import tempfile
from amino_acids import AminoAcid
from deepchem.utils.protein_sequence_data_utils import read_fasta
from deepchem.utils.protein_sequence_data_utils import read_hhm
from deepchem.utils.protein_sequence_data_utils import run_hhblits_local
from deepchem.utils.protein_sequence_feature_utils import read_a3m_as_mat
from deepchem.utils.protein_sequence_feature_utils import sequence_one_hot_encoding
from deepchem.utils.protein_sequence_feature_utils import sequence_deletion_probability
from deepchem.utils.protein_sequence_feature_utils import sequence_weights
from deepchem.utils.protein_sequence_feature_utils import sequence_profile
from deepchem.utils.protein_sequence_feature_utils import sequence_profile_no_gap
from deepchem.utils.protein_sequence_feature_utils import sequence_profile_with_prior
from deepchem.utils.protein_sequence_feature_utils import sequence_identity
from deepchem.utils.protein_sequence_feature_utils import sequence_static_prop
from deepchem.utils.protein_sequence_feature_utils import sequence_gap_matrix
from deepchem.utils.protein_sequence_feature_utils import profile_combinatorial
from deepchem.utils.protein_sequence_feature_utils import mutual_information
from deepchem.utils.protein_sequence_feature_utils import mean_contact_potential

from deepchem.feat.base_classes import Featurizer



class ProteinSequenceFeaturizer(Featurizer):
  """Abstract class for calculating a set of features for a
  protein sequence (used for structure prediction).

  The defining feature of a `ProteinSequenceFeaturizer` is that it
  reads protein sequences (and pre-saved multiple sequence alignments)
  to generate features for structure prediction.

  Child classes need to implement the `_featurize` method for
  calculating features for a single protein sequence. Note that `_featurize`
  methods should take two arguments: protein sequence and a directory path 
  for saving multiple sequence alignments
  """

  def prepare_msa(self, sequence, path):
    msa_file = os.path.join(path, 'results.a3m')
    if os.path.exists(msa_file):
      return True
    else:
      run_hhblits_local(sequence, path)
      if not os.path.exists(msa_file):
        return False
      return True


  def featurize(self, 
                protein_seqs: Iterable[str], 
                log_every_n : int=1000):
    """Calculate features for molecules.
    Parameters
    ----------
    protein_seqs: str
      protein sequence or path to the folder of pre-saved multiple sequence alignments
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
    Returns
    -------
    features: 
    """

    if isinstance(protein_seqs, str):
      # Handle single sequence/path
      protein_seqs = [protein_seqs]
    else:
      # Convert iterables to list
      protein_seqs = list(protein_seqs)

    features = []
    for i, s in enumerate(protein_seqs):
      if os.path.isdir(s):
        # `s` is a folder, pre-saved MSAs are provided.
        # Raw sequence should be saved under this folder as "input.seq"
        _, sequence = read_fasta(os.path.join(s, "input.seq"))
        assert len(sequence) == 1
        sequence = sequence[0]
        path = s
      else:
        # `s` is a protein sequence, a temp folder 
        # will be used for featurization
        sequence = s
        path = tempfile.mkdtemp()

      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)

      features.append(self._featurize(sequence, path))

    features = np.asarray(features)
    return features


class ContactMapProteinSequenceFeaturizer(ProteinSequenceFeaturizer):
  
  def _featurize(sequence, path):
    # sec_structure
    # solv_surf
    assert self.prepare_msa(sequence, path), "MSA not found under %s" % path
    hhm_profile = read_hhm(sequence, os.path.join(path, 'results.hhm'))
      
    a3m_seq_IDs, a3m_seqs = read_fasta(os.path.join(path, 'results.a3m'))
    a3m_seq_mat = read_a3m_as_mat(sequence, a3m_seq_IDs, a3m_seqs)
    
    seq_one_hot = sequence_one_hot_encoding(sequence)
    seq_del_prob = sequence_deletion_probability(sequence, a3m_seqs)
    prof = sequence_profile(a3m_seq_mat)
    prof_no_gap = sequence_profile_no_gap(a3m_seq_mat)
    
    weights = sequence_weights(a3m_seq_mat)
    w_prof = sequence_profile(a3m_seq_mat, weights=weights)
    w_prof_no_gap = sequence_profile_no_gap(a3m_seq_mat, weights=weights)

    prior_prof = sequence_profile_with_prior(w_prof)
    prior_prof_no_gap = sequence_profile_with_prior(w_prof_no_gap)    

    static_prop = sequence_static_prop(a3m_seq_mat, weights)

    feats_1D = np.concatenate([seq_one_hot, seq_del_prob, prof, prof_no_gap,
                               w_prof, w_prof_no_gap, prior_prof, 
                               prior_prof_no_gap, hhm_profile, static_prop], 1)


    gap_matrix = sequence_gap_matrix(a3m_seq_mat)
    w_prof_2D = profile_combinatorial(a3m_seq_mat, weights, w_prof)
    MI = mutual_information(w_prof, w_prof_2D)
    MCP = mean_contact_potential(w_prof_2D)

    #pseudo_bias
    #pseudo_frob
    #pseudolikelihood
    feats_2D = np.concatenate([gap_matrix, MI, MCP], 2)
    
    
    return feats_1D, feats_2D
+19 −0
Original line number Diff line number Diff line
@@ -88,3 +88,22 @@ from deepchem.utils.vina_utils import load_docked_ligands
from deepchem.utils.voxel_utils import convert_atom_to_voxel
from deepchem.utils.voxel_utils import convert_atom_pair_to_voxel
from deepchem.utils.voxel_utils import voxelize

from deepchem.utils.protein_sequence_feature_utils import read_a3m_as_mat
from deepchem.utils.protein_sequence_feature_utils import sequence_one_hot_encoding
from deepchem.utils.protein_sequence_feature_utils import sequence_deletion_probability
from deepchem.utils.protein_sequence_feature_utils import sequence_weights
from deepchem.utils.protein_sequence_feature_utils import sequence_profile
from deepchem.utils.protein_sequence_feature_utils import sequence_profile_no_gap
from deepchem.utils.protein_sequence_feature_utils import sequence_profile_with_prior
from deepchem.utils.protein_sequence_feature_utils import sequence_identity
from deepchem.utils.protein_sequence_feature_utils import sequence_static_prop
from deepchem.utils.protein_sequence_feature_utils import sequence_gap_matrix
from deepchem.utils.protein_sequence_feature_utils import profile_combinatorial
from deepchem.utils.protein_sequence_feature_utils import mutual_information
from deepchem.utils.protein_sequence_feature_utils import mean_contact_potential

from deepchem.utils.protein_sequence_data_utils import read_fasta
from deepchem.utils.protein_sequence_data_utils import write_fasta
from deepchem.utils.protein_sequence_data_utils import read_hhm
from deepchem.utils.protein_sequence_data_utils import run_hhblits_local
 No newline at end of file
+129 −0
Original line number Diff line number Diff line
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jun  7 11:49:05 2017

@author: zqwu
"""
import numpy as np

AminoAcid = {
    '-': 21,
    'A': 0,
    'R': 1,
    'N': 2,
    'D': 3,
    'C': 4,
    'Q': 5,
    'E': 6,
    'G': 7,
    'H': 8,
    'I': 9,
    'L': 10,
    'K': 11,
    'M': 12,
    'F': 13,
    'P': 14,
    'S': 15,
    'T': 16,
    'W': 17,
    'Y': 18,
    'V': 19,
    'X': 20}

AminoAcid_SMILES = {
    'A': 'O=C(O)C(N)C',
    'R': 'NC(CCCNC(N)=N)C(O)=O',
    'N': 'O=C(N)C[C@H](N)C(=O)O',
    'D': 'O=C(O)CC(N)C(=O)O',
    'C': 'C([C@@H](C(=O)O)N)S',
    'Q': 'O=C(N)CCC(N)C(=O)O',
    'E': 'C(CC(=O)O)C(C(=O)O)N',
    'G': 'C(C(=O)O)N',
    'H': 'C1=C(NC=N1)C[C@@H](C(=O)O)N',
    'I': 'CC[C@H](C)[C@@H](C(=O)O)N',
    'L': 'CC(C)C[C@@H](C(=O)O)N',
    'K': 'C(CCN)CC(C(=O)O)N',
    'M': 'CSCCC(C(=O)O)N',
    'F': 'c1ccc(cc1)C[C@@H](C(=O)O)N',
    'P': 'C1CC(NC1)C(=O)O',
    'S': 'C([C@@H](C(=O)O)N)O',
    'T': 'C[C@H]([C@@H](C(=O)O)N)O',
    'W': 'c1ccc2c(c1)c(c[nH]2)C[C@@H](C(=O)O)N',
    'Y': 'N[C@@H](Cc1ccc(O)cc1)C(O)=O',
    'V': 'CC(C)[C@@H](C(=O)O)N'}

AminoAcidFreq = {
	'-': 0.050,
    'A': 0.074,
    'R': 0.042,
    'N': 0.044,
    'D': 0.059,
    'C': 0.033,
    'Q': 0.058,
    'E': 0.037,
    'G': 0.074,
    'H': 0.029,
    'I': 0.038,
    'L': 0.076,
    'K': 0.072,
    'M': 0.018,
    'F': 0.040,
    'P': 0.050,
    'S': 0.081,
    'T': 0.062,
    'W': 0.013,
    'Y': 0.033,
    'V': 0.068,
    'X': 0.050
    }

# from https://www.ncbi.nlm.nih.gov/Class/FieldGuide/BLOSUM62.txt
BLOSUM62 = np.array([
    [ 4, -1, -2, -2,  0, -1, -1,  0, -2, -1, -1, -1, -1, -2, -1,  1,  0, -3, -2,  0,  0, -4], # A 
    [-1,  5,  0, -2, -3,  1,  0, -2,  0, -3, -2,  2, -1, -3, -2, -1, -1, -3, -2, -3, -1, -4], # R 
    [-2,  0,  6,  1, -3,  0,  0,  0,  1, -3, -3,  0, -2, -3, -2,  1,  0, -4, -2, -3, -1, -4], # N 
    [-2, -2,  1,  6, -3,  0,  2, -1, -1, -3, -4, -1, -3, -3, -1,  0, -1, -4, -3, -3, -1, -4], # D 
    [ 0, -3, -3, -3,  9, -3, -4, -3, -3, -1, -1, -3, -1, -2, -3, -1, -1, -2, -2, -1, -2, -4], # C 
    [-1,  1,  0,  0, -3,  5,  2, -2,  0, -3, -2,  1,  0, -3, -1,  0, -1, -2, -1, -2, -1, -4], # Q 
    [-1,  0,  0,  2, -4,  2,  5, -2,  0, -3, -3,  1, -2, -3, -1,  0, -1, -3, -2, -2, -1, -4], # E 
    [ 0, -2,  0, -1, -3, -2, -2,  6, -2, -4, -4, -2, -3, -3, -2,  0, -2, -2, -3, -3, -1, -4], # G 
    [-2,  0,  1, -1, -3,  0,  0, -2,  8, -3, -3, -1, -2, -1, -2, -1, -2, -2,  2, -3, -1, -4], # H 
    [-1, -3, -3, -3, -1, -3, -3, -4, -3,  4,  2, -3,  1,  0, -3, -2, -1, -3, -1,  3, -1, -4], # I 
    [-1, -2, -3, -4, -1, -2, -3, -4, -3,  2,  4, -2,  2,  0, -3, -2, -1, -2, -1,  1, -1, -4], # L 
    [-1,  2,  0, -1, -3,  1,  1, -2, -1, -3, -2,  5, -1, -3, -1,  0, -1, -3, -2, -2, -1, -4], # K 
    [-1, -1, -2, -3, -1,  0, -2, -3, -2,  1,  2, -1,  5,  0, -2, -1, -1, -1, -1,  1, -1, -4], # M 
    [-2, -3, -3, -3, -2, -3, -3, -3, -1,  0,  0, -3,  0,  6, -4, -2, -2,  1,  3, -1, -1, -4], # F 
    [-1, -2, -2, -1, -3, -1, -1, -2, -2, -3, -3, -1, -2, -4,  7, -1, -1, -4, -3, -2, -2, -4], # P 
    [ 1, -1,  1,  0, -1,  0,  0,  0, -1, -2, -2,  0, -1, -2, -1,  4,  1, -3, -2, -2,  0, -4], # S 
    [ 0, -1,  0, -1, -1, -1, -1, -2, -2, -1, -1, -1, -1, -2, -1,  1,  5, -2, -2,  0,  0, -4], # T 
    [-3, -3, -4, -4, -2, -2, -3, -2, -2, -3, -2, -3, -1,  1, -4, -3, -2, 11,  2, -3, -2, -4], # W 
    [-2, -2, -2, -3, -2, -1, -2, -3,  2, -1, -1, -2, -1,  3, -3, -2, -2,  2,  7, -1, -1, -4], # Y 
    [ 0, -3, -3, -3, -1, -2, -2, -3, -3,  3,  1, -2,  1, -1, -2, -2,  0, -3, -1,  4, -1, -4], # V 
    [ 0, -1, -1, -1, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2,  0,  0, -2, -1, -1, -1, -4], # X 
    [-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,  1]]) # * 

# from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2144252/pdf/10048329.pdf B matrix
ContactPotential = np.array([
    [-0.20,  0.27,  0.24,  0.30, -0.26,  0.21,  0.43, -0.03,  0.21, -0.35, -0.37,  0.20, -0.23, -0.33,  0.07,  0.15,  0.00, -0.40, -0.15, -0.38, -0.02, -0.02], # A
    [ 0.27,  0.13,  0.02, -0.71,  0.32, -0.12, -0.75,  0.14,  0.04,  0.18,  0.09,  0.50,  0.17,  0.08, -0.02,  0.12,  0.00, -0.41, -0.37,  0.17, -0.03, -0.03], # R
    [ 0.24,  0.02, -0.04, -0.12,  0.28, -0.05, -0.01,  0.10,  0.10,  0.55,  0.36, -0.14,  0.32,  0.29,  0.13,  0.14,  0.00, -0.09,  0.01,  0.39,  0.12,  0.12], # N
    [ 0.30, -0.71, -0.12,  0.27,  0.38,  0.12,  0.40,  0.17, -0.22,  0.54,  0.62, -0.69,  0.62,  0.48,  0.25,  0.01,  0.00,  0.06, -0.07,  0.66,  0.15,  0.15], # D
    [-0.26,  0.32,  0.28,  0.38, -1.34,  0.04,  0.46, -0.09, -0.19, -0.48, -0.50,  0.35, -0.49, -0.53, -0.18,  0.09,  0.00, -0.74, -0.16, -0.51, -0.12, -0.12], # C
    [ 0.21, -0.12, -0.05,  0.12,  0.04,  0.14,  0.10,  0.20,  0.22,  0.14,  0.08, -0.20, -0.01, -0.04, -0.05,  0.25,  0.00, -0.11, -0.18,  0.17,  0.07,  0.07], # Q
    [ 0.43, -0.75, -0.01,  0.40,  0.46,  0.10,  0.45,  0.48, -0.11,  0.38,  0.37, -0.87,  0.24,  0.34,  0.26,  0.10,  0.00, -0.15, -0.16,  0.41,  0.13,  0.13], # E
    [-0.03,  0.14,  0.10,  0.17, -0.09,  0.20,  0.48, -0.20,  0.23,  0.21,  0.14,  0.12,  0.08,  0.11, -0.01,  0.10,  0.00, -0.24, -0.04,  0.04,  0.08,  0.08], # G
    [ 0.21,  0.04,  0.10, -0.22, -0.19,  0.22, -0.11,  0.23, -0.33,  0.19,  0.10,  0.26, -0.17, -0.19, -0.05,  0.15,  0.00, -0.46, -0.21,  0.18,  0.05,  0.05], # H
    [-0.35,  0.18,  0.55,  0.54, -0.48,  0.14,  0.38,  0.21,  0.19, -0.60, -0.79,  0.21, -0.60, -0.65,  0.05,  0.35,  0.00, -0.65, -0.33, -0.68, -0.07, -0.07], # I
    [-0.37,  0.09,  0.36,  0.62, -0.50,  0.08,  0.37,  0.14,  0.10, -0.79, -0.81,  0.16, -0.68, -0.78, -0.08,  0.26,  0.00, -0.70, -0.44, -0.80, -0.14, -0.14], # L
    [ 0.20,  0.50, -0.14, -0.69,  0.35, -0.20, -0.87,  0.12,  0.26,  0.21,  0.16,  0.38,  0.22,  0.11,  0.12,  0.10,  0.00, -0.28, -0.40,  0.16,  0.04,  0.04], # K
    [-0.23,  0.17,  0.32,  0.62, -0.49, -0.01,  0.24,  0.08, -0.17, -0.60, -0.68,  0.22, -0.56, -0.89, -0.16,  0.32,  0.00, -0.94, -0.51, -0.47, -0.11, -0.11], # M
    [-0.33,  0.08,  0.29,  0.48, -0.53, -0.04,  0.34,  0.11, -0.19, -0.65, -0.78,  0.11, -0.89, -0.82, -0.19,  0.10,  0.00, -0.78, -0.49, -0.67, -0.18, -0.18], # F
    [ 0.07, -0.02,  0.13,  0.25, -0.18, -0.05,  0.26, -0.01, -0.05,  0.05, -0.08,  0.12, -0.16, -0.19, -0.07,  0.17,  0.00, -0.73, -0.40, -0.08, -0.00, -0.00], # P
    [ 0.15,  0.12,  0.14,  0.01,  0.09,  0.25,  0.10,  0.10,  0.15,  0.35,  0.26,  0.10,  0.32,  0.10,  0.17,  0.13,  0.00,  0.07,  0.07,  0.25,  0.14,  0.14], # S
    [ 0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00], # T
    [-0.40, -0.41, -0.09,  0.06, -0.74, -0.11, -0.15, -0.24, -0.46, -0.65, -0.70, -0.28, -0.94, -0.78, -0.73,  0.07,  0.00, -0.74, -0.55, -0.62, -0.36, -0.36], # W
    [-0.15, -0.37,  0.01, -0.07, -0.16, -0.18, -0.16, -0.04, -0.21, -0.33, -0.44, -0.40, -0.51, -0.49, -0.40,  0.07,  0.00, -0.55, -0.27, -0.27, -0.21, -0.21], # Y
    [-0.38,  0.17,  0.39,  0.66, -0.51,  0.17,  0.41,  0.04,  0.18, -0.68, -0.80,  0.16, -0.47, -0.67, -0.08,  0.25,  0.00, -0.62, -0.27, -0.72, -0.11, -0.11], # V
    [-0.02,  0.03,  0.12,  0.15, -0.12,  0.07,  0.13,  0.08,  0.05, -0.07, -0.14,  0.04, -0.11, -0.18, -0.00,  0.14,  0.00, -0.36, -0.21, -0.11,  0.00,  0.00], # X, calculated as weighted (freq) sum
    [-0.02,  0.03,  0.12,  0.15, -0.12,  0.07,  0.13,  0.08,  0.05, -0.07, -0.14,  0.04, -0.11, -0.18, -0.00,  0.14,  0.00, -0.36, -0.21, -0.11,  0.00,  0.00]]) # -, same as X
+88 −0
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 12 10:01:57 2020

@author: Zhenqin Wu
"""
import os
import numpy as np

def read_fasta(path):
  with open(path, 'r') as f:
    lines = f.readlines()
  seq_descs = []
  seqs = []
  desc = None
  raw = ''
  for line in lines:
    if line[0] == '>':
      if desc != None:
        seq_descs.append(desc)
        seqs.append(raw)
        desc = None
        raw = ''
      desc = line.strip()[1:]
    else:
      raw += line.strip()
  if desc != None:
    seq_descs.append(desc)
    seqs.append(raw)
  return seq_descs, seqs


def write_fasta(sequence, path, name=None):
  if name is None:
    name = 'TEMP'
  with open(path, 'a') as f:
    f.writelines(['>%s\n' % name, '%s\n' % sequence, '\n'])
  return


def read_hhm(sequence, path, asterisks_replace=0.):
  with open(path, 'r') as f:
    hhm_file = f.read()
  profile_part = hhm_file.split('#')[-1]
  profile_part = profile_part.split('\n')
  whole_profile = [i.split() for i in profile_part]
  # This part strips away the header and the footer.
  whole_profile = whole_profile[5:-2]
  gap_profile = np.zeros((len(sequence), 10))
  aa_profile = np.zeros((len(sequence), 20))
  count_aa = 0
  count_gap = 0
  for line_values in whole_profile:
    if len(line_values) == 23:
      # The first and the last values in line_values are metadata, skip them.
      for j, t in enumerate(line_values[2:-1]):
        aa_profile[count_aa, j] = (
            2**(-float(t) / 1000.) if t != '*' else asterisks_replace)
      count_aa += 1
    elif len(line_values) == 10:
      for j, t in enumerate(line_values):
        gap_profile[count_gap, j] = (
            2**(-float(t) / 1000.) if t != '*' else asterisks_replace)
      count_gap += 1
    elif not line_values:
      pass
    else:
      raise ValueError('Wrong length of line %s hhm file. Expected 0, 10 or 23'
                       'got %d'%(line_values, len(line_values)))
  hmm_profile = np.hstack([aa_profile, gap_profile])
  return hmm_profile


def run_hhblits_local(sequence, path, name=None):
  if not os.path.exists(path, 'input.seq'):
    write_fasta(sequence, os.path.exists(path, 'input.seq'), name=name)
  
  commands = []
  commands.append('hhblits -v 1 -maxfilt 100000 -realign_max 100000 -all \
      -B 100000 -Z 100000 -diff inf -id 99 -cov 50 -i %s/input.seq -d \
      %s/UniRef30_2020_02 -oa3m %s/results.a3m -cpu 4 -n 3 -e 0.001' % \
      (path, os.environ['PNET_HHDB_PATH'], path))
  commands.append('reformat.pl -v 0 -r a3m clu %s/results.a3m %s/results.clu' \
      % (path, path))
  commands.append('reformat.pl -v 0 -r a3m fas %s/results.a3m %s/results.fas' \
      % (path, path))
  commands.append('hhmake -i %s/results.a3m' % path)
  return commands
 No newline at end of file
+173 −0
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 11 23:31:45 2020

@author: Zhenqin Wu
"""
import os
import numpy as np
from deepchem.utils.amino_acid_utils import AminoAcid, AminoAcidFreq, BLOSUM62, ContactPotential
from scipy.spatial.distance import pdist, squareform


def amino_acid_to_numeric(s):
  if s in AminoAcid:
    return AminoAcid[s]
  else:
    return AminoAcid['X']


def read_a3m_as_mat(sequence, a3m_seq_IDs, a3m_seqs, clean=True):
  """ Read .a3m profile generated by hhblits """
  assert sequence == a3m_seqs[0]
  sequence_array = [amino_acid_to_numeric(s) for s in sequence]
  seq_len = len(sequence_array)
  seq_mat = []
  for seq in a3m_seqs:
    seq_arr = [amino_acid_to_numeric(s) for s in seq if not s.islower()]
    if clean and sequence_identity(sequence_array, seq_arr) < seq_len * 0.01:
      continue
    seq_mat.append(seq_arr)
  seq_mat = np.array(seq_mat).astype(int)
  return seq_mat


def sequence_one_hot_encoding(sequence):
  """ One-hot encoding of amino acid sequence """
  for s in sequence:
    if not s == '-' and not s.isalpha():
      raise ValueError('Sequence %s not recognized' % sequence)
  seq_one_hot = np.zeros((len(sequence), len(AminoAcid)))
  for i, s in enumerate(sequence):
    seq_one_hot[i, amino_acid_to_numeric(s)] = 1
  return seq_one_hot


def sequence_deletion_probability(sequence, a3m_seqs):
  deletion_matrix = []
  for msa_sequence in a3m_seqs:
    deletion_vec = []
    deletion_count = 0
    for j in msa_sequence:
      if j.islower():
        deletion_count += 1
      else:
        deletion_vec.append(deletion_count)
        deletion_count = 0
    deletion_matrix.append(deletion_vec)
  deletion_matrix = np.array(deletion_matrix)
  deletion_matrix[deletion_matrix != 0] = 1.0
  deletion_probability = deletion_matrix.sum(0)/len(deletion_matrix)
  return deletion_probability.reshape((len(sequence), 1))


def sequence_weights(seq_mat):
  n_align, n_res = seq_mat.shape  
  dist_mat = pdist(seq_mat, 'hamming')
  dist_mat = squareform(dist_mat < 0.38)
  weights = 1 + np.sum(dist_mat, 0)
  return (1.0 / weights).reshape((n_align, 1))


def sequence_profile(seq_mat, weights=None):
  prof_ct = np.zeros((seq_mat.shape[1], 22))
  if weights is None:
    weights = np.ones((seq_mat.shape[0], 1))
  for i in range(22):
    prof_ct[:, i] = ((seq_mat == i) * weights).sum(0)
  prof_freq = prof_ct/prof_ct.sum(1, keepdims=True)
  return prof_freq


def sequence_profile_no_gap(seq_mat, weights=None):
  prof_ct = np.zeros((seq_mat.shape[1], 21))
  for i in range(21):
    prof_ct[:, i] = (seq_mat == i).sum(0)
  prof_freq = prof_ct/prof_ct.sum(1, keepdims=True)
  return prof_freq


def sequence_profile_with_prior(prof_freq):
  out_freq = np.zeros_like(prof_freq)
  beta = 10
  P_i = [AminoAcidFreq[aa] for aa in \
      sorted(AminoAcid.keys(), key=lambda x: AminoAcid[x])]
  P_i = P_i[:prof_freq.shape[1]]
  P_i = np.array(P_i)/sum(P_i)
  substitution_mat = BLOSUM62[:prof_freq.shape[1], :prof_freq.shape[1]]
  q_mat = np.matmul(P_i.reshape((-1, 1)), P_i.reshape((1, -1))) * \
      np.exp(0.3176 * substitution_mat)
  for i in range(len(prof_freq)):
    f_i = prof_freq[i]
    NC = np.where(f_i > 0)[0].shape[0]
    alpha = NC - 1
    g_i = np.matmul((f_i/P_i).reshape((1, -1)), q_mat)[0]
    g_i = g_i/g_i.sum()
    out_freq[i] = (f_i * alpha + g_i * beta)/(alpha + beta)
  return out_freq
    
  
def sequence_identity(s1, s2):
  s1 = np.array(s1)
  s2 = np.array(s2)
  return ((s1 == s2) * (s1 != AminoAcid['-'])).sum()


def sequence_static_prop(seq_mat, weights):
  num_alignments, seq_length = seq_mat.shape
  num_effective_alignments = weights.sum()
  feat = np.array([seq_length, num_alignments, num_effective_alignments])
  feat = np.stack([feat] * seq_length, 0)
  feat = np.concatenate([np.arange(seq_length).reshape((-1, 1)), feat], 1)
  return feat


def sequence_gap_matrix(seq_mat):
  gaps = (seq_mat == AminoAcid['-']) * 1
  gap_matrix = np.matmul(np.transpose(gaps), gaps) / seq_mat.shape[0]
  return gap_matrix


def profile_combinatorial(seq_mat, weights, w_prof):
  M, N = seq_mat.shape
  n_res = w_prof.shape[1]
  w = weights.reshape((M, 1, 1))
  combined = seq_mat.reshape((M, N, 1)) * 22 + seq_mat.reshape((M, 1, N))
  prof_2D = np.zeros((N, N, n_res, n_res))
  for i in range(n_res):
    for j in range(n_res):
      prof_2D[:, :, i, j] = ((combined == (i * 22 + j)) * w).sum(0)
  prof_2D = prof_2D / np.sum(w)
  return prof_2D


def mutual_information(prof_1D, prof_2D):
  """ This series of properties were calculated based on:
    "Mutual information without the influence of phylogeny or entropy
     dramatically improves residue contact prediction"
  """
  n_res = prof_1D.shape[0]
  def no_diag(mat):
    return 2 * mat - np.triu(mat) - np.tril(mat)
  H_1D = np.sum(-prof_1D * np.log(prof_1D + 1e-7) / np.log(21), axis=1)
  H_2D = np.sum(-prof_2D * np.log(prof_2D + 1e-7) / np.log(21), axis=(2, 3))
  MI = -H_2D + H_1D.reshape((n_res, 1)) + H_1D.reshape((1, n_res))  
  # Only take off-diagonal parts
  MI = no_diag(MI)
  MIr = MI / (H_2D + 1e-5)
  
  MI_1D = np.sum(MI, axis=1)/(n_res-1)
  MI_av = np.sum(MI)/n_res/(n_res-1)
  APC = MI_1D.reshape((n_res, 1)) * MI_1D.reshape((1, n_res)) / MI_av
  ASC = MI_1D.reshape((n_res, 1)) + MI_1D.reshape((1, n_res)) - MI_av  
  MIp = no_diag(MI - APC)
  MIa = no_diag(MI - ASC)  
  MI_feats = np.stack([MI, MIr, MIp, MIa], 2)
  return MI_feats


def mean_contact_potential(prof_2D):
  n_res = prof_2D.shape[2]
  cp = ContactPotential[:n_res, :n_res]
  mcp = np.sum(prof_2D * cp.reshape((1, 1, n_res, n_res)), axis=(2, 3))
  return np.expand_dims(mcp, 2)
 No newline at end of file