Commit 4a7a1ae4 authored by YaningGao's avatar YaningGao
Browse files

minor

parent e155ae3b
Loading
Loading
Loading
Loading
+13 −20
Original line number Diff line number Diff line
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import logging
from typing import Dict, List, Tuple, Optional, Any, Union
from transformers import AutoModel, AutoImageProcessor
from PIL import Image

import math

class AverageMeter(object):
    """Computes and stores the average and current value"""
@@ -47,9 +46,9 @@ class BaseMetric:
            try:
                measure = self.metric(**kwargs)
            except Exception as e:
                print("Error calculating metric: {}".format(e))
                print(f"Error calculating metric: {e}")
                continue
            if not (measure >= 0) and not (measure <= 0):  # Check for NaN
            if math.isnan(measure):
                continue
            values.append(measure)

@@ -81,7 +80,6 @@ class DINOScoreCalculator(BaseMetric):
        self.model, self.processor = self.get_DINOv2_model(model_size)
        self.device = device
        self.model = self.model.to(self.device)

        self.metric = self.calculate_DINOv2_similarity_score

    def get_DINOv2_model(self, model_size):
@@ -96,8 +94,10 @@ class DINOScoreCalculator(BaseMetric):
        return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size)

    def process_input(self, image, processor):
        """Process images efficiently in batches when possible"""
        if isinstance(image, list):
            if all(isinstance(img, Image.Image) for img in image):
                # Process all images in a single batch to maximize GPU utilization
                with torch.no_grad():
                    inputs = processor(images=image, return_tensors="pt").to(self.device)
                    outputs = self.model(**inputs)
@@ -111,18 +111,21 @@ class DINOScoreCalculator(BaseMetric):
        
        if isinstance(image, str):
            image = Image.open(image)
            
        if isinstance(image, Image.Image):
            with torch.no_grad():
                inputs = processor(images=image, return_tensors="pt").to(self.device)
                outputs = self.model(**inputs)
                features = outputs.last_hidden_state.mean(dim=1)
            return features
        elif isinstance(image, torch.Tensor):
            features = image.unsqueeze(0) if image.dim() == 1 else image
            return features
        else:
            raise ValueError("Input must be a file path, PIL Image, or tensor of features")
        return features

    def calculate_DINOv2_similarity_score(self, **kwargs):
        """Calculate similarity score between two images"""
        image1 = kwargs.get('gt_im')
        image2 = kwargs.get('gen_im')
        features1 = self.process_input(image1, self.processor)
@@ -130,13 +133,14 @@ class DINOScoreCalculator(BaseMetric):

        cos = nn.CosineSimilarity(dim=1)
        sim = cos(features1, features2).item()
        sim = (sim + 1) / 2
        sim = (sim + 1) / 2  # Convert from [-1, 1] to [0, 1] range

        return sim
    
    def calculate_batch_scores(self, gt_images: List[Any], gen_images: List[Any]) -> List[float]:
        """
        Calculate similarity scores for multiple image pairs in a single batch
        Calculate similarity scores for multiple image pairs in a single batch.
        DINO can process all images in a batch efficiently.
        """      
        if not gt_images: 
            return []
@@ -150,14 +154,3 @@ class DINOScoreCalculator(BaseMetric):
        scores = [(sim.item() + 1) / 2 for sim in similarities]
        
        return scores
 No newline at end of file


# Compatibility function for existing code
def get_dino_model(model_size="small", device="cuda:0"):
    """
    Create a new DINO model instance.
    This function exists for backward compatibility.
    The service should use DINOScoreCalculator directly.
    """
    logging.info(f"Creating new DINO model: {model_size} on {device}")
    return DINOScoreCalculator(model_size=model_size, device=device)
 No newline at end of file
+17 −30
Original line number Diff line number Diff line
@@ -3,7 +3,8 @@ from PIL import Image
import os
from dreamsim import dreamsim
import logging

from concurrent.futures import ThreadPoolExecutor
from typing import List, Any

