Commit 8ab7f584 authored by JinyuanSun's avatar JinyuanSun
Browse files

fix bugs and add docking tools

parent e240a407
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -130,3 +130,5 @@ dmypy.json
.pyre/

test*
01d70f7c997c6a7d73e8fc592865b84f7371642b7afdba535726ba70f020183e*
ce73fb8a1b802e6746c58ac3bf915d79506e2b5edc36e83f1cbfa3f6071a9a92*
 No newline at end of file
+19 −0
Original line number Diff line number Diff line
import json
import os
import rdkit
import time
import pandas as pd
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
    Function,
@@ -228,6 +230,22 @@ class ConversationHandler:
                    "required": ["query", "type"],
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "blind_docking",
                    "description": "Perform blind docking using the input protein and ligand",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "protein_pdb_file_path": {"type": "string", "description": "The file path to a local pdb file of the protein"},
                            "ligand_pdb_file_path": {"type": "string","description": "The file path to a local pdb file of the ligand"},
                            "complex_file_path": {"type": "string","description": "The path to save the complex PDB file. Need to be in the same directory as the protein and ligand pdb files"},
                        },
                    },
                    "required": ["query", "type"],
                },
            },
            
        ]
        self.available_functions = {
@@ -243,6 +261,7 @@ class ConversationHandler:
            "get_protein_sequence_from_pdb": self.cfn.get_protein_sequence_from_pdb,
            "search_rcsb": self.cfn.search_rcsb,
            "query_uniprot": self.cfn.query_uniprot,
            "blind_docking": self.cfn.blind_docking
        }

    def setup_workdir(self, work_dir):
+139 −0
Original line number Diff line number Diff line
@@ -6,10 +6,99 @@ import matplotlib.pyplot as plt
from cloudmol.cloudmol import PymolFold
from utils import query_pythia, handle_file_not_found_error
import os
from io import StringIO
from stmol import showmol
from rdkit import Chem
from rdkit.Chem import AllChem
import time
import pandas as pd

from Bio.PDB import PDBParser

def read_first_model_pdbqt(pdbqt_filename):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("PDBQT", pdbqt_filename)[0]
    atoms = []
    for atom in structure.get_atoms():
        atoms.append(atom)
    return atoms

def format_as_pdb_hetatm(serial:int, atom_name:str, element:str, resseq:int, x, y, z):

    x = "{:6s}{:5d} {:^4s}{:1s}{:3s} {:1s}{:4d}{:1s}   {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f}          {:>2s}{:2s}".format('HETATM', serial, atom_name, "", "LIG", 'X', 1, "", float(x), float(y), float(z), 1.00, 0, element, '')
    return x + '\n'

def concate_ligand_to_receptor(ligand_file_path, receptor_file_path, output_filename):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("receptor", receptor_file_path)[0]
    ligand_atoms = read_first_model_pdbqt(ligand_file_path)
    resseqs = [residue.id[1] for residue in structure.get_residues()]
    serial = max([atom.serial_number for atom in structure.get_atoms()]) + 1
    resseq = max(resseqs) + 1
    with open(output_filename, 'w+') as f:
        with open(receptor_file_path, 'r') as receptor_file:
            for line in receptor_file:
                if line.startswith("ATOM"):
                    f.write(line)
        for i, atom in enumerate(ligand_atoms):
            x, y, z = atom.get_coord()
            f.write(format_as_pdb_hetatm(serial+i, atom.get_name(), atom.element, resseq, x, y, z))
        f.write("TER\n")
    return output_filename

def parse_vina_output(vina_output):
    start = vina_output.find("mode |   affinity")
    end = vina_output.find("Writing output")
    result_str = vina_output[start:end].strip()
    result_data = StringIO(result_str)
    df = pd.read_csv(result_data, delim_whitespace=True, skiprows=3, names=['Mode', 'Affinity (kcal/mol)', 'RMSD l.b.', 'RMSD u.b.'])
    return df

def submit_docking_task(protein_file, ligand_file, center_x=0, center_y=0, center_z=0, box_size_x=20, box_size_y=20, box_size_z=20, aa_list=None):
    headers = {
        'accept': 'application/json',
    }

    files = {
        'protein_file': open(protein_file, 'rb'),
        'ligand_file': open(ligand_file, 'rb'),
        'center_x': (None, str(center_x)),
        'center_y': (None, str(center_y)),
        'center_z': (None, str(center_z)),
        'box_size_x': (None, str(box_size_x)),
        'box_size_y': (None, str(box_size_y)),
        'box_size_z': (None, str(box_size_z)),
    }
    if aa_list:
        files['aa_list'] = (None, aa_list)
    response = requests.post('https://dockingapi.cloudmol.org/api/dock', headers=headers, files=files)
    return response.json()

