Commit 408de6f7 authored by YaningGao's avatar YaningGao
Browse files

svg revise

parent b7d08752
Loading
Loading
Loading
Loading
+18 −38
Original line number Diff line number Diff line
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoImageProcessor
from PIL import Image
import torch.nn as nn
import threading
import logging
from typing import Dict, List, Tuple, Optional, Any, Union
from transformers import AutoModel, AutoImageProcessor
from PIL import Image


# @TODO clean codes of this section

_model_cache = {}
_model_cache_lock = threading.Lock()
_model_counter = 0  

def get_dino_model(model_size="small", device="cuda:0"):
    global _model_counter
    cache_key = f"{model_size}_{device}"
    
    with _model_cache_lock:
        if cache_key not in _model_cache:
            _model_counter += 1
            import os
            pid = os.getpid()
            logging.info(f"Process {pid}: Created DINO model #{_model_counter}: {model_size} on {device}")
            _model_cache[cache_key] = DINOScoreCalculator(model_size=model_size, device=device)
        return _model_cache[cache_key]

class AverageMeter(object):
    """Computes and stores the average and current value"""

@@ -46,6 +25,7 @@ class AverageMeter(object):
        self.count += n
        self.avg = self.sum / self.count


class BaseMetric:
    def __init__(self):
        self.meter = AverageMeter()
@@ -59,7 +39,7 @@ class BaseMetric:
        """
        values = []
        batch_size = len(next(iter(batch.values())))
        for index in tqdm(range(batch_size)):
        for index in range(batch_size):
            kwargs = {}
            for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]:
                if key in batch:
@@ -69,7 +49,7 @@ class BaseMetric:
            except Exception as e:
                print("Error calculating metric: {}".format(e))
                continue
            if math.isnan(measure):
            if not (measure >= 0) and not (measure <= 0):  # Check for NaN
                continue
            values.append(measure)

@@ -85,14 +65,13 @@ class BaseMetric:
            return score, values

    def metric(self, **kwargs):
        """
        This method should be overridden by subclasses to provide the specific metric computation.
        """
        """This method should be overridden by subclasses"""
        raise NotImplementedError("The metric method must be implemented by subclasses.")
    
    def get_average_score(self):
        return self.meter.avg


class DINOScoreCalculator(BaseMetric): 
    def __init__(self, config=None, model_size='large', device='cuda:0'):
        super().__init__()
@@ -158,19 +137,11 @@ class DINOScoreCalculator(BaseMetric):
    def calculate_batch_scores(self, gt_images: List[Any], gen_images: List[Any]) -> List[float]:
        """
        Calculate similarity scores for multiple image pairs in a single batch
        
        Args:
            gt_images: List of ground truth images (PIL Images, file paths, or tensors)
            gen_images: List of generated images (PIL Images, file paths, or tensors)
            
        Returns:
            List of similarity scores (float values between 0-1)
        """      
        if not gt_images: 
            return []
        
        gt_features = self.process_input(gt_images, self.processor)
        
        gen_features = self.process_input(gen_images, self.processor)
        
        cos = nn.CosineSimilarity(dim=1)
@@ -181,3 +152,12 @@ class DINOScoreCalculator(BaseMetric):
        return scores


# Compatibility function for existing code
def get_dino_model(model_size="small", device="cuda:0"):
    """
    Create a new DINO model instance.
    This function exists for backward compatibility.
    The service should use DINOScoreCalculator directly.
    """
    logging.info(f"Creating new DINO model: {model_size} on {device}")
    return DINOScoreCalculator(model_size=model_size, device=device)
 No newline at end of file
+12 −52
Original line number Diff line number Diff line
@@ -2,38 +2,8 @@ import torch
from PIL import Image
import os
from dreamsim import dreamsim
import threading
import logging

# Create global cache and lock, similar to DINO implementation
_model_cache = {}
_model_cache_lock = threading.Lock()
_model_counter = 0


def get_dreamsim_model(device="cuda:0"):
    """
    Get a singleton instance of DreamSim model, using cache to avoid duplicate loading

    Args:
        device: Device to run model on

    Returns:
        DreamSimScoreCalculator: Instance of DreamSim calculator
    """
    global _model_counter

    # Use device as cache key
    cache_key = f"dreamsim_{device}"

    with _model_cache_lock:
        if cache_key not in _model_cache:
            _model_counter += 1
            pid = os.getpid()
            logging.info(f"Process {pid}: Created DreamSim model #{_model_counter} on {device}")
            _model_cache[cache_key] = DreamSimScoreCalculator(device=device)
        return _model_cache[cache_key]


class DreamSimScoreCalculator:
    """
@@ -43,11 +13,6 @@ class DreamSimScoreCalculator:
    def __init__(self, pretrained=True, cache_dir="~/.cache", device=None):
        """
        Initialize DreamSim model.

        Args:
            pretrained: Whether to use pretrained model
            cache_dir: Cache directory for model weights
            device: Device to run the model on (defaults to CUDA if available, else CPU)
        """
        cache_dir = os.path.expanduser(cache_dir)

@@ -63,13 +28,6 @@ class DreamSimScoreCalculator:
    def calculate_similarity_score(self, gt_im, gen_im):
        """
        Calculate similarity score between ground truth and generated images.

        Args:
            gt_im: Ground truth PIL Image
            gen_im: Generated PIL Image

        Returns:
            float: Similarity score (1 - distance, normalized to [0, 1])
        """
        # Preprocess images
        img1 = self.preprocess(gt_im)
