Unverified Commit 8f1d0186 authored by Kangrui Wang's avatar Kangrui Wang Committed by GitHub
Browse files

Merge pull request #30 from RAGEN-AI/format

updated env states
parents abf92372 4cd9bc1d
Loading
Loading
Loading
Loading
+26 −0
Original line number Diff line number Diff line
@@ -314,6 +314,32 @@ class FrozenLakeEnv(BaseEnv):
        player_pos = self._get_player_position()
        return self.gym_env.desc[player_pos] in [b'G', b'H']

    def get_env_state(self):
        """
        Get the current state of the environment as a dictionary.
        
        Returns:
            Dict: Contains player position, target position, and hole positions
                as coordinate tuples (row, col)
        """
        # Get dimensions of the grid
        nrow, ncol = self.gym_env.desc.shape
        
        # Get player position
        player_position = self._get_player_position()  # Already returns (row, col)
        
        # Find target/goal position (marked as 'G')
        target_position = tuple(map(int, np.argwhere(self.gym_env.desc == b'G')[0]))
        
        # Find all hole positions (marked as 'H')
        hole_positions = [tuple(map(int, pos)) for pos in np.argwhere(self.gym_env.desc == b'H')]
        
        return {
            "player_position": player_position,
            "target_position": target_position,
            "hole_positions": hole_positions,
            "grid_size": (nrow, ncol),
        }

if __name__ == "__main__":
    """
+1 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ class FrozenLakeEnvConfig(BaseEnvConfig):
    # "free_think", "no_think", "grounding", "worldmodeling", "grounding_worldmodeling"
    # "grounding_symbolic", "worldmodeling_symbolic", "grounding_worldmodeling_symbolic"
    # "grounding_structured", "worldmodeling_structured", "grounding_worldmodeling_structured"
    use_accuracy_reward: bool = False
    use_state_reward: bool = False
    
    def config_id(self) -> str:
        id_fields=["is_slippery", "size", "p", "render_mode", "max_actions_per_step", "min_actions_to_succeed","format_reward"]
+91 −1
Original line number Diff line number Diff line
@@ -317,7 +317,7 @@ class NavigationEnv(BaseEnv):
        success = (dist <= self.SUCCESS_THRESHOLD)
        return float(success), dist
    
    def _render(self, init_obs=False):
    def _render(self, init_obs=True):
        """Render the environment observation.
        
        This method creates either a text representation or an image of the environment
@@ -399,6 +399,96 @@ class NavigationEnv(BaseEnv):
        """Close the environment."""
        self.env.stop()
        
    def get_env_state(self):
        """
        Get the current state of the navigation environment focusing on visible objects.
        
        Returns:
            Dict: Contains target position, target direction, visible objects,
                and instruction information with rounded distances
        """
        # Get agent information
        agent_metadata = self.env.last_event.metadata["agent"]
        agent_position = agent_metadata["position"]
        agent_rotation = agent_metadata["rotation"]["y"]  # Only y-axis rotation is relevant
        
        # Get target information
        target_position = self.episode_data["target_position"]
        target_type = self.episode_data["targetObjectType"]
        success, distance = self.measure_success()
        
        # Calculate target's relative direction
        dx_target = target_position["x"] - agent_position["x"]
        dz_target = target_position["z"] - agent_position["z"]
        angle_to_target = math.degrees(math.atan2(dx_target, dz_target))
        relative_angle_target = (angle_to_target - agent_rotation) % 360
        if relative_angle_target > 180:
            relative_angle_target -= 360
            
        # Determine target's relative position
        if -45 <= relative_angle_target <= 45:
            target_relative_direction = "ahead"
        elif 45 < relative_angle_target <= 135:
            target_relative_direction = "right"
        elif -135 <= relative_angle_target < -45:
            target_relative_direction = "left"
        else:
            target_relative_direction = "back"
        
        # Get visible objects with position and relationship data
        objects = self.env.last_event.metadata["objects"]
        visible_objects = []
        
        for obj in objects:
            if obj.get("visible", False):
                obj_position = obj["position"]
                
                # Calculate distance from agent to object
                obj_distance = math.sqrt(
                    (agent_position["x"] - obj_position["x"])**2 +
                    (agent_position["z"] - obj_position["z"])**2
                )
                
                # Round distance to 2 decimal places
                obj_distance = round(obj_distance, 2)
                
                # Calculate relative angle to object (in degrees)
                dx = obj_position["x"] - agent_position["x"]
                dz = obj_position["z"] - agent_position["z"]
                angle_to_obj = math.degrees(math.atan2(dx, dz))
                # Adjust for agent's rotation (0 means directly in front)
                relative_angle = (angle_to_obj - agent_rotation) % 360
                if relative_angle > 180:
                    relative_angle -= 360
                    
                # Determine relative position (front, back, left, right)
                if -45 <= relative_angle <= 45:
                    relative_direction = "ahead"
                elif 45 < relative_angle <= 135:
                    relative_direction = "right"
                elif -135 <= relative_angle < -45:
                    relative_direction = "left"
                else:
                    relative_direction = "back"
                
                # Store object information
                visible_objects.append({
                    "type": obj["objectType"],
                    "distance": obj_distance,
                    "relative_direction": relative_direction,
                })
    
        # Sort objects by distance (closest first)
        visible_objects.sort(key=lambda x: x["distance"])
        
        return {
            "target_obj_type": target_type, 
            "target_obj_distance": distance,
            "distance_to_target": round(distance, 2), 
            "target_relative_direction": target_relative_direction,
            "visible_objects": visible_objects[:self.config.max_objects_in_state],   
        }