class DreamSimScoreCalculator:
    """
@@ -46,38 +47,24 @@ class DreamSimScoreCalculator:

        return similarity

    def calculate_batch_scores(self, gt_images, gen_images):
    def calculate_batch_scores(self, gt_images: List[Any], gen_images: List[Any]) -> List[float]:
        """
        Calculate similarity scores for a batch of image pairs.
        Calculate similarity scores for multiple image pairs.
        Since DreamSim doesn't natively support batch comparison, we process each pair individually.
        """
        # Preprocess all images
        gt_processed = [self.preprocess(img) for img in gt_images]
        gen_processed = [self.preprocess(img) for img in gen_images]
        if not gt_images or not gen_images:
            return []
            
        scores = []
        # Process each pair
        for gt, gen in zip(gt_processed, gen_processed):
            # Move to device
            gt = gt.to(self.device)
            gen = gen.to(self.device)
        batch_size = len(gt_images)
        
            # Calculate distance
            with torch.no_grad():
                distance = self.model(gt, gen).item()
        gt_processed = [self.preprocess(img).to(self.device) for img in gt_images]
        gen_processed = [self.preprocess(img).to(self.device) for img in gen_images]
        
            # Convert to similarity score
        scores = []
        for i in range(batch_size):
            with torch.no_grad():
                distance = self.model(gt_processed[i], gen_processed[i]).item()
            similarity = 1.0 - min(1.0, max(0.0, distance))
            scores.append(similarity)
        
        return scores
 No newline at end of file


# Compatibility function for existing code
def get_dreamsim_model(device="cuda:0"):
    """
    Create a new DreamSim model instance.
    This function exists for backward compatibility.
    The service should use DreamSimScoreCalculator directly.
    """
    logging.info(f"Creating new DreamSim model on {device}")
    return DreamSimScoreCalculator(device=device)
 No newline at end of file
+2 −15
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ class SvgEnvConfig(BaseEnvConfig):
    action_sep: str = "~~"
    # Score configuration
    model_size: str = "small"  # 'small', 'base', or 'large'
    dino_only: bool = False
    # Weights for different scoring components
    dino_weight: Optional[float] = None
    structural_weight: Optional[float] = None
    dreamsim_weight: Optional[float] = None
@@ -43,7 +43,6 @@ class SvgEnvConfig(BaseEnvConfig):
        id_fields = [
            "dataset_name", 
            "model_size", 
            "dino_only", 
            "format_reward", 
            "format_penalty"
        ]
@@ -65,7 +64,6 @@ class SvgEnvConfig(BaseEnvConfig):
        """Get the score configuration dictionary"""
        score_config = {
            "model_size": self.model_size,
            "dino_only": self.dino_only,
            "device": self.device  # Include processed device configuration in score config
        }
        
@@ -78,14 +76,3 @@ class SvgEnvConfig(BaseEnvConfig):
            score_config["dreamsim_weight"] = self.dreamsim_weight
            
        return score_config
 No newline at end of file


if __name__ == "__main__":
    # Example usage
    config = SvgEnvConfig(
        device={"dino": 1, "dreamsim": 2}  # Will be converted to "cuda:1" and "cuda:2"
    )
    
    print(config.config_id())
    print(config.get_score_config())
    print(f"Processed device config: {config.device}")  # Should show {"dino": "cuda:1", "dreamsim": "cuda:2"}
 No newline at end of file
+31 −115
Original line number Diff line number Diff line
@@ -21,25 +21,17 @@ def calculate_total_score(gt_im, gen_im, gt_code, gen_code, score_config, dino_m
    """
    # Get configuration parameters with defaults
    model_size = score_config.get("model_size", "small")
    dino_only = score_config.get("dino_only", False)
    
    # Get device configuration with defaults
    devices = score_config.get("device", {"dino": "cuda:0", "dreamsim": "cuda:0"})
    dino_device = devices.get("dino", "cuda:0")
    dreamsim_device = devices.get("dreamsim", "cuda:0")
    
    # Define default weights based on model size
    default_weights = {
        "small": {"dino": 3.0, "structural": 7.0, "dreamsim": 5.0},
        "base": {"dino": 5.0, "structural": 5.0, "dreamsim": 5.0},
        "large": {"dino": 6.0, "structural": 4.0, "dreamsim": 5.0}
    }
    
    # Get weights with defaults
    weights = {
        "dino": score_config.get("dino_weight", default_weights[model_size]["dino"]),
        "structural": score_config.get("structural_weight", default_weights[model_size]["structural"]),
        "dreamsim": score_config.get("dreamsim_weight", default_weights[model_size]["dreamsim"])
        "dino": score_config.get("dino_weight", 0.0),
        "structural": score_config.get("structural_weight", 0.0),
        "dreamsim": score_config.get("dreamsim_weight", 0.0)
    }
    
    # Initialize scores
