Commit 6870f30c authored by jameskrw's avatar jameskrw
Browse files

minor

parent d47db92f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ class NavigationService(BaseService):
        self.environments = {}
        self.env_configs = {}
        self.config=config
        print(f"[DEBUG] {self.config}")
    
    def create_environments_batch(self, ids2configs: Dict[str, Any]) -> None:
        """
+3 −2
Original line number Diff line number Diff line
@@ -55,9 +55,10 @@ def service_state_reward_wrapper(step_batch_func):
    def wrapped_step_batch(self, ids2actions):
        # Call the original step_batch function
        step_batch_results = step_batch_func(self, ids2actions)
        if self.config.get("use_state_reward", False):
        if not self.config.get("use_state_reward", False):
            print("[DEUBG] State reward wrapper closed")
            return step_batch_results
        print("State reward wrapper enabled")
        print("[DEUBG] State reward wrapper enabled")
        input_to_llm = []
        for id, result in step_batch_results.items():
            obs, reward, done, info = result
+9 −6
Original line number Diff line number Diff line
@@ -173,7 +173,7 @@ def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        )
        
        # Calculate parse success rate
        parse_successes = sum(1 for r in results if r["success"] and (r["score"] == 1.0 or r["score"] == 0.0))
        parse_successes = sum(1 for r in results if r["parse_success"] and r["success"])
        parse_success_rate = parse_successes / completed_requests if completed_requests > 0 else 0
        
        # Log scalar metrics to wandb with step to ensure proper plotting
@@ -197,12 +197,11 @@ def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        columns = ["id", "env_name", "prompt", "response", "parsed_answer"]
        
        # Split results by category
        correct_grounding = [r for r in grounding_results if r["success"] and r["score"] == 1.0]
        correct_grounding = [r for r in grounding_results if r["success"] and r["score"] > 1e-6]
        incorrect_grounding = [r for r in grounding_results if r["success"] and r["score"] == 0.0]
        correct_worldmodeling = [r for r in worldmodeling_results if r["success"] and r["score"] == 1.0]
        correct_worldmodeling = [r for r in worldmodeling_results if r["success"] and r["score"] > 1e-6]
        incorrect_worldmodeling = [r for r in worldmodeling_results if r["success"] and r["score"] == 0.0]
        parse_failed = [r for r in results if r["success"] and 
                       not re.search(r'<answer>(YES|NO)</answer>', r["response"], re.IGNORECASE)]
        parse_failed = [r for r in results if r["success"] and (not r["parse_success"])]
        
        # Function to extract answer from response
        def extract_parsed_answer(response):
@@ -415,6 +414,9 @@ def process_llm_judgments(input_data: List[Dict[str, Any]], config: Optional[Dic
            if answer_match:
                answer = answer_match.group(1).upper()
                score = 1.0 if answer == "YES" else 0.0
                parse_success = True
            else:
                parse_success = False
        
        # Create the result dictionary
        result = {
@@ -425,7 +427,8 @@ def process_llm_judgments(input_data: List[Dict[str, Any]], config: Optional[Dic
            "response": response_data["response"],
            "success": response_data["success"],
            "score": score,
            "error": response_data["error"]
            "error": response_data["error"],
            "parse_success": parse_success
        }
        
        results.append(result)