Unverified Commit 9d7616d6 authored by jinyuan sun's avatar jinyuan sun Committed by GitHub
Browse files

Merge pull request #31 from JinyuanSun/dev2405

Add new functions to tool calling from ChatMol-OS registry
parents e240a407 2fc72cbc
Loading
Loading
Loading
Loading
+329 −0
Original line number Diff line number Diff line
import json
import requests
import pymol
import os
import pandas as pd
import openai
import time
import types
import urllib.parse
import pprint
from datetime import datetime
from pymol import cmd
from urllib.parse import quote


os.environ["REGISTRY_HOST_PORT"]="100.89.180.132:9999"
api_key = os.getenv("OPENAI_API_KEY")
app_pw = os.getenv("GMAIL_APP_PASSWORD")
registry_host_port = os.getenv("REGISTRY_HOST_PORT")

#os.environ["REGISTRY_HOST_PORT"]="100.89.180.132:9999"

# Search for pymol service endpoint
registry = requests.get("http://"+registry_host_port+"/registry").json()
pymol_endpoint = ""
for key in registry.keys():
    r = registry[key]
    if (r['service_name'] == 'PyMOL'):
        pymol_endpoint = r['endpoint']
        print("Endpoint PyMOL = ", pymol_endpoint)     
try:
    response = requests.post("http://"+pymol_endpoint)
    pprint.pprint(response)
except:
    print("Something is wrong")

GPT_MODEL = "gpt-3.5-turbo"
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

service_ip = pymol_endpoint.split(":")[0]
print("PyMOL service IP = ",service_ip)

# dictionaries for function schema, code and funcs
func_sche_dict = {}
func_code_dict = {}
func_comp_dict = {}

code_str = """
def load_protein_into_pymol(arguments):
    param_obj = json.loads(arguments)
    obj_id = param_obj.get('protein_pdb_id','')
    obj_id = obj_id.lower()
    loaded_objects = cmd.get_names('objects')
    print("Existing Objects: ", loaded_objects)
    if (obj_id in loaded_objects):
        print("Object "+obj_id+" already loaded")
    else:
        cmd.fetch(obj_id)
        print("Send this cmd to PyMOL:","fetch "+obj_id)
        command = f"fetch {obj_id};"
        #response = requests.post('http://localhost:8101/send_message', data=command)
        print("command",command)
        response = requests.post('http://'+service_ip+':8101/send_message', data=command)
        loaded_objects = cmd.get_names('objects')
        print("Loaded Objects: ", loaded_objects)    
"""
    
func_code_dict['load_protein_into_pymol'] = code_str
func_sche_dict['load_protein_into_pymol'] =     {
        "name": "load_protein_into_pymol",
        "description": "fetch a protein from PDB",
        "parameters": {
            "type": "object",
            "properties": {
                "protein_pdb_id": {
                    "type": "string",
                    "description": "PDB ID of a protein. It is usually a four character code",
                },
            },
            "required": ["protein_pdb_id"],
        },
    }

code_str = """ 
def remove_object_from_3D_view(arguments):
    param_obj = json.loads(arguments)
    obj_id = param_obj.get('object_id','')
    obj_id = obj_id.lower()
    loaded_objects = cmd.get_names('objects')
    print("Existing Objects: ", loaded_objects)
    if (obj_id in loaded_objects):
        print("Delete object "+obj_id+"!")
        print("Send this cmd to PyMOL:","delete "+obj_id)
        command = f"delete {obj_id};"
        response = requests.post('http://'+service_ip+':8101/send_message', data=command)
        cmd.delete(obj_id)
    else:
        print("Objects: "+obj_id+ " does not exist.")
"""
        
func_code_dict['remove_object_from_3D_view'] = code_str
func_sche_dict['remove_object_from_3D_view'] = {
        "name": "remove_object_from_3D_view",
        "description": "delete an object from 3D view of PyMOL",
        "parameters": {
            "type": "object",
            "properties": {
                "object_id": {
                    "type": "string",
                    "description": "The object name in a 3D view of PyMOL",
                },
            },
            "required": ["object_id"],
        },
    }

 
