Commit 4448d68d authored by Jiabo Li's avatar Jiabo Li
Browse files

Simplify ChatMol usage without using special keywords. User can use free...

Simplify ChatMol usage without using special keywords. User can use free format query in addition to the regular PyMOL commands. This is a seamless integration of ChatMol commands and the strict PyMOL command. Also switch GPT-4o, as it is super fast
parent 57307452
Loading
Loading
Loading
Loading
+116 −14
Original line number Original line Diff line number Diff line
import sys
import io
import os
import os
from openai import OpenAI
import threading
import threading
import requests
import requests
import json
import json
from pymol import cmd
import http.server
import http.server
from pymol import cmd
from openai import OpenAI


self_call_counter = 0

pymol_cmd_list = [
    'abort', 'accept', 'align', 'alignto', 'alphatoall', 'alter', 'alter_state', 'angle', 'api',
    'as', 'assign_stereo', 'backward', 'bg_color', 'bond', 'button', 'cache', 'callout', 'capture',
    'cartoon', 'cd', 'center', 'centerofmass', 'check', 'clean', 'clip', 'cls', 'color', 'color_deep',
    'conda', 'copy', 'copy_to', 'count_atoms', 'count_discrete', 'count_frames', 'count_states', 'create',
    'cycle_valence', 'decline', 'deprotect', 'desaturate', 'deselect', 'diagnostics', 'dihedral', 'dir',
    'disable', 'drag', 'dss', 'dump', 'edit', 'edit_mode', 'embed', 'enable', 'ending', 'extract', 'feedback',
    'fetch', 'fit', 'fix_chemistry', 'flag', 'fnab', 'focal_blur', 'fork', 'forward', 'fragment', 'frame',
    'full_screen', 'fuse', 'get', 'get_angle', 'get_area', 'get_bond', 'get_chains', 'get_dihedral', 'get_distance',
    'get_extent', 'get_position', 'get_property', 'get_property_list', 'get_renderer', 'get_sasa_relative',
    'get_symmetry', 'get_title', 'get_type', 'get_version', 'get_view', 'get_viewport', 'gradient', 'group',
    'h_add', 'h_fill', 'h_fix', 'help', 'help_setting', 'hide', 'id_atom', 'identify', 'index', 'indicate',
    'intra_fit', 'intra_rms', 'intra_rms_cur', 'invert', 'isodot', 'isolevel', 'isomesh', 'isosurface', 'iterate',
    'iterate_state', 'join_states', 'label', 'load', 'load_embedded', 'load_mtz', 'load_png', 'load_traj', 'loadall',
    'log', 'log_close', 'log_open', 'ls', 'madd', 'map_double', 'map_halve', 'map_new', 'map_set', 'map_set_border',
    'map_trim', 'mappend', 'mask', 'matrix_copy', 'matrix_reset', 'mclear', 'mcopy', 'mdelete', 'mdo', 'mdump', 'mem',
    'meter_reset', 'middle', 'minsert', 'mmatrix', 'mmove', 'morph', 'mouse', 'move', 'movie.load', 'movie.nutate',
    'movie.pause', 'movie.produce', 'movie.rock', 'movie.roll', 'movie.screw', 'movie.sweep', 'movie.tdroll',
    'movie.zoom', 'mplay', 'mpng', 'mse2met', 'mset', 'mstop', 'mtoggle', 'multifilesave', 'multisave', 'mview', 'order',
    'orient', 'origin', 'overlap', 'pair_fit', 'phi_psi', 'pi_interactions', 'pip', 'png', 'pop', 'protect', 'pseudoatom',
    'pwd', 'python', 'quit', 'ramp_new', 'ramp_update', 'ray', 'rebond', 'rebuild', 'recolor', 'redo', 'reference', 'refresh',
    'refresh_wizard', 'reinitialize', 'remove', 'remove_picked', 'rename', 'replace', 'replace_wizard', 'reset', 'resume',
    'rewind', 'rms', 'rms_cur', 'rock', 'rotate', 'run', 'save', 'scene', 'scene_order', 'sculpt_activate', 'sculpt_deactivate',
    'sculpt_iterate', 'sculpt_purge', 'select', 'set_atom_property', 'set_bond', 'set_color', 'set_dihedral', 'set_key',
    'set_name', 'set_symmetry', 'set_view', 'show', 'skip', 'slice_new', 'smooth', 'sort', 'space', 'spectrum', 'splash',
    'split_chains', 'split_states', 'stereo', 'symexp', 'symmetry_copy', 'toggle', 'translate', 'turn', 'unbond', 'uniquify',
    'unmask', 'unset', 'unset_bond', 'unset_deep', 'update', 'util.cbas', 'util.cbaw', 'util.cba', 'util.cbam', 'util.cbap',
    'util.cbay', 'vdw_fit', 'volume', 'volume_color', 'volume_panel', 'volume_ramp_new', 'window', 'wizard','zoom'
]




class PyMOLCommandHandler(http.server.BaseHTTPRequestHandler):
class PyMOLCommandHandler(http.server.BaseHTTPRequestHandler):
    def __init__(self):
    def __init__(self):
@@ -74,7 +111,8 @@ stashed_commands = []
# Save API Key in ~/.PyMOL/apikey.txt
# Save API Key in ~/.PyMOL/apikey.txt
API_KEY_FILE = os.path.expanduser('~')+"/.PyMOL/apikey.txt"
API_KEY_FILE = os.path.expanduser('~')+"/.PyMOL/apikey.txt"
OPENAI_KEY_ENV = "OPENAI_API_KEY"
OPENAI_KEY_ENV = "OPENAI_API_KEY"
GPT_MODEL = "gpt-3.5-turbo-1106"
#GPT_MODEL = "gpt-3.5-turbo-1106"
GPT_MODEL = "gpt-4o"
client = None
client = None


