Commit 394dfba6 authored by jameskrw's avatar jameskrw
Browse files

udpated llm judge

parent fef7d8a8
Loading
Loading
Loading
Loading
+399 −376

File changed.

Preview size limit exceeded, changes collapsed.

+28 −82
Original line number Diff line number Diff line
# Model
model:
  name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
  temperature: 0.1
  max_tokens: 500

# API
api:
  max_parallel_requests: 50
  max_retries: 3
  request_timeout: 15

# Log
log_dir: "./logs/llm_judge"
wandb:
  project: "vagen_process_reward_judge"
  run_name: "llm_judge"
  correct_grounding_samples: 8
  incorrect_grounding_samples: 8
  correct_worldmodeling_samples: 8
  incorrect_worldmodeling_samples: 8

# Prompt
prompt_templates:
  sokoban:
  # Default environment templates that other environments can reference
  default_env:
    grounding: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
@@ -23,7 +32,7 @@ prompt_templates:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      {natural_language_description}
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    worldmodeling: |
@@ -35,84 +44,21 @@ prompt_templates:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      {natural_language_description}
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
  sokoban:
    grounding: ${prompt_templates.default_env.grounding}
    worldmodeling: ${prompt_templates.default_env.worldmodeling}
  
  frozenlake:
    grounding: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    worldmodeling: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    grounding: ${prompt_templates.default_env.grounding}
    worldmodeling: ${prompt_templates.default_env.worldmodeling}
  
  maniskill:
    grounding: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    worldmodeling: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    grounding: ${prompt_templates.default_env.grounding}
    worldmodeling: ${prompt_templates.default_env.worldmodeling}
  
  navigation:
    grounding: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
    worldmodeling: |
      Compare the natural language description with the state information dictionary.
      Answer YES if the description accurately matches the state, or NO if it doesn't.
      Think step by step and end with your answer in <answer>YES</answer> or <answer>NO</answer> format.
      
      State Information:
      {state_information_dict}
      
      Description:
      "{natural_language_description}"
      
      Your answer should be within {max_tokens} tokens and MUST end with <answer>YES</answer> or <answer>NO</answer>.
 No newline at end of file
    grounding: ${prompt_templates.default_env.grounding}
    worldmodeling: ${prompt_templates.default_env.worldmodeling}
 No newline at end of file
+27 −38
Original line number Diff line number Diff line
@@ -17,71 +17,64 @@ import time
from .llm_judge import run_llm_judge

def env_state_reward_wrapper(step_func):
    """
    Decorator function that enhances the step method to include state rewards.
    
    This wrapper:
    1. Captures the state before and after executing the original step function
    2. Updates the info dictionary with appropriate content and state keys based on prompt format
    3. Handles accumulated rewards if configured
    
    Args:
        step_func: The original step function to be wrapped
        
    Returns:
        The wrapped step function with enhanced state reward functionality
    """
    def wrapped_step(self, action_str):
        if hasattr(self, 'config') and self.config.get("use_state_reward", False):
            pre_state = self.get_env_state()
            obs, reward, done, info = step_func(self, action_str)
            post_state = self.get_env_state()
            
            # Get the mapping based on prompt format
            prompt_format = self.config.get("prompt_format", None)
            if prompt_format is None:
                raise ValueError("Prompt format is not specified in the config.")
            assert ("grounding" in prompt_format or "worldmodeling" in prompt_format)
        
            pre_state = self.get_env_state()
            obs, reward, done, info = step_func(self, action_str)
            post_state = self.get_env_state()
            
            if "metrics" not in info:
                info["metrics"] = {"turn_metrics": {}, "traj_metrics": {}}
            if "turn_metrics" not in info["metrics"]:
                info["metrics"]["turn_metrics"] = {}
                
            if info.get("is_format_rewarded", False): # if no format reward, no need to calculate state reward, skipping
                info["use_state_reward"] = True
                if "observation_content" in info and info["observation_content"]:
                    info["observation_state"] = pre_state
                if "prediction_content" in info and info["prediction_content"]:
                    info["prediction_state"] = post_state
            info["use_state_reward"] = True
            else:
                info["use_state_reward"] = False
                if "observation_content" in info and info["observation_content"]:
                    info["metrics"]["turn_metrics"]["grounding_reward"] = 0.0
                if "prediction_content" in info and info["prediction_content"]:
                    info["metrics"]["turn_metrics"]["worldmodeling_reward"] = 0.0
            return obs, reward, done, info
        else:
            return step_func(self, action_str)
    return wrapped_step

def service_state_reward_wrapper(step_batch_func):
    """
    Decorator to wrap the step_batch function to calculate and apply rewards.
    
    Args:
        step_batch_func: Original step_batch function
        
    Returns:
        Wrapped step_batch function with reward calculation
    """
    def wrapped_step_batch(self, ids2actions):
        # Call the original step_batch function
        step_batch_results = step_batch_func(self, ids2actions)
        input_to_llm = []
        for id, result in step_batch_results.items():
            obs, reward, done, info = result
            env_name = self.env_configs[id].get("env_name", "default_env")
            if info.get("use_state_reward", False):
                if info.get("observation_content", None) and info.get("observation_state", None):
                    input_to_llm.append({
                        "id": id,
                        "content": info["observation_content"],
                        "state": info["observation_state"],
                        "type": "observation"
                        "type": "grounding",
                        "env_name": env_name,
                    })
                if info.get("prediction_content", None) and info.get("prediction_state", None):
                    input_to_llm.append({
                        "id": id,
                        "content": info["prediction_content"],
                        "state": info["prediction_state"],
                        "type": "prediction"
                        "type": "worldmodeling",
                        "env_name": env_name,
                    })
                    
        if len(input_to_llm) > 0:
@@ -95,10 +88,6 @@ def service_state_reward_wrapper(step_batch_func):
        for item, score in zip(input_to_llm, scores):
            id = item["id"]
            env_config = self.env_configs[id]
            if "metrics" not in new_step_batch_results[id][3]:
                new_step_batch_results[id][3]["metrics"] = {"turn_metrics": {}, "traj_metrics": {}}
            if "turn_metrics" not in new_step_batch_results[id][3]["metrics"]:
                new_step_batch_results[id][3]["metrics"]["turn_metrics"] = {}
            if item["type"] == "observation":
                new_step_batch_results[id][3]["metrics"]["turn_metrics"]["grounding_reward"] = score * env_config.get("grounding_reward_weight", 0.5)
                new_step_batch_results[id][1] += score * env_config.get("grounding_reward_weight", 0.5)