Commit bda6314f authored by JinyuanSun's avatar JinyuanSun
Browse files

add copilot

parent c15f1fd2
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
# ChatMol copilot

[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/9uMFZMQqTf8/0.jpg)](https://www.youtube.com/watch?v=9uMFZMQqTf8)


## Introduction
This is ChatMol copilot, just like other copilot, it is designed to help your work. Here the LLM is enpowered by computational biology tools and databases. We showed some cases in the video, and you can also try it by yourself.

## Installation

```bash
git https://github.com/JinyuanSun/ChatMol
cd ChatMol/copilot_public
pip install -r requirements.txt
```

## Usage
```bash
streamlit run main.py
```

## Online Demo
We provided an online demo for you to try. [Click here](https://copilot.cloudmol.org/) to try it. This is not necessarily the latest version, but it is enough for you to try.
 No newline at end of file
+311 −0
Original line number Diff line number Diff line
import json
import os
import rdkit
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
    Function,
    ChatCompletionMessageToolCall,
)
import requests
import chatmol_fn as cfn


class ConversationHandler:
    def __init__(self, client, cfn, model_name="gpt-3.5-turbo-1106"):
        self.client = client
        self.model_name = model_name
        self.cfn = cfn
        self.messages = []
        self.tools = [
            {
                "type": "function",
                "function": {
                    "name": "fetch_asked_pdb",
                    "description": "Show the 3D structure of a specified protein by ID. This function supports three databases:\
    - RCSB PDB: Uses PDB IDs to download protein structures from the RCSB Protein Data Bank.\
    - AlphaFoldDB: Requires UniProt IDs to download predicted protein structures from AlphaFold Database.\
    - ESM: Uses MGnify IDs to download predicted structures from ESM (Evolutionary Scale Modeling).",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_id": {
                                "type": "string",
                                "description": "The PDB ID of the molecule",
                            },
                            "database": {
                                "type": "string",
                                "description": "The database name, chose in 'rcsb', 'afdb', 'esm' based on the id provided",
                            },
                        },
                        "required": ["pdb_id", "database"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "pocket_prediction",
                    "description": "query pocketapi.cloudmol.org to predict ligand binding sites of input pdb file.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_file": {
                                "type": "string",
                                "description": "The file path to a local pdb file",
                            },
                        },
                        "required": ["pdb_file"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "protein_structure_prediction",
                    "description": "Preict the structure of a protein sequence",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "seq": {
                                "type": "string",
                                "description": "The protein sequence",
                            },
                            "name": {
                                "type": "string",
                                "description": "The name of the protein sequence",
                            },
                        },
                        "required": ["seq", "name"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "display_protein_structure",
                    "description": "display a protein pdb file structure",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_file": {
                                "type": "string",
                                "description": "The file path to a local pdb file",
                            },
                        },
                        "required": ["pdb_file"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "protein_single_point_mutation_prediction",
                    "description": "Predict the effect of mutations on protein stability",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_file": {
                                "type": "string",
                                "description": "The file path to a local pdb file",
                            },
                            "mutations": {
                                "type": "string",
                                "description": "The mutations to be displayed, in format of 'A_12_F,C_45_D' (a comma separated list of mutations, where each mutation is of the form <wildtype>_<residue_number>_<mutation>",
                            },
                        },
                        "required": ["pdb_file", "mutations"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "recommand_stable_mutations",
                    "description": "design stablizing mutations for a protein",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_file": {
                                "type": "string",
                                "description": "The file path to a local pdb file",
                            },
                            "cutoff": {
                                "type": "string",
                                "description": "The cutoff of the stability score, default is -2",
                            },
                        },
                        "required": ["pdb_file"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "get_protein_sequence_from_pdb",
                    "description": "Get the sequence of a protein from a pdb file",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "pdb_file": {
                                "type": "string",
                                "description": "The file path to a local pdb file",
                            },
                            "chain_id": {
                                "type": "string",
                                "description": "The chain id of the protein, default is A",
                            },
                        },
                        "required": ["pdb_file"],
                    },
                },
            },
            {  # get the smiles string of a compound by its name
                "type": "function",
                "function": {
                    "name": "get_smiles_from_name",
                    "description": "Get the SMILES string of a compound by its name",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "compound_name": {
                                "type": "string",
                                "description": "The name of the compound",
                            },
                        },
                        "required": ["compound_name"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "generate_3D_conformation_and_save",
                    "description": "Generate 3D conformation and save as PDB",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "smiles": {
                                "type": "string",
                                "description": "The SMILES string of the compound",
                            },
                            "file_name": {
                                "type": "string",
                                "description": "The file name of the PDB file, e.g. 'lys.pdb'",
                            },
                        },
                        "required": ["smiles", "file_name"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "search_rcsb",
                    "description": "Search RCSB PDB database and get some pdb IDs",
                    "parameters": {
                        "type": "object",
                        "properties": {"query": {"type": "string", "description": "The query"}, "max_num": {"type": "string", "description": "The max number of results"}},
                        "required": ["query"],
                        # "required": [],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "query_uniprot",
                    "description": "Query UniProt database and get some UniProt IDs or informations of fasta content",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "query": {"type": "string", "description": "The query"},
                            "type": {
                                "type": "string",
                                "description": "should be one of 'search_query', 'get_txt', or 'get_fasta', search_query is to search UniProt database and get some UniProt IDs, get_txt is to get the txt content of a UniProt ID containing annotations, get_fasta is to get the fasta content of a UniProt ID",
                            },
                        },
                    },
                        "required": ["query", "type"],
                },
            },
            
        ]
        self.available_functions = {
            "fetch_asked_pdb": self.cfn.fetch_asked_pdb,
            "pocket_prediction": self.cfn.pocket_prediction,
            "display_protein_structure": self.cfn.display_protein_structure,
            "protein_structure_prediction": self.cfn.protein_structure_prediction,
            "get_work_dir": self.cfn.get_work_dir,
            "recommand_stable_mutations": self.cfn.recommand_stable_mutations,
            "protein_single_point_mutation_prediction": self.cfn.protein_single_point_mutation_prediction,
            "get_smiles_from_name": self.cfn.get_smiles_from_name,
            "generate_3D_conformation_and_save": self.cfn.generate_3D_conformation_and_save,
            "get_protein_sequence_from_pdb": self.cfn.get_protein_sequence_from_pdb,
            "search_rcsb": self.cfn.search_rcsb,
            "query_uniprot": self.cfn.query_uniprot,
        }

    def setup_workdir(self, work_dir):
        self.cfn.WORK_DIR = work_dir

    def run_round(self, user_message):
        self.messages.append(
            {
                "role": "user",
                "content": user_message,
            }
        )

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=self.messages,
            tools=self.tools,
            tool_choice="auto",
        )

        response_message = response.choices[0].message
        tool_calls = response_message.tool_calls
        self.messages.append(response_message)
        second_response = None
        if tool_calls:
            for tool_call in tool_calls:
                function_name = tool_call.function.name
                function_to_call = self.available_functions[function_name]
                function_args = json.loads(tool_call.function.arguments)
                function_response = function_to_call(**function_args)

                self.messages.append(
                    {
                        "tool_call_id": tool_call.id,
                        "role": "tool",
                        "name": function_name,
                        "content": function_response,
                    }
                )
            second_response = client.chat.completions.create(
                model=self.model_name,
                messages=self.messages,
            )  # get a new response from the model where it can see the function response
        return response, second_response