code_str = """
def color_protein_chains(arguments):
    param_obj = json.loads(arguments)
    obj_id = param_obj.get('protein_pdb_id','')
    obj_id = obj_id.lower()
    colors = ['red','blue','yellow','purple','magenta','brown']
    
    loaded_objects = cmd.get_names('objects')
    print("Existing Objects: ", loaded_objects)
    
    if (obj_id in loaded_objects):
        print("Object "+obj_id+" already loaded")
    else:
        cmd.fetch(obj_id)
        print("Send this cmd to PyMOL:","fetch "+obj_id)
        command = f"fetch {obj_id};"
        response = requests.post('http://'+service_ip+':8101/send_message', data=command)
        loaded_objects = cmd.get_names('objects')
        print("Loaded Objects: ", loaded_objects)
        if (obj_id not in loaded_objects):
            print("Protein "+obj_id+" does not exist")

    # Get the list of chains in the structure
    chains = cmd.get_chains(obj_id)
    if (len(chains) == 0):
        print("Not chain exist for protein "+obj_id)

    # You can also iterate over the chains and perform operations
    commands = ""
    for i in range(len(chains)):
        i6 = i%6
        chain = chains[i]
        cmd_str = f" color {colors[i6]}, chain {chain};"
        commands += cmd_str
        print(f" color {colors[i6]}, chain {chain};")
        cmd.color(colors[i6], "chain "+chain)
    print("Send the following command to PyMol")
    print(commands)
    response = requests.post('http://'+service_ip+':8101/send_message', data=commands)
"""

func_code_dict['color_protein_chains'] = code_str
func_sche_dict['color_protein_chains'] = {
        "name": "color_protein_chains",
        "description": "a protein and display its chains with different colors",
        "parameters": {
            "type": "object",
            "properties": {
                "protein_pdb_id": {
                    "type": "string",
                    "description": "PDB ID of a protein. It is usually a four character code",
                },
            },
            "required": ["protein_pdb_id"],
        },
    } 

code_str = """
def clear_objects(arguments):
    print("Clear all object in PyMol view")
    loaded_objects = cmd.get_names('objects')
    commands = ""
    for obj in loaded_objects:
        cmd.delete(obj)
        commands += f"delete {obj}; "
    print("send the following command to PyMol")
    print(commands)
    response = requests.post('http://'+service_ip+':8101/send_message', data=commands)
"""

func_code_dict['clear_objects'] = code_str
func_sche_dict['clear_objects'] = {
        "name": "clear_objects",
        "description": "clear all objects if no specific object name is specified",
        "parameters": {
            "type": "object",
            "properties": {
                "protein_pdb_id": {
                    "type": "string",
                    "description": "PDB ID of a protein. It is usually a four character code",
                },
            },
        },
    }
   

# Create Funcation Calling Schema for OpenAI function calling
def func_schema_gen(registry):
    for key in registry.keys():
        r = registry[key]
        #print(r)
        func_schema = {}
        func_schema['name'] = r['service_name']
        func_schema['description'] = r['description']
        params = {}
        props ={}
        params['type'] = 'object'
        params['properties'] = props
        required = []
        param_desc_str = r['param_desc'].replace("'",'"')
        print("desc = ",param_desc_str)
        param_desc = json.loads(param_desc_str)
        for param_name in param_desc.keys():
            if (param_name == "output_sdf"):
                continue
            p_desc = param_desc[param_name]
            if (p_desc.find('Optional') == -1 and p_desc.find('optional') == -1):
                required.append(param_name)
            props[param_name] = {'type':"string","description":param_desc[param_name]}
        params['reguired'] = required
        func_schema['parameters'] = params
        func_sche_dict[r['service_name']] = func_schema
        pprint.pprint(func_schema)
    return func_sche_dict