@@ -84,8 +42,6 @@ class DreamSimScoreCalculator:
            distance = self.model(img1, img2).item()

        # Convert distance to similarity score (1 - normalized distance)
        # DreamSim usually outputs values in range [0, 1] where lower means more similar
        # We invert it so that higher means more similar (1 = identical)
        similarity = 1.0 - min(1.0, max(0.0, distance))

        return similarity
@@ -93,13 +49,6 @@ class DreamSimScoreCalculator:
    def calculate_batch_scores(self, gt_images, gen_images):
        """
        Calculate similarity scores for a batch of image pairs.

        Args:
            gt_images: List of ground truth PIL Images
            gen_images: List of generated PIL Images

        Returns:
            List[float]: List of similarity scores
        """
        # Preprocess all images
        gt_processed = [self.preprocess(img) for img in gt_images]
@@ -121,3 +70,14 @@ class DreamSimScoreCalculator:
            scores.append(similarity)

        return scores


# Compatibility function for existing code
def get_dreamsim_model(device="cuda:0"):
    """
    Create a new DreamSim model instance.
    This function exists for backward compatibility.
    The service should use DreamSimScoreCalculator directly.
    """
    logging.info(f"Creating new DreamSim model on {device}")
    return DreamSimScoreCalculator(device=device)
 No newline at end of file
+5 −15
Original line number Diff line number Diff line
@@ -101,16 +101,8 @@ class SVGEnv(BaseEnv):
        
        return self._render(init_obs=True), {}

    def step(self, action_str: str, dino_model=None) -> Tuple[Dict, float, bool, Dict]:
        """Execute a step in the environment.
        
        Args:
            action_str: Raw text response from LLM
            dino_model: Optional DINO model for scoring
            
        Returns:
            Observation, reward, done, info
        """
    def step(self, action_str: str, dino_model=None, dreamsim_model=None) -> Tuple[Dict, float, bool, Dict]:
        """Execute a step in the environment."""
        # Process the LLM response to extract actions
        rst = self.parse_func(
            response=action_str,
@@ -170,7 +162,7 @@ class SVGEnv(BaseEnv):
                _, gen_image = process_and_rasterize_svg(self.gen_svg_code)
                self.gen_image = gen_image
                
                # Calculate score
                # Calculate score using service models if provided
                score_config = self.config.get_score_config()
                scores = calculate_total_score(
                    gt_im=self.gt_image,
@@ -178,7 +170,8 @@ class SVGEnv(BaseEnv):
                    gt_code=self.gt_svg_code,
                    gen_code=self.gen_svg_code,
                    score_config=score_config,
                    dino_model=dino_model
                    dino_model=dino_model,
                    dreamsim_model=dreamsim_model
                ) 
                
                # Set metrics and update reward
@@ -189,9 +182,6 @@ class SVGEnv(BaseEnv):
                metrics["turn_metrics"]["action_is_effective"] = scores["total_score"] > 0
                    
            except Exception as e:
                import traceback
                print(f"Error processing SVG: {e}")
                traceback.print_exc()
                # Reset actions and update metrics
                self.valid_actions = []
                metrics["turn_metrics"]["action_is_valid"] = False
+0 −32
Original line number Diff line number Diff line
import numpy as np
import cv2
from vagen.env.svg.dino import DINOScoreCalculator


def calculate_structural_accuracy(gt_im, gen_im):
    "range from 0 - 1"
@@ -20,24 +18,6 @@ def calculate_structural_accuracy(gt_im, gen_im):
def calculate_total_score(gt_im, gen_im, gt_code, gen_code, score_config, dino_model=None, dreamsim_model=None):
    """
    Calculate all metrics and return a comprehensive score
    
    Args:
        gt_im: Ground truth image
        gen_im: Generated image
        gt_code: Ground truth SVG code
        gen_code: Generated SVG code
        score_config: Dictionary containing scoring parameters
            - model_size: small, base, large
            - dino_only: Whether to use only DINO for scoring
            - dino_weight: Weight for DINO score
            - structural_weight: Weight for structural score
            - dreamsim_weight: Weight for DreamSim score
            - device: Dictionary with keys "dino" and "dreamsim" specifying device
        dino_model: Pre-loaded DINO model (optional)
        dreamsim_model: Pre-loaded DreamSim model (optional)
        
    Returns:
        dict: Dictionary of all scores including the total weighted score
    """
    # Get configuration parameters with defaults
    model_size = score_config.get("model_size", "small")
@@ -108,18 +88,6 @@ def calculate_total_score_batch(gt_images, gen_images, gt_codes, gen_codes, scor
                                dreamsim_model=None):
    """
    Calculate scores for multiple image pairs in batch mode

    Args:
        gt_images: List of ground truth images
        gen_images: List of generated images
        gt_codes: List of ground truth SVG codes
        gen_codes: List of generated SVG codes
        score_configs: List of scoring parameters dictionaries
        dino_model: Pre-loaded DINO model (optional)
        dreamsim_model: Pre-loaded DreamSim model (optional)

    Returns:
        List of dictionaries containing all scores
    """
    batch_size = len(gt_images)
    if batch_size == 0:
+100 −147

File changed.

Preview size limit exceeded, changes collapsed.

Loading