def set_api_key(api_key):
def set_api_key(api_key):
@@ -118,9 +156,15 @@ def chat_with_gpt(message, max_history=10):


    conversation_history += f"User: {message}\nChatGPT:"
    conversation_history += f"User: {message}\nChatGPT:"


    system_prompt = """
    You are PyMOL expert and able to convert a user query about molecule operation into one or more PyMOL commands.
    For example, is a user ask 'Please load protein 3hiv', you should provide the response as 'fetch 3hiv'.
    Sometimes, the use may just ask some general questions, in these case, you just answer these questions.
    """

    try:
    try:
        messages = [
        messages = [
            {"role": "system", "content": "You are an AI language model specialized in providing command line code solutions related to PyMOL. Generate clear and effective solutions in a continuous manner. When providing demos or examples, try to use 'fetch' if object name is not provided. Prefer academic style visulizations. Code within triple backticks, comment and code should not in the same line."}
            {"role": "system", "content": system_prompt}
        ]
        ]


        # Keep only the max_history latest exchanges to avoid making the conversation too long
        # Keep only the max_history latest exchanges to avoid making the conversation too long
@@ -166,13 +210,18 @@ def chatlite(question):
    lite_conversation_history += data['answer']
    lite_conversation_history += data['answer']
    lite_conversation_history += "\n"
    lite_conversation_history += "\n"
    commands = data['answer']
    commands = data['answer']
    commands = commands.strip()
    # Check for valid pymol command
    words = commands.split(" ")
    commands = commands.split('\n')
    commands = commands.split('\n')
    if (len(words)> 0 and words[0] in pymol_cmd_list):
        for command in commands:
        for command in commands:
            if command == '':
            if command == '':
                continue
                continue
            else:
            else:
            cmd.do(command)
                #cmd.do(command)
    print("Answers from ChatMol-Lite: ")
                original_do(command)
    print("Answers from ChatMol-Lite is: ")
    for command in commands:
    for command in commands:
        if command == '':
        if command == '':
            continue
            continue
@@ -192,7 +241,8 @@ def start_chatgpt_cmd(message, execute:bool=True, lite:bool=False):
        if (len(stashed_commands) == 0):
        if (len(stashed_commands) == 0):
            print("There is no stashed commands")
            print("There is no stashed commands")
        for command in stashed_commands:
        for command in stashed_commands:
            cmd.do(command)
            #cmd.do(command)
            original_do(command)
        # clear stash
        # clear stash
        stashed_commands.clear()
        stashed_commands.clear()
        return 0
        return 0
@@ -208,6 +258,19 @@ def start_chatgpt_cmd(message, execute:bool=True, lite:bool=False):
    response = chat_with_gpt(message)
    response = chat_with_gpt(message)
    print("ChatGPT: " + response.strip())
    print("ChatGPT: " + response.strip())


    #Execute pymol command
    # Check for valid pymol command
    commands = response.strip()
    words = commands.split(" ")
    commands = commands.split('\n')
    if (len(words)> 0 and words[0] in pymol_cmd_list):
        for command in commands:
            if command == '':
                continue
            else:
                #cmd.do(command)
                original_do(command)

    try:
    try:
        command_blocks = []
        command_blocks = []
        # I think it would be better to reset stashed_commands to empty for each chat.
        # I think it would be better to reset stashed_commands to empty for each chat.
@@ -226,19 +289,58 @@ def start_chatgpt_cmd(message, execute:bool=True, lite:bool=False):
                        index_ = command.index("#")
                        index_ = command.index("#")
                        if execute:
                        if execute:
                            print(command[:index_])
                            print(command[:index_])
                            cmd.do(command[:index_])
                            #cmd.do(command[:index_])
                            original_do(command[:index_])
                        else:
                        else:
                            stashed_commands.append(command[:index_])
                            stashed_commands.append(command[:index_])
                    else:
                    else:
                        if execute:
                        if execute:
                            print(command)
                            print(command)
                            cmd.do(command)
                            #cmd.do(command)
                            original_do(command)
                        else:
                        else:
                            stashed_commands.append(command)
                            stashed_commands.append(command)


    except Exception as e:
    except Exception as e:
        print(f"Error command execution code: {e}")
        print(f"Error command execution code: {e}")


# Save the original cmd.do function
original_do = cmd.do

stdout_output = ''
stderr_output = ''

def my_custom_do(command):
    global self_call_counter
    self_call_counter = 0
    try:
        # Capture stdout and stderr
        old_stdout = sys.stdout
        sys.stdout = stdout_capture = io.StringIO()

        # Execute the original command
        original_do(command)
        
        # Get the captured stdout and stderr
        stdout_output = stdout_capture.getvalue()
        
    except Exception as e:
        print("An error occurred while executing the PyMOL command.")

    finally:
        # Restore stdout and stderr
        sys.stdout = old_stdout

    #print("stdout_output:", stdout_output)
    if (len(stdout_output) >1):
       print("Free format PyMOL command:",command)
       stdout_output = ""
       if (self_call_counter < 1):
            #chatlite(command)
            start_chatgpt_cmd(command)

# Overwrite cmd.do with the custom function
cmd.do = my_custom_do


client = load_api_key()
client = load_api_key()