@@ -50,30 +42,12 @@ def calculate_total_score(gt_im, gen_im, gt_code, gen_code, score_config, dino_m
        "total_score": 0.0
    }
    
    # Calculate DINO score if needed
    if weights["dino"] > 0:
        if dino_model is None:
            from vagen.env.svg.dino import get_dino_model
            dino_model = get_dino_model(model_size, device=dino_device)
    scores["dino_score"] = float(dino_model.calculate_DINOv2_similarity_score(gt_im=gt_im, gen_im=gen_im))
    
    # Calculate DreamSim score if needed
    if weights["dreamsim"] > 0:
        if dreamsim_model is None:
            from vagen.env.svg.dreamsim import get_dreamsim_model
            dreamsim_model = get_dreamsim_model(device=dreamsim_device)
    scores["dreamsim_score"] = float(dreamsim_model.calculate_similarity_score(gt_im=gt_im, gen_im=gen_im))
    
    # If DINO only mode, return only DINO score
    if dino_only:
        scores["total_score"] = scores["dino_score"]
        return scores
    
    # Calculate structural score if needed
    if weights["structural"] > 0:
        scores["structural_score"] = max(0.0, float(calculate_structural_accuracy(gt_im, gen_im)))
    
    # Calculate weighted total score
    weighted_sum = (
        scores["dino_score"] * weights["dino"] +
        scores["structural_score"] * weights["structural"] +
@@ -87,7 +61,8 @@ def calculate_total_score(gt_im, gen_im, gt_code, gen_code, score_config, dino_m
def calculate_total_score_batch(gt_images, gen_images, gt_codes, gen_codes, score_configs, dino_model=None,
                              dreamsim_model=None):
    """
    Calculate scores for multiple image pairs in batch mode
    Batch score calculation that leverages model batch processing.
    Always calculates all scores for metrics, regardless of weights.
    """
    batch_size = len(gt_images)
    if batch_size == 0:
@@ -105,93 +80,34 @@ def calculate_total_score_batch(gt_images, gen_images, gt_codes, gen_codes, scor
        "total_score": 0.0
    } for _ in range(batch_size)]
    
    # Check if we need to calculate DINO scores and get device
    need_dino = False
    dino_device = "cuda:0"
    for score_config in score_configs:
        if score_config.get("dino_weight", 0.0) > 0:
            need_dino = True
            devices = score_config.get("device", {"dino": "cuda:0", "dreamsim": "cuda:0"})
            dino_device = devices.get("dino", "cuda:0")
            break

    # Check if we need to calculate DreamSim scores and get device
    need_dreamsim = False
    dreamsim_device = "cuda:0"
    for score_config in score_configs:
        if score_config.get("dreamsim_weight", 0.0) > 0:
            need_dreamsim = True
            devices = score_config.get("device", {"dino": "cuda:0", "dreamsim": "cuda:0"})
            dreamsim_device = devices.get("dreamsim", "cuda:0")
            break

    # Calculate DINO scores in batch if needed
    if need_dino:
    if dino_model is None:
            from vagen.env.svg.dino import get_dino_model
            # Default to small model size if not specified
            model_size = score_configs[0].get("model_size", "small") if score_configs else "small"
            dino_model = get_dino_model(model_size, device=dino_device)

        # Calculate all DINO scores at once using batch processing
        dino_scores = dino_model.calculate_batch_scores(gt_images, gen_images)

        # Assign scores to results
        for i, score in enumerate(dino_scores):
            batch_results[i]["dino_score"] = float(score)

    # Calculate DreamSim scores in batch if needed
    if need_dreamsim:
        raise ValueError("DINO model must be provided by the service")
    if dreamsim_model is None:
            from vagen.env.svg.dreamsim import get_dreamsim_model
            dreamsim_model = get_dreamsim_model(device=dreamsim_device)
        raise ValueError("DreamSim model must be provided by the service")
    
        # Calculate all DreamSim scores at once using batch processing
    dino_scores = dino_model.calculate_batch_scores(gt_images, gen_images)
    dreamsim_scores = dreamsim_model.calculate_batch_scores(gt_images, gen_images)
    
        # Assign scores to results
        for i, score in enumerate(dreamsim_scores):
            batch_results[i]["dreamsim_score"] = float(score)
    structural_scores = [calculate_structural_accuracy(gt_images[i], gen_images[i]) 
                          for i in range(batch_size)]

    # Calculate structural scores and total scores
    # Assign scores and calculate total scores
    for i in range(batch_size):
        score_config = score_configs[i]
        result = batch_results[i]

        # Check if DINO-only mode
        dino_only = score_config.get("dino_only", False)
        if dino_only:
            result["total_score"] = result["dino_score"]
            continue

        # Get model size for default weights
        model_size = score_config.get("model_size", "small")

        # Define default weights based on model size
        default_weights = {
            "small": {"dino": 3.0, "structural": 7.0, "dreamsim": 5.0},
            "base": {"dino": 5.0, "structural": 5.0, "dreamsim": 5.0},
            "large": {"dino": 6.0, "structural": 4.0, "dreamsim": 5.0}
        }
        batch_results[i]["dino_score"] = float(dino_scores[i])
        batch_results[i]["dreamsim_score"] = float(dreamsim_scores[i])
        batch_results[i]["structural_score"] = max(0.0, float(structural_scores[i]))
        
        # Get weights with defaults
        weights = {
            "dino": score_config.get("dino_weight", default_weights[model_size]["dino"]),
            "structural": score_config.get("structural_weight", default_weights[model_size]["structural"]),
            "dreamsim": score_config.get("dreamsim_weight", default_weights[model_size]["dreamsim"])
            "dino": score_configs[i].get("dino_weight", 0.0),
            "structural": score_configs[i].get("structural_weight", 0.0),
            "dreamsim": score_configs[i].get("dreamsim_weight", 0.0)
        }

        # Calculate structural score if needed
        if weights["structural"] > 0:
            from vagen.env.svg.score import calculate_structural_accuracy
            result["structural_score"] = max(0.0, float(calculate_structural_accuracy(gt_images[i], gen_images[i])))

        # Calculate weighted total score
        weighted_sum = (
                result["dino_score"] * weights["dino"] +
                result["structural_score"] * weights["structural"] +
                result["dreamsim_score"] * weights["dreamsim"]
            batch_results[i]["dino_score"] * weights["dino"] +
            batch_results[i]["structural_score"] * weights["structural"] +
            batch_results[i]["dreamsim_score"] * weights["dreamsim"]
        )
        result["total_score"] = max(0.0, weighted_sum)
        batch_results[i]["total_score"] = max(0.0, weighted_sum)

    return batch_results
 No newline at end of file
+20 −12
Original line number Diff line number Diff line
@@ -165,10 +165,16 @@ class SVGService(BaseService):
        return results
    
    def step_batch(self, ids2actions: Dict[Any, Any]) -> Dict[Any, Tuple[Dict, float, bool, Dict]]:
        """
        Optimized step_batch method that maximizes RAM and GPU utilization
        """
        results = {}
        
        # Process SVG actions in batch
        env_processing_results, error_results = self._process_svg_actions_batch(ids2actions)
        results.update(error_results)
        
        # Collect valid SVGs for batch processing
        valid_env_ids = []
        gt_images = []
        gen_images = []
@@ -177,7 +183,6 @@ class SVGService(BaseService):
        score_configs = []
        
        for env_id, result in env_processing_results.items():
            # Only process valid SVGs - skip invalid ones entirely for scoring
            if result["valid"] and result["gen_image"] is not None and result["metrics"]["turn_metrics"]["svg_is_valid"]:
                valid_env_ids.append(env_id)
                gt_images.append(result["env"].gt_image)
@@ -187,16 +192,17 @@ class SVGService(BaseService):
                score_configs.append(result["env"].config.get_score_config())
        
        if valid_env_ids:
            # Get the models from the service
            # Get models from service
            dino_model = self.get_dino_model()
            dreamsim_model = self.get_dreamsim_model()

            # Calculate all scores at once with the service models
            # Calculate all scores at once
            batch_results = calculate_total_score_batch(
                gt_images, gen_images, gt_codes, gen_codes, score_configs,
                dino_model=dino_model, dreamsim_model=dreamsim_model
            )
            
            # Process results and update environments
            for i, env_id in enumerate(valid_env_ids):
                result = env_processing_results[env_id]
                env = result["env"]
@@ -206,9 +212,7 @@ class SVGService(BaseService):
                env.reward += scores["total_score"]
                env.total_reward += env.reward
                
                # Determine if action is effective - either:
                # 1. First generation (no previous score) with a positive score
                # 2. Improved score compared to previous generation
                # Determine effectiveness based on improvement
                previous_score = 0.0
                is_first_generation = True
                
@@ -216,29 +220,30 @@ class SVGService(BaseService):
                    previous_score = self.cache[env_id]['scores'].get('total_score', 0.0)
                    is_first_generation = False
                
                # Check effectiveness based on whether it's the first generation or an improvement
                # Check if first generation or improved
                if is_first_generation:
                    result["metrics"]["turn_metrics"]["action_is_effective"] = scores["total_score"] > 0
                else:
                    result["metrics"]["turn_metrics"]["action_is_effective"] = scores["total_score"] > previous_score
                
                # Update other metrics
                # Update metrics
                result["metrics"]["turn_metrics"]["dino_score"] = scores["dino_score"]
                result["metrics"]["turn_metrics"]["dreamsim_score"] = scores["dreamsim_score"]
                info = result["rst"].copy()
                info["scores"] = scores
                info["metrics"] = result["metrics"]
                
                # Update cache if needed
                # Update cache
                if env_id in self.cache:
                    self.cache[env_id]['gen_image'] = env.gen_image
                    self.cache[env_id]['gen_svg_code'] = env.gen_svg_code
                    self.cache[env_id]['scores'] = scores
                
                # Create observation
                observation = env._render(init_obs=False)
                results[env_id] = serialize_step_result((observation, env.reward, False, info))
        
        # Handle invalid cases or cases not processed above
        # Handle invalid cases
        for env_id, result in env_processing_results.items():
            if env_id not in results:
                env = result["env"]
@@ -252,11 +257,11 @@ class SVGService(BaseService):
                elif "traj_metrics" not in info["metrics"]:
                    info["metrics"]["traj_metrics"] = {}
                    
                # For invalid SVGs, explicitly set scores to zero
                # Set invalid metrics
                info["metrics"]["turn_metrics"]["action_is_valid"] = False
                info["metrics"]["turn_metrics"]["action_is_effective"] = False
                
                # Set all scores to zero for invalid SVGs
                # Zero scores for invalid SVGs
                info["scores"] = {
                    "dino_score": 0.0,
                    "structural_score": 0.0,
@@ -264,6 +269,7 @@ class SVGService(BaseService):
                    "total_score": 0.0
                }
                
                # Apply penalty
                reward = 0.0
                if hasattr(env.config, "format_penalty"):
                    reward = env.config.format_penalty
@@ -273,8 +279,10 @@ class SVGService(BaseService):
                env.gen_svg_code = None
                env.gen_image = None
                
                # Create observation
                observation = env._render(init_obs=False)
                
                # Update cache
                if env_id in self.cache:
                    self.cache[env_id]['gen_image'] = None
                    self.cache[env_id]['gen_svg_code'] = None
Loading