def submit_pocket_prediction_task(protein_file):
    headers = {
        'accept': 'application/json',
    }
    files = {
        'file': open(protein_file, 'rb'),
    }
    response = requests.post('https://pocketapi.cloudmol.org/predict', headers=headers, files=files)
    return response.json()

def query_docking_status(docking_code):
    response = requests.get(f'https://dockingapi.cloudmol.org/task_status/{docking_code}')# .decode('utf-8')
    return response.text

def get_docking_result(docking_code):
    response = requests.get(f'https://dockingapi.cloudmol.org/task_progress/{docking_code}')
    return response.text

def save_best_docking_result(docking_code ,file_path):
    response = requests.get(f'https://dockingapi.cloudmol.org/get_best_pose/{docking_code}/best_pose.pdb')
    with open(file_path, 'w') as f:
        f.write(response.text)
        # with open(receptor_file_path, 'r') as receptor:
        #     f.write(receptor.read())
    return f"Docking result saved as {file_path}"

class ChatmolFN:
    def __init__(self, work_dir="./"):
@@ -132,6 +221,56 @@ class ChatmolFN:
        max_num = min(int(max_num), len(pdb_ids))
        return f"The top {max_num} PDB IDs are {pdb_ids[:max_num]}"

    def blind_docking(self, protein_pdb_file_path, ligand_pdb_file_path, complex_file_path):
        """
        Blind docking between a protein and a ligand
        Parameters:
        - protein_pdb_file_path (str): The path to the protein PDB file.
        - ligand_pdb_file_path (str): The path to the ligand PDB file.
        - complex_file_path (str): The path to save the complex PDB file.

        """
        print('Submitting pocket prediction task...')
        pocket_prediction = submit_pocket_prediction_task(protein_pdb_file_path)
            
        pocket_aas = pocket_prediction['Confident pocket residues'].replace('+', ',')
        if len(pocket_prediction['Confident pocket residues'].split('+')) < 2:
            pocket_aas = pocket_prediction['Likely pocket residues'].replace('+', ',')
        print('Pocket residues:', pocket_aas)
        print('Submitting docking task...')
        docking_code = submit_docking_task(protein_pdb_file_path, ligand_pdb_file_path, aa_list=pocket_aas)
        docking_code = docking_code['hash_code']
        print(docking_code)
        status = ""
        status_prev = ''
        print("status:")
        while status != '"completed"':
            status = query_docking_status(docking_code)
        
            print(f"Debug: {status}")
            if status == '"completed"':
                print("finished")
                save_best_docking_result(docking_code, complex_file_path)
                concate_ligand_to_receptor(complex_file_path, protein_pdb_file_path, complex_file_path)
                log = get_docking_result(docking_code)
                if log.endswith("done.\n"):
                    res_df = parse_vina_output(log)
                    # st.
                    return res_df.to_string()
            else:
                time.sleep(5)
        # save_best_docking_result(docking_code, complex_file_path)
        # concate_ligand_to_receptor(complex_file_path, protein_pdb_file_path, complex_file_path)
        # log = get_docking_result(docking_code)
        # res_df = parse_vina_output(log)
        # while res_df.empty:
        #     log = get_docking_result(docking_code)
        #     res_df = parse_vina_output(log)
        # # print(res_df.to_string())
        # # print(log)
        # return res_df.to_string()


    @handle_file_not_found_error
    def protein_single_point_mutation_prediction(self, pdb_file, mutations):
        pythia_res = query_pythia(pdb_file)
+14 −7
Original line number Diff line number Diff line
from openai import OpenAI
import streamlit as st
import chatmol_fn as cfn
import chatmol_fn as cfn_
from stmol import showmol
from streamlit_float import *
from viewer_utils import show_pdb, update_view
from utils import test_openai_api, function_args_to_streamlit_ui
from streamlit_molstar import st_molstar, st_molstar_rcsb, st_molstar_remote
from streamlit_molstar.docking import st_molstar_docking
import hashlib
import new_function_template
import shutil
from chat_helper import ConversationHandler, compose_chat_completion_message
@@ -17,7 +19,7 @@ import re
import streamlit_analytics
import requests
import inspect