# Get the default value for optional parameters
def get_default_value(param_desc):
    if (param_desc.find("Optional") < 0 and param_desc.find("optional") < 0):
        return None
    default_value = None
    try:
        if (param_desc.find("Default")> -1 or param_desc.find("default")> -1):
            param_desc = param_desc.replace("  "," ")
            words = param_desc.split(" ")
            for i in range(len(words)):
                w = words[i]
                if (w.find("Default") > -1 or w.find("default") > -1):
                    default_value = words[i+1]
                    if (default_value[-1] == "," or default_value[-1] == "."):
                        default_value = default_value[:-1]
                    break
    except:
        pass
    return default_value
                
    
# Dynamic generation of new python functions from registry
def func_code_gen(registry):
    for key in registry.keys():
        r = registry[key]
        param_desc_str = r['param_desc'].replace("'",'"')
        param_desc = json.loads(param_desc_str)
        param_names = list(param_desc.keys())
        func_args = ",".join(param_names)
        dec_func=[]
        service_name = r['service_name']
        
        reg_args = []
        opt_args = []

        # Check for optinal parameters with default values
        for param_name in param_desc.keys():
            if (param_name == "output_sdf"):
                continue
            desc = param_desc[param_name]
            default_value = get_default_value(desc)
            if (default_value):
                opt_args.append(f"{param_name}={default_value}")
                #opt_args.append(f"{param_name}")
            else:
                reg_args.append(param_name)
        func_args = reg_args + opt_args
        print("func_args = ", func_args)
        func_arg_str = ",".join(func_args)
        print("func_arg_str = ", func_arg_str)
        dec_func.append("def "+r['service_name']+"(self," + func_arg_str +"):")
        dec_func.append(f"    print('{service_name} is called')")       
        dec_func.append("    param_dict = {}")
        for param_name in param_desc.keys():
            if (param_name == 'output_sdf'):
                code = f"    param_dict['{param_name}'] = 'testout.sdf'"
            else:
                code = f"    param_dict['{param_name}'] = {param_name}"
            dec_func.append(code)
        dec_func.append("    #Call the generic FastAPI")
        code = f"    messages = call_fastapi('{service_name}', param_dict)"
        dec_func.append(code)
        dec_func.append("    return 'The results are: ' + str(messages)")

        code_str = "\n".join(dec_func)
        func_code_dict[service_name] = code_str
        
    return func_code_dict

func_schema_gen(registry)
func_code_dict = func_code_gen(registry)
print("D func_sche_dict", func_sche_dict.keys())
print("D func_code_dict", func_code_dict.keys())

for key in func_code_dict.keys():
    pprint.pprint(func_code_dict[key])
    pprint.pprint(func_sche_dict[key])
+21 −4
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from viewer_utils import show_pdb, update_view
from utils import test_openai_api, function_args_to_streamlit_ui
from streamlit_molstar import st_molstar, st_molstar_rcsb, st_molstar_remote
import new_function_template
import new_function_registry
import shutil
from chat_helper import ConversationHandler, compose_chat_completion_message
import os
@@ -33,10 +34,10 @@ else:
    cfn = cfn.ChatmolFN()
    st.session_state["cfn"] = cfn

st.title("ChatMol copilot", anchor="center")
st.title("ChatMol Copilot", anchor="center")
st.sidebar.write("2024 Jan 05 public version")
st.sidebar.write(
    "ChatMol copilot is a copilot for protein engineering. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
    "ChatMol copilot is a AI platform for protein engineering, molecular design and computation. Also chekcout our [GitHub](https://github.com/JinyuanSun/ChatMol)."
)
st.write("Enjoy modeling proteins with ChatMol copilot! 🤖️ 🧬")
float_init()
@@ -137,7 +138,7 @@ if add_from_template := st.sidebar.checkbox("Add from template"):
    test_data = new_function_template.test_data
    for description, new_func in zip(descriptions, new_funcs):
        try:
            test_results = new_function_template.test_new_function(new_func, description['function']['name'], test_data)
            #test_results = new_function_template.test_new_function(new_func, description['function']['name'], test_data)
            conversation.tools.append(description)
            conversation.available_functions[description['function']['name']] = new_func.__get__(cfn)
            if description['function']['name'] not in st.session_state.new_added_functions:
