Commit f47ea28f authored by JinyuanSun's avatar JinyuanSun
Browse files

modified: copilot_public/main.py

	modified:   copilot_public/new_function_template.py
	modified:   copilot_public/utils.py
parent 13e3708e
Loading
Loading
Loading
Loading
+138 −43
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import chatmol_fn as cfn
from stmol import showmol
from streamlit_float import *
from viewer_utils import show_pdb, update_view
from utils import test_openai_api
from utils import test_openai_api, function_args_to_streamlit_ui
from streamlit_molstar import st_molstar, st_molstar_rcsb, st_molstar_remote
import new_function_template
import shutil
@@ -16,16 +16,17 @@ import pickle
import re
import streamlit_analytics
import requests
import inspect

st.set_page_config(layout="wide")
st.session_state.new_added_functions = []
try:
    print(st.session_state["messages"])
except:
    pass


pattern = r"<chatmol_sys>.*?</chatmol_sys>"
# wide
if "function_queue" not in st.session_state:
    st.session_state["function_queue"] = []

if "cfn" in st.session_state:
    cfn = st.session_state["cfn"]
else:
@@ -33,7 +34,7 @@ else:
    st.session_state["cfn"] = cfn

st.title("ChatMol copilot", anchor="center")
st.sidebar.write("2023 Dec 16 public version")
st.sidebar.write("2024 Jan 05 public version")
st.sidebar.write(
    "ChatMol copilot is a copilot for protein engineering. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
)
@@ -41,6 +42,7 @@ st.write("Enjoy modeling proteins with ChatMol copilot! 🤖️ 🧬")
float_init()

st.sidebar.title("Settings")
mode = st.sidebar.selectbox("Mode", ["automatic", "manual"], index=0)
with streamlit_analytics.track():
    project_id = st.sidebar.text_input("Project Name", "Project-X")
openai_api_key = st.sidebar.text_input("OpenAI API key", type="password")
@@ -83,6 +85,9 @@ if st.sidebar.button("Clear Project History"):
    if os.path.exists(f"./{hash_string}"):
        shutil.rmtree(f"./{hash_string}")
        st.session_state.messages = []
        st.session_state.function_queue = []
        st.session_state.new_added_functions = []
        st.session_state.cfn = cfn.ChatmolFN()
# try to bring back the previous session
work_dir = f"./{hash_string}"
cfn.WORK_DIR = work_dir
@@ -157,7 +162,6 @@ if "messages" not in st.session_state:
chatcol, displaycol = st.columns([1, 1])
with chatcol:
    for message in st.session_state.messages:
        # print(type(message))
        try:
            if message["role"] != "system":
                cleaned_string = re.sub(
@@ -172,6 +176,52 @@ with chatcol:
                if cleaned_string != "":
                    with st.chat_message(message.role):
                        st.markdown(cleaned_string)
    function_called = False

    for tool_call in st.session_state.function_queue:
        if tool_call["status"] == "pending":
            function_called = True
            # with st.chat_message("assistant"):
                # st.write(f"Please provide arguments for {tool_call['name']}")
            try:
                function_name = tool_call["name"]
                function_to_call = tool_call["func"]
                function_args = json.loads(tool_call["args"])
                function_response = function_to_call(**function_args)
                if function_response:
                    st.session_state.messages.append(
                        {
                            "tool_call_id": tool_call["tool_call_id"],
                            "role": "tool",
                            "name": function_name,
                            "content": function_response,
                        }
                    )
                    tool_call["status"] = "done"
                    
            except Exception as e:
                print(f"The error is:\n{e}")
                st.session_state.messages.append(
                    {
                        "tool_call_id": tool_call["tool_call_id"],
                        "role": "tool",
                        "name": function_name,
                        "content": f"error: {e}",
                    }
                )
                tool_call["status"] = "done"
    if function_called:
        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            full_response = ""
            for response in client.chat.completions.create(
                model=st.session_state["openai_model"],
                messages=st.session_state.messages,
                stream=True,
            ):
                full_response += response.choices[0].delta.content or ""
                message_placeholder.markdown(full_response)
    

if prompt := st.chat_input("What is up?"):
    with chatcol:
@@ -233,12 +283,15 @@ if prompt := st.chat_input("What is up?"):
                    tool_call_dict_list=tool_calls,
                )
                st.session_state.messages.append(response_message)
                if mode == "automatic":
                    for tool_call in tool_calls:
                        function_name = tool_call["function"]["name"]
                        function_to_call = available_functions[function_name]
                        try:
                            function_args = json.loads(tool_call["function"]["arguments"])

                            function_response = function_to_call(**function_args)
                            if function_response:
                                st.session_state.messages.append(
                                    {
                                        "tool_call_id": tool_call["id"],
@@ -257,8 +310,51 @@ if prompt := st.chat_input("What is up?"):
                                    "content": f"error: {e}",
                                }
                            )
            # if full_response:
            # message_placeholder.markdown(full_response)

                if mode == "manual":
                    print("manual mode")
                    for tool_call in tool_calls:
                        function_name = tool_call["function"]["name"]
                        function_to_call = available_functions[function_name]
                        function_arg_string = tool_call["function"]["arguments"]
                        st.session_state.function_queue.append(
                            {
                                "tool_call_id": tool_call["id"],
                                "role": "tool",
                                "name": function_name,
                                "func": function_to_call,
                                "args": function_arg_string,
                                "status": "pending",
                                "content": "",
                            }
                        )
                    for tool_call in st.session_state.function_queue:
                        if tool_call["status"] == "pending":
                            try:
                                function_args = json.loads(function_arg_string)
                                function_response = function_args_to_streamlit_ui(function_to_call, function_args, tool_call["tool_call_id"])
                                print(function_response)
                                if function_response:
                                    st.session_state.messages.append(
                                        {
                                            "tool_call_id": tool_call["tool_call_id"],
                                            "role": "tool",
                                            "name": function_name,
                                            "content": function_response,
                                        }
                                    )
                                    tool_call["status"] = "done"
                            except Exception as e:
                                print(f"The error is:\n{e}")
                                st.session_state.messages.append(
                                    {
                                        "tool_call_id": tool_call["tool_call_id"],
                                        "role": "tool",
                                        "name": function_name,
                                        "content": f"error: {e}",
                                    }
                                )
                                tool_call["status"] = "done"
                if function_response:
                    for response in client.chat.completions.create(
                        model=st.session_state["openai_model"],
@@ -275,7 +371,6 @@ if prompt := st.chat_input("What is up?"):
                    )

                message_placeholder.markdown(full_response)
            # print(st.session_state.messages)
