Commit e155ae3b authored by YaningGao's avatar YaningGao
Browse files

minor

parent 408de6f7
Loading
Loading
Loading
Loading
+30 −8
Original line number Diff line number Diff line
@@ -177,7 +177,8 @@ class SVGService(BaseService):
        score_configs = []
        
        for env_id, result in env_processing_results.items():
            if result["valid"] and result["gen_image"] is not None:
            # Only process valid SVGs - skip invalid ones entirely for scoring
            if result["valid"] and result["gen_image"] is not None and result["metrics"]["turn_metrics"]["svg_is_valid"]:
                valid_env_ids.append(env_id)
                gt_images.append(result["env"].gt_image)
                gen_images.append(result["gen_image"])
@@ -201,16 +202,34 @@ class SVGService(BaseService):
                env = result["env"]
                scores = batch_results[i]
                
                # Update reward
                env.reward += scores["total_score"]
                env.total_reward += env.reward
                
                # Determine if action is effective - either:
                # 1. First generation (no previous score) with a positive score
                # 2. Improved score compared to previous generation
                previous_score = 0.0
                is_first_generation = True
                
                if env_id in self.cache and self.cache[env_id].get('scores') is not None:
                    previous_score = self.cache[env_id]['scores'].get('total_score', 0.0)
                    is_first_generation = False
                
                # Check effectiveness based on whether it's the first generation or an improvement
                if is_first_generation:
                    result["metrics"]["turn_metrics"]["action_is_effective"] = scores["total_score"] > 0
                else:
                    result["metrics"]["turn_metrics"]["action_is_effective"] = scores["total_score"] > previous_score
                
                # Update other metrics
                result["metrics"]["turn_metrics"]["dino_score"] = scores["dino_score"]
                result["metrics"]["turn_metrics"]["dreamsim_score"] = scores["dreamsim_score"]
                info = result["rst"].copy()
                info["scores"] = scores
                info["metrics"] = result["metrics"]
                
                # Update cache if needed
                if env_id in self.cache:
                    self.cache[env_id]['gen_image'] = env.gen_image
                    self.cache[env_id]['gen_svg_code'] = env.gen_svg_code
@@ -219,6 +238,7 @@ class SVGService(BaseService):
                observation = env._render(init_obs=False)
                results[env_id] = serialize_step_result((observation, env.reward, False, info))
        
        # Handle invalid cases or cases not processed above
        for env_id, result in env_processing_results.items():
            if env_id not in results:
                env = result["env"]
@@ -232,13 +252,15 @@ class SVGService(BaseService):
                elif "traj_metrics" not in info["metrics"]:
                    info["metrics"]["traj_metrics"] = {}
                    
                # For invalid SVGs, explicitly set scores to zero
                info["metrics"]["turn_metrics"]["action_is_valid"] = False
                info["metrics"]["turn_metrics"]["action_is_effective"] = False
                
                if "scores" not in info:
                # Set all scores to zero for invalid SVGs
                info["scores"] = {
                    "dino_score": 0.0,
                    "structural_score": 0.0,
                    "dreamsim_score": 0.0,
                    "total_score": 0.0
                }