if __name__ == "__main__":
    # Example usage
+1 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ class NavigationEnvConfig(BaseEnvConfig):
    prompt_format: str = "free_think" 
    # "free_think", "no_think", "grounding", "worldmodeling", "grounding_worldmodeling"
    use_accuracy_reward: bool = False
    max_objects_in_state: int = 10

    def config_id(self) -> str:
        """Generate a unique identifier for this configuration."""
+19 −18
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ class PrimitiveSkillEnv(BaseEnv):
        self.parse_func = parse_function_map[self.config.prompt_format]
        # Define the state keys for the environment
        self.state_keys = self.env.state_keys
        self.last_info = None
    
    def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
@@ -46,8 +47,8 @@ class PrimitiveSkillEnv(BaseEnv):
                - info: Empty dictionary for initial state
        """
        _, info = self.env.reset(seed=seed)
        obs = self._render(info, init_obs=True)
        self.last_info = info
        obs = self._render(init_obs=True)
        self.initial_reward = self._compute_reward()
        self.total_reward = 0
        self.steps = 0
@@ -113,9 +114,8 @@ class PrimitiveSkillEnv(BaseEnv):
            metrics["traj_metrics"]['success'] = True
        
        done = terminated or truncated
        info["action_is_valid"] = metrics["turn_metrics"]['action_is_valid']
        
        obs = self._render(info, init_obs=False, valid_actions=valid_actions)
        obs = self._render(init_obs=False, valid_actions=valid_actions)
        output_info["metrics"] = metrics
        
        self.total_reward += reward
@@ -181,32 +181,22 @@ class PrimitiveSkillEnv(BaseEnv):
        """
        return self._compute_reward() + self.total_reward - self.initial_reward - self.steps * 0.1

    def _get_current_state(self):
        """
        Get a representation of the current state for comparison.
        
        Returns:
            dict: Dictionary representation of important state components
        """
        # This is a simple implementation - customize based on your environment
        return {k: v for k, v in self.last_info.items() if k.endswith('_position')}
    
    def _render(self, info, init_obs=False, valid_actions=None,seed=42):
    def _render(self, init_obs=True, valid_actions=None,seed=42):
        """
        Render the environment as an observation.
        
        Args:
            info (dict): Environment info dictionary
            init_obs (bool): If True, create initial observation
            valid_actions (list): List of valid actions executed (for step observations)
        
        Returns:
            Dict: Observation dictionary containing observation string and optional image data
        """
        new_info = handle_info(info.copy(), state_keys=self.state_keys,mask_success=self.config.mask_success, env=self.env)
        info = self.last_info.copy()
        new_info = handle_info(info, state_keys=self.state_keys,mask_success=self.config.mask_success, env=self.env)
        positions_list = list(new_info['obj_positions'].values())
        # random.seed(seed)
        # random.shuffle(positions_list)  # This shuffles the list in-place

        object_positions = str(positions_list)
        # object_names=str([key.removesuffix("_position") for key in new_info['obj_positions'].keys()])
        other_information = str(new_info['other_info'])
@@ -267,6 +257,17 @@ class PrimitiveSkillEnv(BaseEnv):
            }
    
    
    def get_env_state(self):
        """
        Get the current state of the environment.
        
        Returns:
            dict: Dictionary representation of the environment state
        """
        rst=handle_info(self.last_info, state_keys=self.state_keys,mask_success=self.config.mask_success, env=self.env)
        return rst["obj_positions"]
    
    
    def _parse_action(self, action_str):
        """
        Parse a single action string into an action array.
Loading