uploaded_file = st.sidebar.file_uploader("Upload PDB file", type=["pdb"])

if uploaded_file:
+0 −10
Original line number Diff line number Diff line
@@ -221,8 +221,6 @@ test_data = {
    }
}

import types


def predict_rna_secondary_structure(self, rna_seq: str):
    from seqfold import fold, dg, dot_bracket
@@ -235,14 +233,8 @@ def predict_rna_secondary_structure(self, rna_seq: str):
    Returns:
    - A dictionary with the minimum free energy and the dot-bracket representation of the structure.
    """

    # Calculate the minimum free energy (MFE) using seqfold
    mfe = dg(rna_seq)

    # Get the structural details using seqfold
    structs = fold(rna_seq)

    # Get the dot-bracket representation of the structure
    dot_bracket_structure = dot_bracket(rna_seq, structs)

    return f"Minimum Free Energy (MFE): `{mfe}`\nDot-Bracket Structure: `{dot_bracket_structure}`"
@@ -270,8 +262,6 @@ test_data["predict_rna_secondary_structure"] = {
        "rna_seq": "GGGAGGTCGTTACATCTGGGTAACACCGGTACTGATCCGGTGACCTCCC",
    },
    "output": {"Minimum Free Energy (MFE): `-13.4`\nDot-Bracket Structure: `((((((((.((((......))))..((((.......)))).))))))))`"
        # "Minimum Free Energy (MFE)": -13.4,
        # "Dot-Bracket Structure": "((((((((.((((......))))..((((.......)))).))))))))"
    },
}

+30 −0
Original line number Diff line number Diff line
import requests
import inspect
import functools
import streamlit as st
from openai import OpenAI

def handle_file_not_found_error(func):
@@ -11,6 +13,34 @@ def handle_file_not_found_error(func):
            return f"Current working directory is: {args[0].get_work_dir()}"
    return wrapper

def function_args_to_streamlit_ui(func, args=None, tool_call_id=None):
    # clicked = False
    signature = inspect.signature(func)
    docstring = inspect.getdoc(func)
    if docstring:
        st.write(docstring)
    args_values = {}
    for name, param in signature.parameters.items():
        if param.annotation is str:
            if name == "seq":
                args_values[name] = st.text_area(name, key=f"{tool_call_id}_{name}", value=args[name] if name in args else None)
            else:
                args_values[name] = st.text_input(name, key=f"{tool_call_id}_{name}", value=args[name] if name in args else None)
        elif param.annotation is int:
            args_values[name] = st.number_input(name, key=f"{tool_call_id}_{name}", value=args[name] if name in args else None)
        else:
            args_values[name] = st.text_input(name, key=f"{tool_call_id}_{name}", value=args[name] if name in args else None)
    # while not clicked:
    if st.button('Submit'):
            # clicked = True
        print(args_values)
        result = func(**args_values)
        st.write(result)
        return result
        # else:
        #     st.write("Waiting for submission...")
            

def test_openai_api(api_key):
    client = OpenAI(api_key=api_key)
    try: