Unverified Commit 57307452 authored by jinyuan sun's avatar jinyuan sun Committed by GitHub
Browse files

Merge pull request #35 from JinyuanSun/dev02

Dev02
parents a632a38c 82cdbce3
Loading
Loading
Loading
Loading
+12 −18
Original line number Original line Diff line number Diff line
import json
import json
import requests
import requests
# import pymol
import os
import os
# import pandas as pd
# import openai
# import time
# import types
# import urllib.parse
import pprint
import pprint
# from datetime import datetime
# from pymol import cmd
# from urllib.parse import quote



os.environ["REGISTRY_HOST_PORT"]="100.89.180.132:9999"
os.environ["REGISTRY_HOST_PORT"]="100.89.180.132:9999"
api_key = os.getenv("OPENAI_API_KEY")
api_key = os.getenv("OPENAI_API_KEY")
app_pw = os.getenv("GMAIL_APP_PASSWORD")
app_pw = os.getenv("GMAIL_APP_PASSWORD")
registry_host_port = os.getenv("REGISTRY_HOST_PORT")
registry_host_port = os.getenv("REGISTRY_HOST_PORT")


#os.environ["REGISTRY_HOST_PORT"]="100.89.180.132:9999"

# Search for pymol service endpoint
# print("Registry Host Port = ",registry_host_port)
# print(requests.get("http://"+registry_host_port+"/registry"))
# print("END")
registry = requests.get("http://"+registry_host_port+"/registry").json()
registry = requests.get("http://"+registry_host_port+"/registry").json()
pymol_endpoint = ""
pymol_endpoint = ""
for key in registry.keys():
for key in registry.keys():
@@ -276,6 +260,7 @@ def get_default_value(param_desc):
    
    
# Dynamic generation of new python functions from registry
# Dynamic generation of new python functions from registry
def func_code_gen(registry):
def func_code_gen(registry):
    molgen_funcs = ['DenovoGen','MotifExtend','SuperStructure','ScaffoldMorphine','LinkerGen']
    for key in registry.keys():
    for key in registry.keys():
        r = registry[key]
        r = registry[key]
        param_desc_str = r['param_desc'].replace("'",'"')
        param_desc_str = r['param_desc'].replace("'",'"')
@@ -304,6 +289,7 @@ def func_code_gen(registry):
        func_arg_str = ",".join(func_args)
        func_arg_str = ",".join(func_args)
        print("func_arg_str = ", func_arg_str)
        print("func_arg_str = ", func_arg_str)
        dec_func.append("def "+r['service_name']+"(self," + func_arg_str +"):")
        dec_func.append("def "+r['service_name']+"(self," + func_arg_str +"):")
        dec_func.append("    from chatmol_fn import redis_writer, redis_reader")
        dec_func.append(f"    print('{service_name} is called')")       
        dec_func.append(f"    print('{service_name} is called')")       
        dec_func.append("    param_dict = {}")
        dec_func.append("    param_dict = {}")
        for param_name in param_desc.keys():
        for param_name in param_desc.keys():
@@ -315,9 +301,17 @@ def func_code_gen(registry):
        dec_func.append("    #Call the generic FastAPI")
        dec_func.append("    #Call the generic FastAPI")
        code = f"    messages = call_fastapi('{service_name}', param_dict)"
        code = f"    messages = call_fastapi('{service_name}', param_dict)"
        dec_func.append(code)
        dec_func.append(code)
        dec_func.append("    return 'The results are: ' + str(messages)")
        dec_func.append("    message = ''")

        if (service_name in molgen_funcs):
            # create smiles_key = service_name_smiles
            dec_func.append(f"    smiles_key = '{service_name}' + '_smiles'")
            dec_func.append(f"    redis_writer(smiles_key, messages)")
            #dec_func.append("    message = 'Generated ' + str(len(messages)) + 'molecules\n'")
            dec_func.append(f"    message = 'Generated ' + str(len(messages)) + ' molecules. '")
            dec_func.append(f"    message += 'Generated SMILES list is save to redis cache with smiles_key: ' + smiles_key ") 
        dec_func.append("    return message + ' The results are: ' + str(messages)")
        code_str = "\n".join(dec_func)
        code_str = "\n".join(dec_func)
        print(code_str)
        func_code_dict[service_name] = code_str
        func_code_dict[service_name] = code_str
        
        
    return func_code_dict
    return func_code_dict
