Commit ffa5539b authored by JinyuanSun's avatar JinyuanSun
Browse files

small fix

parent ccae1803
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
import json
import os
import rdkit
import time
import pandas as pd
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
    Function,
@@ -13,7 +9,7 @@ import chatmol_fn as cfn


class ConversationHandler:
    def __init__(self, client, cfn, model_name="gpt-3.5-turbo-1106"):
    def __init__(self, client, cfn, model_name="gpt-4o"):
        self.client = client
        self.model_name = model_name
        self.cfn = cfn
+4 −1
Original line number Diff line number Diff line
@@ -76,6 +76,7 @@ def submit_docking_task(protein_file, ligand_file, center_x=0, center_y=0, cente
    if aa_list:
        files['aa_list'] = (None, aa_list)
    response = requests.post('https://dockingapi.cloudmol.org/api/dock', headers=headers, files=files)
    # breakpoint()
    return response.json()

def submit_pocket_prediction_task(protein_file):
@@ -142,6 +143,7 @@ def redis_reader(key):
class ChatmolFN:
    def __init__(self, work_dir="./"):
        self.WORK_DIR = "./"
        print("WORK_DIR = ", self.WORK_DIR)
        self.STREAMLIT_GUI = True
        self.VIEW_DICTS = {}
        self.viewer_height = 300
@@ -280,8 +282,9 @@ class ChatmolFN:
        docking_code = submit_docking_task(protein_pdb_file_path, ligand_pdb_file_path, aa_list=pocket_aas)
        docking_code = docking_code['hash_code']
        print(docking_code)
        time.sleep(15)
        status = ""
        status_prev = ''

        print("status:")
        while status != '"completed"':
            status = query_docking_status(docking_code)
+36 −10
Original line number Diff line number Diff line
@@ -4,13 +4,13 @@ import streamlit as st
import chatmol_fn as cfn_
from stmol import showmol
from streamlit_float import *
from viewer_utils import show_pdb #, update_view
from utils import test_openai_api, function_args_to_streamlit_ui
from streamlit_molstar import st_molstar #, st_molstar_rcsb, st_molstar_remote
from viewer_utils import show_pdb
from utils import test_openai_api, function_args_to_streamlit_ui, test_ds_api
from streamlit_molstar import st_molstar
import hashlib
import new_function_template
import shutil
from chat_helper import ConversationHandler, compose_chat_completion_message #, extract_function_and_execute
from chat_helper import ConversationHandler, compose_chat_completion_message
import os
import json
import pickle
@@ -39,9 +39,9 @@ else:
    st.session_state["cfn"] = cfn

st.title("ChatMol Copilot", anchor="center")
st.sidebar.write("2024 May 14 public version")
st.sidebar.write("2025 Mar 11 public version")
st.sidebar.write(
    "ChatMol copilot is a AI platform for protein engineering, molecular design and computation. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
    "ChatMol copilot is a AI platform for protein engineering, molecular design and computation. Also chekcout our [GitHub](https://github.com/ChatMol/ChatMol)."
)
st.write("The LLM Powered Agent for Protein Modeling and Molecular Computation 🤖️ 🧬")
float_init()
@@ -65,7 +65,7 @@ if project_id + str(openai_api_key) == "Project-X":

model = st.sidebar.selectbox(
    "Model",
    ["gpt-3.5-turbo",  "gpt-4o",  "gpt-4-turbo", "gpt-4"],
    ["gpt-4o",  "gpt-4o-mini",  "deepseek-chat"],
)
st.session_state["openai_model"] = model

@@ -86,6 +86,27 @@ if st.session_state["openai_model"].startswith("gpt"):
        api_key_test = test_openai_api(openai_api_key)
        st.session_state["api_key"] = api_key_test

elif st.session_state["openai_model"].startswith("deepseek"):
    openai_api_key = os.getenv("DEEPSEEK_API_KEY", openai_api_key)
    if "api_key" in st.session_state:
        api_key_test = st.session_state["api_key"]
        if st.session_state.api_key is False:
            # api_key_test = True
            api_key_test = test_ds_api(openai_api_key)
            st.session_state.api_key = api_key_test
            if api_key_test is False:
                st.warning(
                    "The provided DeepSeek API key seems to be invalid. Please check again. If you don't have an DeepSeek API key, please visit https://api-docs.deepseek.com/ to get one."
                )
                st.stop()
    else:
        # api_key_test = True
        api_key_test = test_ds_api(openai_api_key)
        st.session_state["api_key"] = api_key_test
else:
    st.warning("Please select a valid model.")
    st.stop()

m.update((openai_api_key + project_id).encode())
hash_string = m.hexdigest()

@@ -117,11 +138,10 @@ if os.path.exists(f"{work_dir}/.history"):
    with open(f"{work_dir}/.history", "rb") as f:
        st.session_state.messages = pickle.load(f)

if model != "glm-4":
if model.startswith("gpt"):
    client = OpenAI(api_key=openai_api_key)
else:
    from zhipuai import ZhipuAI
    client = ZhipuAI()
    client = OpenAI(api_key=openai_api_key, base_url="https://api.deepseek.com")

conversation = ConversationHandler(client, cfn, model_name=model)

@@ -246,6 +266,11 @@ with chatcol:
                function_name = tool_call["name"]
                function_to_call = tool_call["func"]
                function_args = json.loads(tool_call["args"])
            except Exception as e:
                print(f"The error is:\n{e}")
                print("Error in function_args")

            try:
                function_response = function_to_call(**function_args)
                if function_response:
                    st.session_state.messages.append(
@@ -260,6 +285,7 @@ with chatcol:
                    
            except Exception as e:
                print(f"The error is:\n{e}")
                print("Error in function_response")
                st.session_state.messages.append(
                    {
                        "tool_call_id": tool_call["tool_call_id"],
+16 −1
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ def test_openai_api(api_key):
    client = OpenAI(api_key=api_key)
    try:
        response = client.chat.completions.create(
                    model="gpt-3.5-turbo-1106",
                    model="gpt-4o-mini",
                    messages=[{"role": "user", "content": "Test prompt"}],
                    max_tokens=10,
                )
@@ -56,6 +56,21 @@ def test_openai_api(api_key):
        print("OpenAI API is not working")
        return False
    
def test_ds_api(api_key):
    client = OpenAI(api_key=api_key)
    try:
        response = client.chat.completions.create(
                    model="deepseek-chat",
                    messages=[{"role": "user", "content": "Test prompt"}],
                    max_tokens=10,
                )
        print(response)
        return True

    except Exception as e:
        print("DeepSeek API is not working")
        return False


def query_pythia(pdb_file):
    try: