Unverified Commit 1b4a80b1 authored by jinyuan sun's avatar jinyuan sun Committed by GitHub
Browse files

Merge pull request #27 from JinyuanSun/sjy_dev

update readme of copilot
parents 89c54774 7ccb5fcb
Loading
Loading
Loading
Loading
+49 −2
Original line number Diff line number Diff line
@@ -18,6 +18,53 @@ pip install -r requirements.txt
```bash
streamlit run main.py
```
## Want more funtionality?

## Online Demo
We provided an online demo for you to try. [Click here](https://copilot.cloudmol.org/) to try it. This is not necessarily the latest version, but it is enough for you to try.
 No newline at end of file
You can easily add more functionality to ChatMol copilot. 

```python

# 1. Define a function
def translate_to_protein(self, seq:str, pname=None):
    from Bio.Seq import Seq
    nucleotide_seq = Seq(seq)
    protein_seq = nucleotide_seq.translate()
    if pname:
        return f"The protein sequence of {seq} is `>{pname}\n{protein_seq}`"
    else:
        return f"The protein sequence of {seq} is `>protein\n{protein_seq}`"

# 2. Add it to the conversation
cfn.translate_to_protein = translate_to_protein.__get__(cfn)
conversation.tools.append(
    { # This is the description of the function
        "type": "function",
        "function": {
            "name": "translate_to_protein",
            "description": "Translate a DNA/RNA sequence to a protein sequence",
            "parameters": {
                "type": "object",
                "properties": {
                    "seq": {"type": "string", "description": "The DNA/RNA sequence"},
                },
            },
            "required": ["seq"],
        },
    }
)
conversation.available_functions["translate_to_protein"] = cfn.translate_to_protein
```
By adding the above code to `main.py` at line 97 after `conversation = ConversationHandler(client, cfn, model_name=model)`, you can add this translation function to ChatMol copilot.

You are more than welcome to contribute any function to ChatMol copilot.
1. Fork this repo
2. Create a new branch
3. Add your function in `copilot_public/new_function_template.py`.  
   In this file, you need to define a function and clearly define the parameters and return value of this function, also add test case in `test_data`. You can refer to the existing content in `copilot_public/new_function_template.py`. We have a button named `Add from template`. You can click it to add your function to ChatMol copilot.
4. Create a pull request
5. We will review your code and merge it to the main branch  

**If you still don't know what to do, just paste this and the content in `copilot_public/new_function_template.py` to the input box of ChatGPT and ask it to do all the coding for you.** *Remeber to add the magic prompt: "I don't have fingers, can you write the complete code for me."*

## Online Version
We provided an online version for you. [Click here](https://chatmol.org/copilot/) to try it.  
 No newline at end of file
+0 −1
Original line number Diff line number Diff line
@@ -43,7 +43,6 @@ class ChatmolFN:
            else:
                return f"Failed to query {query}. HTTP Status Code: {response.status_code}"


    def fetch_asked_pdb(self, pdb_id, database=["rcsb", "afdb", "esm"]):
        """
        Download the PDB file for a given protein from RCSB (pdb_id), AlphaFoldDB (uniprot id) or esmatlas (MGnifyid),
+151 −96
Original line number Diff line number Diff line
@@ -4,11 +4,9 @@ import chatmol_fn as cfn
from stmol import showmol
from streamlit_float import *
from viewer_utils import show_pdb, update_view
# from utils import test_openai_api
from streamlit_js_eval import streamlit_js_eval
from utils import test_openai_api
from streamlit_molstar import st_molstar, st_molstar_rcsb, st_molstar_remote


import new_function_template
import shutil
from chat_helper import ConversationHandler, compose_chat_completion_message
import os
@@ -18,43 +16,27 @@ import pickle
import re
import streamlit_analytics
import requests

st.set_page_config(layout="wide")
st.session_state["viewport_width"] = streamlit_js_eval(
    js_expressions="window.innerWidth", key="ViewportWidth"
)
# protein_viewer_width = protein_viewer_height = st.session_state["viewport_width"] * 0.45
st.session_state.new_added_functions = []
try:
    print(st.session_state["messages"])
except:
    pass
# width = streamlit_js_eval(js_expressions='screen.width', key = 'SCR')
# print(width)
def test_openai_api(api_key):
    client = OpenAI(api_key=api_key)
    try:
        response = client.chat.completions.create(
                    model=st.session_state["openai_model"],
                    messages=[{"role": "user", "content": "Test prompt"}],
                    max_tokens=10,
                )
        return True

    except Exception as e:
        print("OpenAI API is not working")
        return False


pattern = r"<chatmol_sys>.*?</chatmol_sys>"
# wide
if 'cfn' in st.session_state:
if "cfn" in st.session_state:
    cfn = st.session_state["cfn"]
else:
    cfn = cfn.ChatmolFN()
    st.session_state["cfn"] = cfn

st.title("ChatMol copilot", anchor="center")
st.sidebar.write("2023 Dec 12 public version")
st.sidebar.write("ChatMol copilot is a copilot for protein engineering. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol).")
st.sidebar.write("2023 Dec 16 public version")
st.sidebar.write(
    "ChatMol copilot is a copilot for protein engineering. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
)
st.write("Enjoy modeling proteins with ChatMol copilot! 🤖️ 🧬")
float_init()

@@ -69,22 +51,27 @@ except:
    print("No OPENAI_API_KEY found in environment variables")

if project_id + str(openai_api_key) == "Project-X":
    st.warning("Please change the project name to your own project name, and provide your own OpenAI API key.")
    st.warning(
        "Please change the project name to your own project name, and provide your own OpenAI API key."
    )
    st.stop()

model = st.sidebar.selectbox("Model", ["gpt-3.5-turbo-1106", "gpt-4-32k-0613", "gpt-3.5-turbo-16k", "gpt-4-1106-preview"])
model = st.sidebar.selectbox(
    "Model",
    ["gpt-3.5-turbo-1106", "gpt-4-32k-0613", "gpt-3.5-turbo-16k", "gpt-4-1106-preview"],
)
st.session_state["openai_model"] = model




if 'api_key' in st.session_state:
if "api_key" in st.session_state:
    api_key_test = st.session_state["api_key"]
    if st.session_state.api_key is False:
        api_key_test = test_openai_api(openai_api_key)
        st.session_state.api_key = api_key_test
        if api_key_test is False:
            st.warning("The provided OpenAI API key seems to be invalid. Please check again. If you don't have an OpenAI API key, please visit https://platform.openai.com/ to get one.")
            st.warning(
                "The provided OpenAI API key seems to be invalid. Please check again. If you don't have an OpenAI API key, please visit https://platform.openai.com/ to get one."
            )
            st.stop()
else:
    api_key_test = test_openai_api(openai_api_key)
@@ -95,6 +82,7 @@ hash_string = "WD_" + str(hash(openai_api_key+project_id)).replace("-", "_")
if st.sidebar.button("Clear Project History"):
    if os.path.exists(f"./{hash_string}"):
        shutil.rmtree(f"./{hash_string}")
        st.session_state.messages = []
# try to bring back the previous session
work_dir = f"./{hash_string}"
cfn.WORK_DIR = work_dir
@@ -106,13 +94,65 @@ if not os.path.exists(work_dir):

client = OpenAI(api_key=openai_api_key)
conversation = ConversationHandler(client, cfn, model_name=model)

if add_translator := st.sidebar.checkbox("Add translator"):
    def translate_to_protein(self, seq:str, pname=None):
        from Bio.Seq import Seq
        nucleotide_seq = Seq(seq)
        protein_seq = nucleotide_seq.translate()
        if pname:
            return f"The protein sequence of {seq} is `>protein\n{protein_seq}`\n{pname}"
        else:
            return f"The protein sequence of {seq} is `>protein\n{protein_seq}`"

    cfn.translate_to_protein = translate_to_protein.__get__(cfn)

    conversation.tools.append(
        {
            "type": "function",
            "function": {
                "name": "translate_to_protein",
                "description": "Translate a DNA/RNA sequence to a protein sequence",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "seq": {"type": "string", "description": "The DNA/RNA sequence"},
                    },
                },
                "required": ["seq"],
            },
        }
    )
    conversation.available_functions["translate_to_protein"] = cfn.translate_to_protein

if add_from_template := st.sidebar.checkbox("Add from template"):
    function_info = new_function_template.get_info()
    descriptions = function_info['descriptions']
    new_funcs = function_info['functions']
    test_data = new_function_template.test_data
    for description, new_func in zip(descriptions, new_funcs):
        try:
            test_results = new_function_template.test_new_function(new_func, description['function']['name'], test_data)
            conversation.tools.append(description)
            conversation.available_functions[description['function']['name']] = new_func.__get__(cfn)
            if description['function']['name'] not in st.session_state.new_added_functions:
                st.sidebar.success(f"Function `{description['function']['name']}` added successfully.")
                st.session_state.new_added_functions.append(description['function']['name'])
        except Exception as e:
            st.warning(f"Failed to add function from template. Error: {e}")

available_functions = conversation.available_functions
available_tools = conversation.tools
if "openai_model" not in st.session_state:
    st.session_state["openai_model"] = model

if "messages" not in st.session_state:
    st.session_state.messages = [{"role": "system", "content": "You are ChatMol copilot, a helpful copilot in molecule analysis with tools. Use tools only when you need them."}]
    st.session_state.messages = [
        {
            "role": "system",
            "content": "You are ChatMol copilot, a helpful copilot in molecule analysis with tools. Use tools only when you need them. Only answer to questions related molecular modelling.",
        }
    ]

chatcol, displaycol = st.columns([1, 1])
with chatcol:
@@ -120,15 +160,18 @@ with chatcol:
        # print(type(message))
        try:
            if message["role"] != "system":
                with st.chat_message(message["role"]):
                    cleaned_string = re.sub(pattern, '', message["content"], flags=re.DOTALL)
                cleaned_string = re.sub(
                    pattern, "", message["content"], flags=re.DOTALL
                )
                if cleaned_string != "":
                    with st.chat_message(message["role"]):
                        st.markdown(cleaned_string)
        except:
            if message.role != "system":
                cleaned_string = re.sub(pattern, "", message.content, flags=re.DOTALL)
                if cleaned_string != "":
                    with st.chat_message(message.role):
                    cleaned_string = re.sub(pattern, '', message.content, flags=re.DOTALL)
                    if cleaned_string != "": st.markdown(cleaned_string)
                        st.markdown(cleaned_string)

if prompt := st.chat_input("What is up?"):
    with chatcol:
@@ -137,7 +180,7 @@ if prompt := st.chat_input("What is up?"):
            prompt += chatmol_sys_info
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            cleaned_string = re.sub(pattern, '', prompt, flags=re.DOTALL)
            cleaned_string = re.sub(pattern, "", prompt, flags=re.DOTALL)
            st.markdown(cleaned_string)

        with st.chat_message("assistant"):
@@ -145,39 +188,37 @@ if prompt := st.chat_input("What is up?"):
            full_response = ""
            tool_calls = []
            tool_call = None
            # print("=====================================")
            # print(st.session_state.messages)
            # print('+++++++++++++++++++++++++++++++++++++')
            # print(available_tools)
            for response in client.chat.completions.create(
                model=st.session_state["openai_model"],
                # messages=[
                    # {"role": m["role"], "content": m["content"]}
                    # for m in st.session_state.messages
                # ],
                messages=st.session_state.messages,
                stream=True,
                tools=available_tools,
                tool_choice="auto",
            ):
                full_response += (response.choices[0].delta.content or "")
                full_response += response.choices[0].delta.content or ""
                message_placeholder.markdown(full_response + "")
                # print(response)
                tool_call_chunk_list = response.choices[0].delta.tool_calls
                if tool_call_chunk_list:
                    for tool_call_chunk in tool_call_chunk_list:
                        if len(tool_calls) <= tool_call_chunk.index:
                            tool_calls.append({"id": "", "type": "function", "function": { "name": "", "arguments": "" } })
                        # print(tool_call_chunk.function.arguments)
                            tool_calls.append(
                                {
                                    "id": "",
                                    "type": "function",
                                    "function": {"name": "", "arguments": ""},
                                }
                            )
                        tool_call = tool_calls[tool_call_chunk.index]
                        if tool_call_chunk.id:
                            tool_call["id"] += tool_call_chunk.id
                        if tool_call_chunk.function.name:
                            tool_call["function"]["name"] += tool_call_chunk.function.name
                            tool_call["function"][
                                "name"
                            ] += tool_call_chunk.function.name
                        if tool_call_chunk.function.arguments:
                            tool_call["function"]["arguments"] += tool_call_chunk.function.arguments
                        # if tool_call_chunk.function.type:
                        #     tool_call["function"]["type"] = tool_call_chunk.function.type
                            tool_call["function"][
                                "arguments"
                            ] += tool_call_chunk.function.arguments
            function_response = ""
            st.session_state.messages.append(
                {
@@ -185,17 +226,14 @@ if prompt := st.chat_input("What is up?"):
                    "content": full_response,
                }
            )
            # if tool_calls != []:
                # response_message = {"role": "assistant","content": full_response}
            # else:
                # response_message = compose_chat_completion_message(role="assistant", content=full_response, tool_call_dict_list=tool_calls)
                # st.session_state.messages.append(response_message)
            # print(tool_calls)
            if tool_call:
                response_message = compose_chat_completion_message(role="assistant", content=full_response, tool_call_dict_list=tool_calls)
                response_message = compose_chat_completion_message(
                    role="assistant",
                    content=full_response,
                    tool_call_dict_list=tool_calls,
                )
                st.session_state.messages.append(response_message)
                for tool_call in tool_calls:
                    # print(tool_call)
                    function_name = tool_call["function"]["name"]
                    function_to_call = available_functions[function_name]
                    try:
@@ -203,7 +241,7 @@ if prompt := st.chat_input("What is up?"):
                        function_response = function_to_call(**function_args)
                        st.session_state.messages.append(
                            {
                                "tool_call_id": tool_call['id'],
                                "tool_call_id": tool_call["id"],
                                "role": "tool",
                                "name": function_name,
                                "content": function_response,
@@ -213,7 +251,7 @@ if prompt := st.chat_input("What is up?"):
                        print(f"The error is:\n{e}")
                        st.session_state.messages.append(
                            {
                                "tool_call_id": tool_call['id'],
                                "tool_call_id": tool_call["id"],
                                "role": "tool",
                                "name": function_name,
                                "content": f"error: {e}",
@@ -225,9 +263,9 @@ if prompt := st.chat_input("What is up?"):
                for response in client.chat.completions.create(
                    model=st.session_state["openai_model"],
                    messages=st.session_state.messages,
                    stream=True
                    stream=True,
                ):
                    full_response += (response.choices[0].delta.content or "")
                    full_response += response.choices[0].delta.content or ""
                    message_placeholder.markdown(full_response)
                st.session_state.messages.append(
                    {
@@ -249,13 +287,20 @@ if uploaded_file:
    with open(pdb_file, "wb") as f:
        f.write(uploaded_file.getbuffer())

    pdb_string = "\n".join([x for x in uploaded_file.getvalue().decode("utf-8").split("\n") if x.startswith("ATOM") or x.startswith("HETATM")])
    pdb_string = "\n".join(
        [
            x
            for x in uploaded_file.getvalue().decode("utf-8").split("\n")
            if x.startswith("ATOM") or x.startswith("HETATM")
        ]
    )
    view = show_pdb(
        pdb_str=pdb_string,
                color=color_options[selected_color],
                show_sidechains=show_sidechains,
                show_ligands=show_ligands,
                show_mainchains=show_mainchains)
        color=color_options["Rainbow"],
        show_sidechains=False,
        show_ligands=True,
        show_mainchains=False,
    )

    cfn.VIEW_DICTS[pdb_id] = view

@@ -265,15 +310,23 @@ with displaycol:
    with container:
        col1, col2 = st.columns([1, 1])
        with col1:
            viewer_selection = st.selectbox("Select a viewer", options=["molstar", "py3Dmol"], index=0)
            viewer_selection = st.selectbox(
                "Select a viewer", options=["molstar", "py3Dmol"], index=0
            )
        if viewer_selection == "molstar":
            pdb_files = [f for f in os.listdir(cfn.WORK_DIR) if f.endswith(".pdb")]
            if len(pdb_files) > 0:
                with col2:
                    pdb_file = st.selectbox("Select a pdb file", options=pdb_files, index=0)
                st_molstar(f"{cfn.WORK_DIR}/{pdb_file}", height=400)
                    pdb_file = st.selectbox(
                        "Select a pdb file", options=pdb_files, index=0
                    )
                st_molstar(f"{cfn.WORK_DIR}/{pdb_file}", height=500)
        if viewer_selection == "py3Dmol":
            color_options = {"Confidence": "pLDDT", "Rainbow": "rainbow", "Chain": "chain"}
            color_options = {
                "Confidence": "pLDDT",
                "Rainbow": "rainbow",
                "Chain": "chain",
            }
            selected_color = st.sidebar.selectbox(
                "Color Scheme", options=list(color_options.keys()), index=0
            )
@@ -283,7 +336,9 @@ with displaycol:
            show_ligands = st.sidebar.checkbox("Show Ligands", value=True)
            if len(cfn.VIEW_DICTS) > 0:
                with col2:
                    select_view = st.selectbox("Select a view", options=list(cfn.VIEW_DICTS.keys()), index=0)
                    select_view = st.selectbox(
                        "Select a view", options=list(cfn.VIEW_DICTS.keys()), index=0
                    )
                view = cfn.VIEW_DICTS[select_view]
                showmol(view, height=400, width=500)
        float_parent()
+54 −0
Original line number Diff line number Diff line
import types

def translate_to_protein(self, seq: str, pname=None):
    from Bio.Seq import Seq

    nucleotide_seq = Seq(seq)
    protein_seq = nucleotide_seq.translate()
    if pname:
        return f"The protein sequence of {seq} is `>{pname}\n{protein_seq}`"
    else:
        return f"The protein sequence of {seq} is `>protein\n{protein_seq}`"

function_descriptions = [{  # This is the description of the function
    "type": "function",
    "function": {
        "name": "translate_to_protein",
        "description": "Translate a DNA/RNA sequence to a protein sequence",
        "parameters": {
            "type": "object",
            "properties": {
                "seq": {"type": "string", "description": "The DNA/RNA sequence"},
            },
        },
        "required": ["seq"],
    },
}]

test_data = {
    "translate_to_protein": {
        "input": {
            "self": None,
            "seq": "ATGCGAATTTGGGCCC",
        },
        "output": "The protein sequence of ATGCGAATTTGGGCCC is `>protein\nMRFL`",
    }
}

### DO NOT MODIFY BELOW THIS LINE ###
def get_all_functions():
    all_functions = []
    global_functions = globals()
    for _, func in global_functions.items():
        if isinstance(func, types.FunctionType):
            all_functions.append(func)
    return all_functions

def get_info():
    return {"functions": get_all_functions(), "descriptions": function_descriptions}

def test_new_function(function, function_name, test_data):
    return (
        function(**test_data[function_name]["input"])
        == test_data[function_name]["output"]
    )
+2 −1
Original line number Diff line number Diff line
@@ -8,9 +8,10 @@ py3Dmol==2.0.0.post2
rdkit==2023.9.1
Requests==2.31.0
stmol==0.0.9
streamlit==1.24.0
streamlit==1.29.0
streamlit_analytics==0.4.1
streamlit_float==0.3.2
streamlit_js_eval==0.1.5
streamlit_molstar==0.4.6
tqdm==4.66.1
ipython_genutils
 No newline at end of file
Loading