+14 −1
Original line number Original line Diff line number Diff line
@@ -279,8 +279,20 @@ class ConversationHandler:
                    },
                    },
                    "required": ["pdb_file1", "pdb_file2"],
                    "required": ["pdb_file1", "pdb_file2"],
                }
                }
            },
            { "type": "function",
                "function": {
                    "name": "python_executer",
                    "description": "Python executer creates a python function from python code (string), and execute it.",
                     "parameters": {
                        "type": "object",
                        "properties": {
                            "function_name": {"type": "string", "description": "The python funciton name"},
                        },
                    },
                    "required": ["function_name"],                   
                }
            }
            }
            
            
            
        ]
        ]
        self.available_functions = {
        self.available_functions = {
@@ -299,6 +311,7 @@ class ConversationHandler:
            "blind_docking": self.cfn.blind_docking,
            "blind_docking": self.cfn.blind_docking,
            "call_proteinmpnn_api": self.cfn.call_proteinmpnn_api,
            "call_proteinmpnn_api": self.cfn.call_proteinmpnn_api,
            "compare_protein_structures": self.cfn.compare_protein_structures,
            "compare_protein_structures": self.cfn.compare_protein_structures,
            "python_executer": self.cfn.python_executer,
        }
        }


    def setup_workdir(self, work_dir):
    def setup_workdir(self, work_dir):
+100 −0
Original line number Original line Diff line number Diff line
import json
import json
import requests
import requests
import py3Dmol
import py3Dmol
import pickle
import redis
from tqdm import tqdm
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from cloudmol.cloudmol import PymolFold
from cloudmol.cloudmol import PymolFold
from utils import query_pythia, handle_file_not_found_error
from utils import query_pythia, handle_file_not_found_error
import os
import os
import re
from io import StringIO
from io import StringIO
from stmol import showmol
from stmol import showmol
from rdkit import Chem
from rdkit import Chem
@@ -101,6 +104,41 @@ def save_best_docking_result(docking_code ,file_path):
        #     f.write(receptor.read())
        #     f.write(receptor.read())
    return f"Docking result saved as {file_path}"
    return f"Docking result saved as {file_path}"


def redis_writer(key, data):
    # Serialize the Python object using pickle
    r = redis.Redis(host='localhost', port=6379, db=0)  
    # First, check redis service
    try: 
        r.ping()
    except:
        print("Redis is out of service")
        return None
    
    # Save the serialized object to Redis
    try:
        serialized_data = pickle.dumps(data)
        r.set(key, serialized_data)
    except:
        print("Error in redis_writer")

def redis_reader(key):
    # Retrieve the serialized object from Redis
    r = redis.Redis(host='localhost', port=6379, db=0)
    # First, check redis service
    try: 
        r.ping()
    except:
        print("Redis is out of service")
        return None
    try:
        retrieved_data = r.get(key)
        # Deserialize the object using pickle
        data = pickle.loads(retrieved_data)
        return data
    except:
        print("Error in redis_reader. Check the key!")
        return None

class ChatmolFN:
class ChatmolFN:
    def __init__(self, work_dir="./"):
    def __init__(self, work_dir="./"):
        self.WORK_DIR = "./"
        self.WORK_DIR = "./"
