Commit c9d44988 authored by JinyuanSun's avatar JinyuanSun
Browse files

beta test of chatmol v2.1, add claude support

parent ebca0ff5
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -142,3 +142,4 @@ cf08d4763cf8b0309f2aa182c253388d712a1b736d4eb34f226c74588995faa0*
e9bbca7f41ac6c03b3a6c3193115e8ac8a4c4fd572a4230577a81c432b2cacfe*
06377e8a560dca176f9a7d7ac9e3184c6c67e13aa1e9a6a3972e2fde8952375b*
copilot_public/Project-*
chatmol_claude_test.py
 No newline at end of file
+160 −60
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import threading
import json
import http.server

from typing import List, Dict, Optional
from typing import Dict, List, Optional, Literal, Union
from datetime import datetime
from pymol import cmd

@@ -62,67 +62,112 @@ class PyMOLCommandHandler(http.server.BaseHTTPRequestHandler):
        self.end_headers()

class PyMOLAgent:
    OPENAI_MODELS = {
        "gpt-4o", "gpt-4o-mini"
    }
    
    ANTHROPIC_MODELS = {
        "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620"
    }

    def __init__(
        self,
        model: str = "gpt-4o",
        provider: Optional[Literal["openai", "anthropic"]] = None,
        system_message: Optional[str] = None,
        # max_history: int = 100,
        # command_timeout: int = 30
    ):
        self.local_api_file = os.path.expanduser("~") + "/.PyMOL/apikey.txt"
        self.api_key = self.load_api_key()
        if not self.api_key:
            raise ValueError("Please set OPENAI_API_KEY environment variable")
        self.config_dir = os.path.expanduser("~/.PyMOL")
        self.config_file = os.path.join(self.config_dir, "config.json")
        os.makedirs(self.config_dir, exist_ok=True)
        
        self.model = model
        self.system_message = system_message
        self.provider = provider or self.detect_provider(model)
        self.lite_conversation_history = ""
        self.system_message = """You are a PyMOL expert assistant, specialized in providing command line code solutions related to PyMOL. 
        self.config = self.load_config()
        self.api_key = self.get_api_key()
        
        if not self.api_key:
            raise ValueError(f"Please set {'ANTHROPIC' if self.provider == 'anthropic' else 'OPENAI'}_API_KEY environment variable or configure it in {self.config_file}")

        self.system_message = system_message or """You are a PyMOL expert assistant, specialized in providing command line code solutions related to PyMOL. 
Generate clear and effective solutions. 
Prefer academic style visulizations.
Format your responses like this:
Place PyMOL commands in ```pymol blocksxw
Place PyMOL commands in ```pymol blocks

Example response format:

