Unverified Commit 6463a0bc authored by jinyuan sun's avatar jinyuan sun Committed by GitHub
Browse files

Merge pull request #43 from ChatMol/get_ride_of_openai

beta test of chatmol v2.2, add deepseek support
parents 8d0d8c78 11edb4da
Loading
Loading
Loading
Loading
+23 −11
Original line number Original line Diff line number Diff line
@@ -70,6 +70,10 @@ class PyMOLAgent:
        "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620"
        "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620"
    }
    }


    DEEPSEEK_MODELS = {
        "deepseek-chat"
    }

    def __init__(
    def __init__(
        self,
        self,
        model: str = "gpt-4o",
        model: str = "gpt-4o",
@@ -108,6 +112,7 @@ show cartoon
        self.api_urls = {
        self.api_urls = {
            "openai": "https://api.openai.com/v1/chat/completions",
            "openai": "https://api.openai.com/v1/chat/completions",
            "anthropic": "https://api.anthropic.com/v1/messages",
            "anthropic": "https://api.anthropic.com/v1/messages",
            "deepseek": "https://api.deepseek.com/chat/completions",
        }
        }
        self.stashed_commands = []
        self.stashed_commands = []
    @classmethod
    @classmethod
@@ -117,6 +122,8 @@ show cartoon
            return "openai"
            return "openai"
        elif model in cls.ANTHROPIC_MODELS:
        elif model in cls.ANTHROPIC_MODELS:
            return "anthropic"
            return "anthropic"
        elif model in cls.DEEPSEEK_MODELS:
            return "deepseek"
        else:
        else:
            # Default to OpenAI if unknown model
            # Default to OpenAI if unknown model
            print(f"Warning: Unknown model '{model}'. Defaulting to OpenAI.")
            print(f"Warning: Unknown model '{model}'. Defaulting to OpenAI.")
@@ -125,10 +132,10 @@ show cartoon
    def load_config(self) -> Dict[str, Union[str, Dict[str, str]]]:
    def load_config(self) -> Dict[str, Union[str, Dict[str, str]]]:
        """Load configuration from JSON file."""
        """Load configuration from JSON file."""
        default_config = {
        default_config = {
            "provider": self.provider,
            "api_keys": {
            "api_keys": {
                "openai": "",
                "openai": "",
                "anthropic": ""
                "anthropic": "",
                "deepseek": "",
            }
            }
        }
        }
        
        
@@ -143,7 +150,7 @@ show cartoon
            self.save_config(default_config)
            self.save_config(default_config)
            return default_config
            return default_config


    def save_config(self, config: Dict[str, Union[str, Dict[str, str]]]) -> None:
    def save_config(self, config) -> None:
        """Save configuration to JSON file."""
        """Save configuration to JSON file."""
        try:
        try:
            with open(self.config_file, 'w') as f:
            with open(self.config_file, 'w') as f:
@@ -153,8 +160,14 @@ show cartoon


    def get_api_key(self) -> Optional[str]:
    def get_api_key(self) -> Optional[str]:
        """Get API key from environment or config file."""
        """Get API key from environment or config file."""
        env_var = f"{'ANTHROPIC' if self.provider == 'anthropic' else 'OPENAI'}_API_KEY"
        # env_var = f"{'ANTHROPIC' if self.provider == 'anthropic' else 'OPENAI'}_API_KEY"
        return os.getenv(env_var) or self.config["api_keys"].get(self.provider, "")
        # return os.getenv(env_var) or self.config["api_keys"].get(self.provider, "")
        if self.provider == "anthropic":
            return os.getenv("ANTHROPIC_API_KEY") or self.config["api_keys"].get(self.provider, "")
        if self.provider == "openai":
            return os.getenv("OPENAI_API_KEY") or self.config["api_keys"].get(self.provider, "")
        if self.provider == "deepseek":
            return os.getenv("DEEPSEEK_API_KEY") or self.config["api_keys"].get(self.provider, "")


    def set_api_key(self, api_key: str) -> None:
    def set_api_key(self, api_key: str) -> None:
        """Set the API key for the current provider."""
        """Set the API key for the current provider."""
@@ -168,11 +181,11 @@ show cartoon
        print(f"API key for {self.provider} set and saved successfully.")
        print(f"API key for {self.provider} set and saved successfully.")
        
        
        # Reload PyMOL configuration
        # Reload PyMOL configuration
        cmd.reinitialize()
        # cmd.reinitialize()
        cmd.do("@~/.pymolrc")
        # cmd.do("@~/.pymolrc")
        cmd.do(
        # cmd.do(
            "load https://raw.githubusercontent.com/JinyuanSun/ChatMol/main/chatmol.py"
        #     "load https://raw.githubusercontent.com/JinyuanSun/ChatMol/main/chatmol.py"
        )
        # )


    def update_model(self, model_name: str) -> str:
    def update_model(self, model_name: str) -> str:
        """Update the model and automatically detect provider."""
        """Update the model and automatically detect provider."""
@@ -181,7 +194,6 @@ show cartoon
        
        
        if new_provider != self.provider:
        if new_provider != self.provider:
            self.provider = new_provider
            self.provider = new_provider
            self.config["provider"] = new_provider
            self.save_config(self.config)
            self.save_config(self.config)
            self.api_key = self.get_api_key()
            self.api_key = self.get_api_key()
            print(f"Provider automatically switched to {new_provider}")
            print(f"Provider automatically switched to {new_provider}")