def compose_chat_completion_message(
    role="assistant", content="", tool_call_dict_list=[]
):
    tool_calls = []
    for tool_call_dict in tool_call_dict_list:
        tool_call = ChatCompletionMessageToolCall(
            id=tool_call_dict["id"],
            function=Function(
                name=tool_call_dict["function"]["name"],
                arguments=tool_call_dict["function"]["arguments"],
            ),
            type="function",
        )
        tool_calls.append(tool_call)
    message = ChatCompletionMessage(
        role=role,
        content=content,
        tool_calls=tool_calls,
    )
    return message
 No newline at end of file
+223 −0
Original line number Diff line number Diff line
import json
import requests
import py3Dmol
from tqdm import tqdm
import matplotlib.pyplot as plt
from cloudmol.cloudmol import PymolFold
from utils import query_pythia, handle_file_not_found_error
import os
from stmol import showmol
from rdkit import Chem
from rdkit.Chem import AllChem


class ChatmolFN:
    def __init__(self, work_dir="./"):
        self.WORK_DIR = "./"
        self.STREAMLIT_GUI = True
        self.VIEW_DICTS = {}
        self.viewer_height = 300
        self.viewer_width = 300

    def query_uniprot(self, query, type=["search_query", "get_txt", "get_fasta"]):
        query = query.replace(" ", "+")
        if type == "search_query":
            url = f"https://rest.uniprot.org/uniprot/search?query={query}&format=tsv"
            response = requests.get(url)
            if response.status_code == 200:
                return "\n".join(response.text.split("\n")[:10])
            else:
                return f"Failed to query {query}. HTTP Status Code: {response.status_code}"
        if type == "get_txt":
            url = f"https://www.uniprot.org/uniprot/{query}.txt"
            response = requests.get(url)
            if response.status_code == 200:
                return f"Full uniprot record:\n{response.text}"
            else:
                return f"Failed to query {query}. HTTP Status Code: {response.status_code}"
        if type == "get_fasta":
            url = f"https://www.uniprot.org/uniprot/{query}.fasta"
            response = requests.get(url)
            if response.status_code == 200:
                return f"Fasta of {query}:\n{response.text}\n"
            else:
                return f"Failed to query {query}. HTTP Status Code: {response.status_code}"


    def fetch_asked_pdb(self, pdb_id, database=["rcsb", "afdb", "esm"]):
        """
        Download the PDB file for a given protein from RCSB (pdb_id), AlphaFoldDB (uniprot id) or esmatlas (MGnifyid),
        show it using p3dmol
        Parameters:
        - pdb_id (str): The ID of the protein. The format depends on the selected database.
        - database (str): A database name, includes 'rcsb', 'afdb', 'esm'.
        """
        if database == "rcsb":
            url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
            response = requests.get(url)
            if response.status_code == 200:
                system = response.text
                with open(f"{self.WORK_DIR}/{pdb_id}.pdb", "w") as ofile:
                    ofile.write(system)
            else:
                return f"Failed to download PDB file for {pdb_id}. HTTP Status Code: {response.status_code}"
        if database == "esm":
            url = f"https://api.esmatlas.com/fetchPredictedStructure/{pdb_id}.pdb"
            response = requests.get(url, verify=False)
            if response.status_code == 200:
                system = response.text
                with open(f"{self.WORK_DIR}/{pdb_id}.pdb", "w") as ofile:
                    ofile.write(system)
            else:
                return f"Failed to download PDB file for {pdb_id}. HTTP Status Code: {response.status_code}"

        if database == "afdb":
            url = f"https://alphafold.ebi.ac.uk/files/AF-{pdb_id}-F1-model_v4.pdb"
            response = requests.get(url)
            if response.status_code == 200:
                system = response.text
                with open(f"{self.WORK_DIR}/{pdb_id}.pdb", "w") as ofile:
                    ofile.write(system)
            else:
                return f"Failed to download PDB file for {pdb_id}. HTTP Status Code: {response.status_code}"

        view = py3Dmol.view(height=self.viewer_height, width=self.viewer_width)
        view.addModelsAsFrames(system)
        view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
        view.zoomTo()
        if self.STREAMLIT_GUI:
            self.VIEW_DICTS[pdb_id] = view
            showmol(view, height=self.viewer_height, width=self.viewer_width)
        else:
            view.show()
        return f"{pdb_id} shows here and saved to {self.WORK_DIR}/{pdb_id}.pdb"

    def get_work_dir(self):
        return self.WORK_DIR

    def get_smiles_from_name(self, compound_name: str):
        # Get CID from compound name
        cid_request_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{compound_name}/cids/TXT"
        cid_response = requests.get(cid_request_url)
        cid = cid_response.text.strip()

        # Get SMILES from CID
        smiles_request_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/TXT"
        smiles_response = requests.get(smiles_request_url)
        smiles = smiles_response.text.strip()

        return f"The SMILES of {compound_name} is {smiles}"

    def display_protein_structure(self, pdb_file):
        try:
            system = open(pdb_file, "r").read()
            view = py3Dmol.view(height=self.viewer_height, width=self.viewer_width)
            view.addModelsAsFrames(system)
            view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
            view.zoomTo()
            if self.STREAMLIT_GUI:
                self.VIEW_DICTS[pdb_file] = view
                showmol(view, height=self.viewer_height, width=self.viewer_width)
            else:
                view.show()
            return "The protein is showed here!"
        except Exception as e:
            print(f"The error is:\n{e}")
            return "wrong file!"

    def search_rcsb(self, query, max_num=3):
        import biotite.database.rcsb as rcsb

        query = rcsb.BasicQuery(query)
        pdb_ids = rcsb.search(query)
        max_num = min(int(max_num), len(pdb_ids))
        return f"The top {max_num} PDB IDs are {pdb_ids[:max_num]}"

    @handle_file_not_found_error
    def protein_single_point_mutation_prediction(self, pdb_file, mutations):
        pythia_res = query_pythia(pdb_file)
        mutation_res = ""
        for mutation in pythia_res.split("\n"):
            m, score = mutation.split()
            if m in mutations:
                mutation_res += f"{m} {score}\n"
        return mutation_res

    @handle_file_not_found_error
    def recommand_stable_mutations(self, pdb_file, cutoff=-2):
        pythia_res = query_pythia(pdb_file)
        mutation_res = ""
        for mutation in pythia_res.split("\n"):
            m, score = mutation.split()
            if float(score) < float(cutoff):
                mutation_res += f"{m} {score}\n"
        return mutation_res

    @handle_file_not_found_error
    def get_protein_sequence_from_pdb(self, pdb_file, chain_id="A"):
        from Bio.Seq import Seq
        from Bio import SeqIO

        records = SeqIO.parse(pdb_file, "pdb-atom")
        for record in records:
            if record.annotations["chain"] == chain_id:
                return f"The sequnece of chain {chain_id} in pdb file {pdb_file} is {str(record.seq)}"

    @handle_file_not_found_error
    def pocket_prediction(self, pdb_file):
        """
        query pocketapi.cloudmol.org to predict ligand binding sites of input pdb file.
        """

        headers = {"accept": "application/octet-stream",}
        files = { "file": open(pdb_file, "rb"),}
        response = requests.post("https://pocketapi.cloudmol.org/predict", headers=headers, files=files)
        x = response.json()
        system = open(pdb_file, "r").read()
        colors = ["#FF0000", "#FFFF00", "#00FF00", "#00FFFF", "#0000FF"]
        pdbview = py3Dmol.view(height=self.viewer_height, width=self.viewer_width)
        pdbview.addModel(system, "pdb")
        pdbview.setStyle({"cartoon": {"color": "#193f90"}})
        pdbview.setBackgroundColor("white")
        i = 0
        for line in system.split("\n"):
            if len(line) != 0 and line.startswith("ATOM"):
                i += 1
                resn = str(int(line[22:26]))
                color = colors[0]
                if resn in x["Likely pocket residues"].split("+"):
                    color = colors[1]
                if resn in x["Confident pocket residues"].split("+"):
                    color = colors[3]
                if resn in x["Highly confident pocket residues"].split("+"):
                    color = colors[4]
                pdbview.setStyle(
                    {"model": -1, "serial": i}, {"cartoon": {"color": color}}
                )
        pdbview.setStyle({"hetflag": True}, {"stick": {"radius": 0.3}})
        pdbview.zoomTo()
        if self.STREAMLIT_GUI:
            self.VIEW_DICTS["pocket_" + pdb_file.split("/")[-1]] = pdbview
            showmol(pdbview, height=self.viewer_height, width=self.viewer_width)
        else:
            pdbview.show()
        return response.text

    def protein_structure_prediction(self, seq, name):
        """Protein structure prediction"""
        pf = PymolFold()
        pf.set_path(self.WORK_DIR)  # change the path to save results
        pf.query_esmfold(seq, name)
        pdb_filename = os.path.join(pf.ABS_PATH, name) + ".pdb"
        return f"Predicted structure saved as {pdb_filename}"

    def generate_3D_conformation_and_save(self, smiles: str, file_name: str):
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, AllChem.ETKDG())
        AllChem.MMFFOptimizeMolecule(mol)
        file_name = os.path.join(self.WORK_DIR, file_name)
        writer = Chem.PDBWriter(file_name)
        writer.write(mol)
        writer.close()
        return f"The conformation of {smiles} is saved as {file_name}"

copilot_public/main.py

0 → 100644
+292 −0

File added.

Preview size limit exceeded, changes collapsed.

+16 −0
Original line number Diff line number Diff line
Bio==1.6.0
biopython==1.81
biotite==0.38.0
cloudmol==0.1.2
matplotlib==3.7.1
openai==1.3.8
py3Dmol==2.0.0.post2
rdkit==2023.9.1
Requests==2.31.0
stmol==0.0.9
streamlit==1.24.0
streamlit_analytics==0.4.1
streamlit_float==0.3.2
streamlit_js_eval==0.1.5
streamlit_molstar==0.4.6
tqdm==4.66.1
Loading