Commit a2fa9f36 authored by jameskrw's avatar jameskrw
Browse files

minor

parent 4b45dab9
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ wandb:
  incorrect_grounding_samples: 8
  correct_worldmodeling_samples: 8
  incorrect_worldmodeling_samples: 8
  parse_failed_samples: 8
  table_logging_frequency: 10

# Prompt
prompt_templates:
+86 −123
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ _GLOBAL_STEPS = {} # Track global step count per process
_PROCESS_LOCKS = {}  # Semaphore for each process
_HYDRA_LOCKS = {}  # Semaphore for Hydra initialization
_HYDRA_INITIALIZED = {}  # Track Hydra initialization per process

_PID_CONFIG= {}  # Store config per process
# Context manager to ensure proper cleanup of wandb sessions
@contextmanager
def wandb_run_context():
@@ -63,8 +63,13 @@ def _get_hydra_config(pid: int) -> DictConfig:
            # Mark as initialized for this process
            _HYDRA_INITIALIZED[pid] = True
        
        # Load and return the config
        return hydra.compose(config_name="llm_as_judge")
        if pid not in _PID_CONFIG:
            # Load the config for this process
            config = hydra.compose(config_name="llm_as_judge")
            _PID_CONFIG[pid] = config
        else:
            config = _PID_CONFIG[pid]
        return config
            
def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
@@ -113,11 +118,16 @@ def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
            _WANDB_INITIALIZED[pid] = True
        
        # Get sampling parameters from wandb config
        wandb_config = wandb.config
        wandb_config = wandb.config.wandb
        correct_grounding_samples = wandb_config.get("correct_grounding_samples", 3)
        incorrect_grounding_samples = wandb_config.get("incorrect_grounding_samples", 3)
        correct_worldmodeling_samples = wandb_config.get("correct_worldmodeling_samples", 3)
        incorrect_worldmodeling_samples = wandb_config.get("incorrect_worldmodeling_samples", 3)
        parse_failed_samples = wandb_config.get("parse_failed_samples", 3)
        # Removed error_samples as we no longer log error data to wandb tables
        
        # Get table logging frequency (default to 10 if not specified)
        table_logging_frequency = wandb_config.get("table_logging_frequency", 10)
        
        # Measure execution time
        start_time = time.time()
