Commit d8e60e9b authored by JinyuanSun's avatar JinyuanSun
Browse files

beta test of chatmol v2.3, add ollama support

parent 11edb4da
Loading
Loading
Loading
Loading
+27 −0
Original line number Original line Diff line number Diff line
@@ -113,8 +113,10 @@ show cartoon
            "openai": "https://api.openai.com/v1/chat/completions",
            "openai": "https://api.openai.com/v1/chat/completions",
            "anthropic": "https://api.anthropic.com/v1/messages",
            "anthropic": "https://api.anthropic.com/v1/messages",
            "deepseek": "https://api.deepseek.com/chat/completions",
            "deepseek": "https://api.deepseek.com/chat/completions",
            "ollama": "http://localhost:11434/api/chat"
        }
        }
        self.stashed_commands = []
        self.stashed_commands = []

    @classmethod
    @classmethod
    def detect_provider(cls, model: str) -> str:
    def detect_provider(cls, model: str) -> str:
        """Automatically detect provider based on model name."""
        """Automatically detect provider based on model name."""
@@ -190,6 +192,15 @@ show cartoon
    def update_model(self, model_name: str) -> str:
    def update_model(self, model_name: str) -> str:
        """Update the model and automatically detect provider."""
        """Update the model and automatically detect provider."""
        self.model = model_name
        self.model = model_name

        if model_name.split("@")[-1] == "ollama":
            self.provider = "ollama"
            self.model = model_name.split("@")[0]
            # self.base_url = "http://localhost:11434/api/chat"
            # self.base_url = "https://chatmol.org/ollama/api/chat"

            return f"Model updated to: {self.model}"
        
        new_provider = self.detect_provider(model_name)
        new_provider = self.detect_provider(model_name)
        
        
        if new_provider != self.provider:
        if new_provider != self.provider:
@@ -208,6 +219,10 @@ show cartoon
                "x-api-key": self.api_key,
                "x-api-key": self.api_key,
                "anthropic-version": "2023-06-01"
                "anthropic-version": "2023-06-01"
            }
            }
        elif self.provider == "ollama":
            return {
                "Content-Type": "application/json",
            }
        else:  # openai
        else:  # openai
            return {
            return {
                "Content-Type": "application/json",
                "Content-Type": "application/json",
@@ -235,6 +250,16 @@ show cartoon
                "system": self.system_message,
                "system": self.system_message,
                "temperature": 0.01,
                "temperature": 0.01,
            }
            }
        elif self.provider == "ollama":
            return {
                "model": self.model,
                "messages": self.conversation_history,
                "stream": False,
                "options": {
                    "seed": 101,
                    "temperature": 0
                }
            }
        else:  # openai
        else:  # openai
            return {
            return {
                "model": self.model,
                "model": self.model,
@@ -246,6 +271,8 @@ show cartoon
        """Extract the assistant's message from the API response."""
        """Extract the assistant's message from the API response."""
        if self.provider == "anthropic":
        if self.provider == "anthropic":
            return response_data["content"][0]["text"]
            return response_data["content"][0]["text"]
        elif self.provider == "ollama":
            return response_data["message"]["content"]
        else:  # openai
        else:  # openai
            return response_data["choices"][0]["message"]["content"]
            return response_data["choices"][0]["message"]["content"]