@@ -146,6 +147,23 @@ if add_from_template := st.sidebar.checkbox("Add from template"):
        except Exception as e:
            st.warning(f"Failed to add function from template. Error: {e}")

if add_from_registry := st.sidebar.checkbox("Add from registry"):
    function_info = new_function_registry.get_info()
    descriptions = function_info['descriptions']
    new_funcs = function_info['functions']
    test_data = new_function_registry.test_data
    for description, new_func in zip(descriptions, new_funcs):
        try:
            #test_results = new_function_registry.test_new_function(new_func, description['function']['name'], test_data)
            conversation.tools.append(description)
            conversation.available_functions[description['function']['name']] = new_func.__get__(cfn)
            if description['function']['name'] not in st.session_state.new_added_functions:
                st.sidebar.success(f"Function `{description['function']['name']}` added successfully.")
                st.session_state.new_added_functions.append(description['function']['name'])
        except Exception as e:
            st.warning(f"Failed to add function from template. Error: {e}")
    print("Add from registry is on")

available_functions = conversation.available_functions
available_tools = conversation.tools
if "openai_model" not in st.session_state:
@@ -289,7 +307,6 @@ if prompt := st.chat_input("What is up?"):
                        function_to_call = available_functions[function_name]
                        try:
                            function_args = json.loads(tool_call["function"]["arguments"])

                            function_response = function_to_call(**function_args)
                            if function_response:
                                st.session_state.messages.append(
+93 −0
Original line number Diff line number Diff line
import types
import os
import requests
import urllib.parse
from enum import Enum
from build_from_registry import *

registry_host_port = os.getenv("registry_host_port")
print("registry_host_PORT  ", registry_host_port)

# func_code_dict
# func_sche_dict

test_data = []

def call_fastapi(service: str, params={}):
    # call fastapi endpoint
    # From service, find the available endpoint
    global registry_host_port
    print("1 input service name = ", service)
    print("2 Get all registerred services")
    registry = requests.get("http://"+registry_host_port+"/registry").json()
    #registry = requests.get("http://100.89.180.132:9999/registry").json()
    print("3 Length of Registry = ",len(registry))

    # Search for service endpoint by service name
    service_endpoint = ""
    status = "Inactive"
    for key in registry.keys():
        r = registry[key]
        print("Registry = ",r)
        if (r['service_name'] == service):
            service_endpoint = r['endpoint']
            status = r['status']
    if (service_endpoint == ""):
        print("Service not found")
        return "Service not found"
    if  (service_endpoint != "" and status== "Inactive"):
        print("Service is inactive")
        return "Service is inactive"

    # Now, constructure encoded_url for FastAPI call

    if (len(params) == 0):
        # just check the status 
        url = url = "http://" + service_endpoint + "/status"
        message = requests.get(url).json()
    else:
        url = "http://" + service_endpoint+"?" + urllib.parse.urlencode(params)
        message  = requests.post(url).json()
    return message

func_code_list = []
func_sche_list = []

for key in func_code_dict.keys():
    sche = {"type": "function", "function": func_sche_dict[key]}
    func_code_list.append(func_code_dict[key])
    func_sche_list.append(sche)


func_list = []
for i in range(len(func_sche_list)):
    code = func_code_list[i]
    sche = func_sche_list[i]
    service_name = sche['function']['name']
    code_obj = compile(code, service_name, 'exec')
    # Execute the compiled code object in the prepared namespace
    exec(code_obj, globals())
    service_func = globals().get(service_name)
    print("service_func", service_func)
    func_list.append(service_func)


### DO NOT MODIFY BELOW THIS LINE ###
def get_all_functions():
    all_functions = []
    global_functions = globals()
    for _, func in global_functions.items():
        # Made one line change
        if isinstance(func, types.FunctionType) and func in func_list:
            all_functions.append(func)
    return all_functions

def get_info():
    #return {"functions": get_all_functions(), "descriptions": function_descriptions}
    return {"functions": func_list, "descriptions": func_sche_list}

def test_new_function(function, function_name, test_data):
    return (
        function(**test_data[function_name]["input"])
        == test_data[function_name]["output"]
    )