m = hashlib.sha256()
st.set_page_config(layout="wide")
st.session_state.new_added_functions = []

@@ -30,7 +32,7 @@ if "function_queue" not in st.session_state:
if "cfn" in st.session_state:
    cfn = st.session_state["cfn"]
else:
    cfn = cfn.ChatmolFN()
    cfn = cfn_.ChatmolFN()
    st.session_state["cfn"] = cfn

st.title("ChatMol copilot", anchor="center")
@@ -80,14 +82,17 @@ else:
    st.session_state["api_key"] = api_key_test


hash_string = "WD_" + str(hash(openai_api_key + project_id)).replace("-", "_")
# hash_string = "WD_" + str(hash(openai_api_key + project_id)).replace("-", "_")
# pub_dir = openai_api_key + project_id
m.update((openai_api_key + project_id).encode())
hash_string = m.hexdigest()
if st.sidebar.button("Clear Project History"):
    if os.path.exists(f"./{hash_string}"):
        shutil.rmtree(f"./{hash_string}")
        st.session_state.messages = []
        st.session_state.function_queue = []
        st.session_state.new_added_functions = []
        st.session_state.cfn = cfn.ChatmolFN()
        st.session_state.cfn = cfn_.ChatmolFN()
# try to bring back the previous session
work_dir = f"./{hash_string}"
cfn.WORK_DIR = work_dir
@@ -155,7 +160,7 @@ if "messages" not in st.session_state:
    st.session_state.messages = [
        {
            "role": "system",
            "content": "You are ChatMol copilot, a helpful copilot in molecule analysis with tools. Use tools only when you need them. Only answer to questions related molecular modelling.",
            "content": "You are ChatMol copilot, a helpful copilot in molecule analysis with tools. Use tools only when you need them. Answer to questions related molecular modelling.",
        }
    ]

@@ -406,7 +411,7 @@ with displaycol:
        col1, col2 = st.columns([1, 1])
        with col1:
            viewer_selection = st.selectbox(
                "Select a viewer", options=["molstar", "py3Dmol"], index=0
                "Select a viewer", options=["molstar", "py3Dmol", 'molstar docking'], index=0
            )
        if viewer_selection == "molstar":
            pdb_files = [f for f in os.listdir(cfn.WORK_DIR) if f.endswith(".pdb")]
@@ -416,6 +421,8 @@ with displaycol:
                        "Select a pdb file", options=pdb_files, index=0
                    )
                st_molstar(f"{cfn.WORK_DIR}/{pdb_file}", height=500)
        if viewer_selection == "molstar docking":
            st_molstar_docking(f"{cfn.WORK_DIR}/{pdb_file}", height=500)
        if viewer_selection == "py3Dmol":
            color_options = {
                "Confidence": "pLDDT",
+56 −0
Original line number Diff line number Diff line
import types

def translate_to_protein(self, seq: str, pname=None):
    from Bio.Seq import Seq

    nucleotide_seq = Seq(seq)
    protein_seq = nucleotide_seq.translate()
    if pname:
        return f"The protein sequence of {seq} is `>{pname}\n{protein_seq}`"
    else:
        return f"The protein sequence of {seq} is `>protein\n{protein_seq}`"

function_descriptions = [
    {  # This is the description of the function
    "type": "function",
    "function": {
        "name": "translate_to_protein",
        "description": "Translate a DNA/RNA sequence to a protein sequence",
        "parameters": {
            "type": "object",
            "properties": {
                "seq": {"type": "string", "description": "The DNA/RNA sequence"},
            },
        },
        "required": ["seq"],
    },
}
]

test_data = {
    "translate_to_protein": {
        "input": {
            "self": None,
            "seq": "ATGCGAATTTGGGCCC",
        },
        "output": "The protein sequence of ATGCGAATTTGGGCCC is `>protein\nMRFL`",
    }
}

### DO NOT MODIFY BELOW THIS LINE ###
def get_all_functions():
    all_functions = []
    global_functions = globals()
    for _, func in global_functions.items():
        if isinstance(func, types.FunctionType):
            all_functions.append(func)
    return all_functions

def get_info():
    return {"functions": get_all_functions(), "descriptions": function_descriptions}

def test_new_function(function, function_name, test_data):
    return (
        function(**test_data[function_name]["input"])
        == test_data[function_name]["output"]
    )
Loading