@@ -199,110 +209,97 @@ def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
                global_step: Current global step
                
            Returns:
                Dictionary with column names as keys and values as a single row
                List with values in the same order as the columns
            """
            if not results_subset:
                return {}
                return None
            
            # Sample results (or take all if fewer than max_samples)
            samples = random.sample(results_subset, min(max_samples, len(results_subset)))
            
            # Create a dictionary where keys are column names and values are lists
            row_data = {"step": global_step}
            
            # Add columns for each sample and its fields
            for i, sample in enumerate(samples):
                # For each sample, add columns with index for each required field
                sample_idx = i + 1
                row_data[f"sample_{sample_idx}_id"] = sample["id"]
                row_data[f"sample_{sample_idx}_env_name"] = sample["env_name"]
                row_data[f"sample_{sample_idx}_prompt"] = sample["prompt"]
                row_data[f"sample_{sample_idx}_response"] = sample["response"]
                row_data[f"sample_{sample_idx}_parsed_answer"] = extract_parsed_answer(sample["response"])
            # Create a list with values in the same order as the columns
            # First element is the step
            row_data = [global_step]
            
            # For each possible sample index (1 to max_samples)
            for i in range(max_samples):
                # If we have a sample at this index
                if i < len(samples):
                    sample = samples[i]
                    # Add each field value in order
                    row_data.append(sample["id"])
                    row_data.append(sample["env_name"])
                    row_data.append(sample["prompt"])
                    row_data.append(sample["response"])
                    row_data.append(extract_parsed_answer(sample["response"]))
                else:
                    # Add empty values for missing samples to ensure all rows have the same length
                    row_data.extend([""] * 5)  # 5 fields per sample
            
            return row_data

        # Create and update tables with the global step structure
        # Removed prepare_error_table_data function since we no longer log error data to wandb tables

        # Create and log tables with the global step structure
        def log_tables_with_step(global_step):
            # Define tables if they don't exist yet or get existing ones
            if "correct_grounding_table" not in wandb.run.summary:
            # Define columns for each table type
            correct_grounding_columns = ["step"] + [
                f"sample_{i}_{field}" 
                for i in range(1, correct_grounding_samples + 1) 
                for field in ["id", "env_name", "prompt", "response", "parsed_answer"]
            ]
                correct_grounding_table = wandb.Table(columns=correct_grounding_columns)
            
            incorrect_grounding_columns = ["step"] + [
                f"sample_{i}_{field}" 
                for i in range(1, incorrect_grounding_samples + 1) 
                for field in ["id", "env_name", "prompt", "response", "parsed_answer"]
            ]
                incorrect_grounding_table = wandb.Table(columns=incorrect_grounding_columns)
            
            correct_worldmodeling_columns = ["step"] + [
                f"sample_{i}_{field}" 
                for i in range(1, correct_worldmodeling_samples + 1) 
                for field in ["id", "env_name", "prompt", "response", "parsed_answer"]
            ]
                correct_worldmodeling_table = wandb.Table(columns=correct_worldmodeling_columns)
            
            incorrect_worldmodeling_columns = ["step"] + [
                f"sample_{i}_{field}" 
                for i in range(1, incorrect_worldmodeling_samples + 1) 
                for field in ["id", "env_name", "prompt", "response", "parsed_answer"]
            ]
                incorrect_worldmodeling_table = wandb.Table(columns=incorrect_worldmodeling_columns)
            
            parse_failed_columns = ["step"] + [
                f"sample_{i}_{field}" 
                    for i in range(1, 4)  # Up to 3 parse failure samples 
                for i in range(1, parse_failed_samples + 1)  # Use config for parse failure samples 
                for field in ["id", "env_name", "prompt", "response", "parsed_answer"]
            ]
                parse_failed_table = wandb.Table(columns=parse_failed_columns)
            
                # Initialize tables in wandb
                wandb.run.summary["correct_grounding_table"] = correct_grounding_table
                wandb.run.summary["incorrect_grounding_table"] = incorrect_grounding_table
                wandb.run.summary["correct_worldmodeling_table"] = correct_worldmodeling_table
                wandb.run.summary["incorrect_worldmodeling_table"] = incorrect_worldmodeling_table
                wandb.run.summary["parse_failed_table"] = parse_failed_table
            else:
                # Get existing tables
                correct_grounding_table = wandb.run.summary["correct_grounding_table"]
                incorrect_grounding_table = wandb.run.summary["incorrect_grounding_table"]
                correct_worldmodeling_table = wandb.run.summary["correct_worldmodeling_table"]
                incorrect_worldmodeling_table = wandb.run.summary["incorrect_worldmodeling_table"]
                parse_failed_table = wandb.run.summary["parse_failed_table"]
            # Create new tables for each step to avoid summary access issues
            correct_grounding_table = wandb.Table(columns=correct_grounding_columns)
            incorrect_grounding_table = wandb.Table(columns=incorrect_grounding_columns)
            correct_worldmodeling_table = wandb.Table(columns=correct_worldmodeling_columns)
            incorrect_worldmodeling_table = wandb.Table(columns=incorrect_worldmodeling_columns)
            parse_failed_table = wandb.Table(columns=parse_failed_columns)
            
            # Prepare data rows for each table (one row per global step)
            correct_grounding_data = prepare_table_data(correct_grounding, correct_grounding_samples, global_step)
            incorrect_grounding_data = prepare_table_data(incorrect_grounding, incorrect_grounding_samples, global_step)
            correct_worldmodeling_data = prepare_table_data(correct_worldmodeling, correct_worldmodeling_samples, global_step)
            incorrect_worldmodeling_data = prepare_table_data(incorrect_worldmodeling, incorrect_worldmodeling_samples, global_step)
            parse_failed_data = prepare_table_data(parse_failed, 3, global_step)  # Up to 3 parse failures
            parse_failed_data = prepare_table_data(parse_failed, parse_failed_samples, global_step)  # Use config parameter
            
            # Add data rows to tables
            if correct_grounding_data:
                correct_grounding_table.add_data(**correct_grounding_data)
                correct_grounding_table.add_data(*correct_grounding_data)
            if incorrect_grounding_data:
                incorrect_grounding_table.add_data(**incorrect_grounding_data)
                incorrect_grounding_table.add_data(*incorrect_grounding_data)
            if correct_worldmodeling_data:
                correct_worldmodeling_table.add_data(**correct_worldmodeling_data)
                correct_worldmodeling_table.add_data(*correct_worldmodeling_data)
            if incorrect_worldmodeling_data:
                incorrect_worldmodeling_table.add_data(**incorrect_worldmodeling_data)
                incorrect_worldmodeling_table.add_data(*incorrect_worldmodeling_data)
            if parse_failed_data:
                parse_failed_table.add_data(**parse_failed_data)
            
            # Update the tables in wandb
            wandb.run.summary["correct_grounding_table"] = correct_grounding_table
            wandb.run.summary["incorrect_grounding_table"] = incorrect_grounding_table
            wandb.run.summary["correct_worldmodeling_table"] = correct_worldmodeling_table
            wandb.run.summary["incorrect_worldmodeling_table"] = incorrect_worldmodeling_table
            wandb.run.summary["parse_failed_table"] = parse_failed_table
                parse_failed_table.add_data(*parse_failed_data)
            
            # Also log the tables to the history
            # Log the tables directly to history without using summary
            wandb.log({
                "correct_grounding_examples": correct_grounding_table,
                "incorrect_grounding_examples": incorrect_grounding_table,
@@ -311,46 +308,12 @@ def run_llm_judge(input_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
                "parse_failed_examples": parse_failed_table
            }, step=global_step)

        # Similarly update the error table:
        def prepare_error_table_data(error_examples, max_samples, global_step):
            if not error_examples:
                return {}
            
            samples = error_examples[:max_samples]  # Take up to max_samples errors
            
            row_data = {"step": global_step}
            for i, sample in enumerate(samples):
                sample_idx = i + 1
                row_data[f"error_{sample_idx}_id"] = sample["id"]
                row_data[f"error_{sample_idx}_env_name"] = sample["env_name"]
                row_data[f"error_{sample_idx}_type"] = sample["type"]
                row_data[f"error_{sample_idx}_error"] = sample["error"]
            
            return row_data

        # Process error examples
        error_examples = [r for r in results if not r["success"]]
        
        # Create and update error table
        if "error_table" not in wandb.run.summary:
            error_columns = ["step"] + [
                f"error_{i}_{field}" 
                for i in range(1, 4)  # Up to 3 error samples
                for field in ["id", "env_name", "type", "error"]
            ]
            error_table = wandb.Table(columns=error_columns)
            wandb.run.summary["error_table"] = error_table
        else:
            error_table = wandb.run.summary["error_table"]

        error_data = prepare_error_table_data(error_examples, 3, global_step)  # Sample up to 3 errors
        if error_data:
            error_table.add_data(**error_data)

        wandb.run.summary["error_table"] = error_table
        wandb.log({"error_examples": error_table}, step=global_step)
        # Remove error data logging to wandb tables
        
        # Replace the original table creation with the new approach
        # Replace the original table logging with a frequency-based approach
        # Only log tables if the current step is divisible by the table logging frequency
        # This ensures we log tables at regular intervals rather than every step
        if global_step % table_logging_frequency == 0:
            log_tables_with_step(global_step)
        
        return results