@@ -271,6 +309,7 @@ class ChatmolFN:
        # # print(log)
        # # print(log)
        # return res_df.to_string()
        # return res_df.to_string()



    @handle_file_not_found_error
    @handle_file_not_found_error
    def call_proteinmpnn_api(
    def call_proteinmpnn_api(
            self,
            self,
@@ -409,3 +448,64 @@ class ChatmolFN:
        writer.write(mol)
        writer.write(mol)
        writer.close()
        writer.close()
        return f"The conformation of {smiles} is saved as {file_name}"
        return f"The conformation of {smiles} is saved as {file_name}"
    
    def python_executer(self, function_name):
        # Guardrails to prevent dangerous code execution
        print("function_name ------------------------------------------------------------- =", function_name)
        work_dir = self.get_work_dir()
        print("work_dir = ", work_dir)
        file_path = "./Project-X/.history"
        text = "Test"
        try:
            with open(file_path, 'rb') as f:
                binary_data = f.read()
                try:
                    text = binary_data.decode('utf-8')
                except UnicodeDecodeError:
                    text = binary_data.decode('latin-1')  # or another appropriate encoding
            print("TEXT = ", text)
        except Exception as e:
            print(f"An error occurred: {e}")
        # Find all function definitions using a regex
        functions = re.findall(r'def\s+\w+\s*\(.*?\):\n(?:\s+.*\n)*', text)
        # Return the last function definition
        function_code = functions[-1] if functions else None

        print("function code ---------------------------------------")
        print(function_code)
        print("Function code ends heer -----------------------------")
        dangerous_keywords = [
            'exec', 'eval', 'import os', 'import subprocess', 'os.system', 'subprocess.call',
            'subprocess.Popen', 'compile', 'builtins', '__import__', 'globals', 'locals',#'open'
        ]
        # Check if the function code contains any dangerous keywords
        for keyword in dangerous_keywords:
            if re.search(r'\b' + re.escape(keyword) + r'\b', function_code):
                raise ValueError(f"Use of dangerous keyword '{keyword}' detected in the function code.")
            
        # Regular expression to extract the function name
        pattern = r'\bdef\s+(\w+)\s*\(.*\)\s*:'

        # Find all matches in the code
        matches = re.findall(pattern, function_code)
        function_name = "NoName"
        if (len(matches) < 1):
            raise ValueError(f"No function is defined")
        function_name = matches[0]
        print("Function_Name = ", function_name)
        # Compile the function code
        code_obj = compile(function_code, function_name, 'exec')

        # Prepare a restricted execution environment
        exec_env = {}
        print("Function_Name 3 = ", function_name)
        # Execute the compiled code object in the restricted namespace
        exec(code_obj, exec_env)
        # Retrieve the function from the restricted environment
        service_func = exec_env.get(function_name)
        # Ensure the function exists
        if service_func is None:
            raise ValueError(f"Function '{function_name}' is not defined in the provided code.")
        # Call the function with the provided parameters
        results = service_func()
        return "This is the filtered molecules: "+ str(results)
+6 −1
Original line number Original line Diff line number Diff line
@@ -43,7 +43,7 @@ st.sidebar.write("2024 May 14 public version")
st.sidebar.write(
st.sidebar.write(
    "ChatMol copilot is a AI platform for protein engineering, molecular design and computation. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
    "ChatMol copilot is a AI platform for protein engineering, molecular design and computation. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
)
)
st.write("Enjoy modeling proteins with ChatMol copilot! 🤖️ 🧬")
st.write("The LLM Powered Agent for Protein Modeling and Molecular Computation 🤖️ 🧬")
float_init()
float_init()


st.sidebar.title("Settings")
st.sidebar.title("Settings")
@@ -338,6 +338,10 @@ if prompt := st.chat_input("What is up?"):
                    "content": full_response,
                    "content": full_response,
                }
                }
            )
            )
            print("Debug: Full response from assistant", full_response)
            with open(f"{work_dir}/workspace", "a") as f:
                f.write(full_response)

            if tool_call:
            if tool_call:
                st.session_state.messages = st.session_state.messages[:-1]
                st.session_state.messages = st.session_state.messages[:-1]
                if st.session_state.openai_model == "glm-4":
                if st.session_state.openai_model == "glm-4":
@@ -442,6 +446,7 @@ if prompt := st.chat_input("What is up?"):
                            "content": full_response,
                            "content": full_response,
                        }
                        }
                    )
                    )
                    print("Debug: Full response from tool calling", full_response)


                message_placeholder.markdown(full_response)
                message_placeholder.markdown(full_response)
uploaded_file = st.sidebar.file_uploader("Upload PDB file", type=["pdb"])
uploaded_file = st.sidebar.file_uploader("Upload PDB file", type=["pdb"])
+4 −7
Original line number Original line Diff line number Diff line
import types
import types
import os
import os
import re
import requests
import requests
import urllib.parse
import urllib.parse
from enum import Enum
from enum import Enum
from build_from_registry import *
from build_from_registry import *


registry_host_port = os.getenv("registry_host_port")
registry_host_port = os.getenv("REGISTRY_HOST_PORT")
print("registry_host_PORT  ", registry_host_port)
print("registry_host_PORT  ", registry_host_port)


# func_code_dict
# func_sche_dict

test_data = []
test_data = []


def call_fastapi(service: str, params={}):
def call_fastapi(service: str, params={}):
@@ -20,7 +18,6 @@ def call_fastapi(service: str, params={}):
    print("1 input service name = ", service)
    print("1 input service name = ", service)
    print("2 Get all registerred services")
    print("2 Get all registerred services")
    registry = requests.get("http://"+registry_host_port+"/registry").json()
    registry = requests.get("http://"+registry_host_port+"/registry").json()
    #registry = requests.get("http://100.89.180.132:9999/registry").json()
    print("3 Length of Registry = ",len(registry))
    print("3 Length of Registry = ",len(registry))


    # Search for service endpoint by service name
    # Search for service endpoint by service name
@@ -52,6 +49,7 @@ def call_fastapi(service: str, params={}):


func_code_list = []
func_code_list = []
func_sche_list = []
func_sche_list = []
func_list = []


for key in func_code_dict.keys():
for key in func_code_dict.keys():
    sche = {"type": "function", "function": func_sche_dict[key]}
    sche = {"type": "function", "function": func_sche_dict[key]}
@@ -59,8 +57,7 @@ for key in func_code_dict.keys():
    func_sche_list.append(sche)
    func_sche_list.append(sche)




func_list = []
for i in range(len(func_code_list)):
for i in range(len(func_sche_list)):
    code = func_code_list[i]
    code = func_code_list[i]
    sche = func_sche_list[i]
    sche = func_sche_list[i]
    service_name = sche['function']['name']
    service_name = sche['function']['name']
Loading