```
```pymol
fetch 1abc
show cartoon
```
"""

        self.conversation_history: List[Dict[str, str]] = [
            {"role": "system", "content": self.system_message}
        ]
        self.api_url = "https://api.openai.com/v1/chat/completions"
        self.api_urls = {
            "openai": "https://api.openai.com/v1/chat/completions",
            "anthropic": "https://api.anthropic.com/v1/messages",
        }
        self.stashed_commands = []
    @classmethod
    def detect_provider(cls, model: str) -> str:
        """Automatically detect provider based on model name."""
        if model in cls.OPENAI_MODELS:
            return "openai"
        elif model in cls.ANTHROPIC_MODELS:
            return "anthropic"
        else:
            # Default to OpenAI if unknown model
            print(f"Warning: Unknown model '{model}'. Defaulting to OpenAI.")
            return "openai"

    def load_config(self) -> Dict[str, Union[str, Dict[str, str]]]:
        """Load configuration from JSON file."""
        default_config = {
            "provider": self.provider,
            "api_keys": {
                "openai": "",
                "anthropic": ""
            }
        }
        
    def load_api_key(self) -> str:
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
        try:
                with open(self.local_api_file, "r") as api_key_file:
                    api_key = api_key_file.read().strip()
                    print("API key loaded from file.")
            with open(self.config_file, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
                print(
                    "API key file not found. Please set your API key using 'set_api_key your_api_key_here' command"
                    + f" or by environment variable `OPENAI_KEY_ENV`."
                )
        return api_key
            self.save_config(default_config)
            return default_config
        except json.JSONDecodeError:
            print(f"Warning: Invalid JSON in {self.config_file}. Using default configuration.")
            self.save_config(default_config)
            return default_config

    def save_config(self, config: Dict[str, Union[str, Dict[str, str]]]) -> None:
        """Save configuration to JSON file."""
        try:
            with open(self.config_file, 'w') as f:
                json.dump(config, f, indent=2)
        except Exception as e:
            print(f"Warning: Could not save configuration to {self.config_file}: {str(e)}")

    def get_api_key(self) -> Optional[str]:
        """Get API key from environment or config file."""
        env_var = f"{'ANTHROPIC' if self.provider == 'anthropic' else 'OPENAI'}_API_KEY"
        return os.getenv(env_var) or self.config["api_keys"].get(self.provider, "")

    def set_api_key(self, api_key: str) -> None:
        """Set the OpenAI API key."""
        """Set the API key for the current provider."""
        api_key = api_key.strip()
        print("APIKEYFILE = ", self.local_api_file)
        try:
            with open(self.local_api_file, "w+") as api_key_file:
                api_key_file.write(api_key)
            print("API key set and saved to file successfully.")
        except:
            print(
                "API key set successfully but could not be saved to file. You may need to reset the API key next time."
            )
        self.api_key = api_key
        
        # Update config
        self.config["api_keys"][self.provider] = api_key
        self.save_config(self.config)
        
        print(f"API key for {self.provider} set and saved successfully.")
        
        # Reload PyMOL configuration
        cmd.reinitialize()
        cmd.do("@~/.pymolrc")
        cmd.do(
@@ -130,22 +175,70 @@ show cartoon
        )

    def update_model(self, model_name: str) -> str:
        """Update the GPT model used by the assistant."""
        """Update the model and automatically detect provider."""
        self.model = model_name
        new_provider = self.detect_provider(model_name)
        
        if new_provider != self.provider:
            self.provider = new_provider
            self.config["provider"] = new_provider
            self.save_config(self.config)
            self.api_key = self.get_api_key()
            print(f"Provider automatically switched to {new_provider}")
            
        return f"Model updated to: {self.model}"

    def get_headers(self) -> Dict[str, str]:
        """Get headers based on the current provider."""
        if self.provider == "anthropic":
            return {
                "Content-Type": "application/json",
                "x-api-key": self.api_key,
                "anthropic-version": "2023-06-01"
            }
        else:  # openai
            return {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}",
            }

    def add_message(self, role: str, content: str) -> None:
        """Add a message to the conversation history."""
        self.conversation_history.append({"role": role, "content": content})
    def prepare_messages(self, message: str) -> Dict:
        """Prepare the API request payload based on provider."""
        if self.provider == "anthropic":
            # Convert conversation history to messages format
            messages = []
            
            # First add any previous messages
            for msg in self.conversation_history[1:]:  # Skip system message
                if msg["role"] in ["user", "assistant"]:  # Only include user and assistant messages
                    messages.append({
                        "role": msg["role"],
                        "content": msg["content"]
                    })
            
            return {
                "model": self.model,
                "messages": messages,
                "max_tokens": 1024,
                "system": self.system_message,
                "temperature": 0.01,
            }
        else:  # openai
            return {
                "model": self.model,
                "messages": self.conversation_history,
                "temperature": 0.01,
            }

    def process_response(self, response_data: Dict) -> str:
        """Extract the assistant's message from the API response."""
        if self.provider == "anthropic":
            return response_data["content"][0]["text"]
        else:  # openai
            return response_data["choices"][0]["message"]["content"]

    def send_message(self, message: str, execute: bool = True) -> str:
        """Send a message and process PyMOL commands from the response."""
        """Send a message and process PyMOL commands."""
        message = message.strip()

        # Handle special commands
@@ -161,30 +254,26 @@ show cartoon
        # Add user message to history
        self.add_message("user", message)

        # Prepare the API request
        payload = {
            "model": self.model,
            "messages": self.conversation_history,
            "temperature": 0.01,
        }
        # Prepare and send API request
        payload = self.prepare_messages(message)
        
        try:
            # Make API call
            response = requests.post(
                self.api_url, headers=self.get_headers(), json=payload
                self.api_urls[self.provider],
                headers=self.get_headers(),
                json=payload
            )
            response.raise_for_status()

            # Parse response
            response_data = response.json()
            assistant_message = response_data["choices"][0]["message"]["content"]
            assistant_message = self.process_response(response.json())

            # Add assistant's response to history
            self.add_message("assistant", assistant_message)

            # Process PyMOL commands

            self.process_pymol_commands(assistant_message, execute)
            
            print("====================================")
            print("User:", message)
            print("Assistant:", assistant_message)
@@ -196,6 +285,17 @@ show cartoon
            print(error_msg)
            return error_msg

    def add_message(self, role: str, content: str) -> None:
        """Add a message to the conversation history."""
        self.conversation_history.append({"role": role, "content": content})

    def reset_conversation(self) -> str:
        """Reset the conversation history."""
        self.conversation_history = [
            {"role": "system", "content": self.system_message}
        ]
        return "Conversation reset."
    
    def process_pymol_commands(self, response: str, execute: bool) -> None:
        """Extract and process PyMOL commands from the response."""
        try: