Commit 98685a3b authored by 张泽凯's avatar 张泽凯
Browse files

add vlnce env

parent e51eb887
Loading
Loading
Loading
Loading

vagen/env/vlnce/env.py

0 → 100644
+383 −0
Original line number Diff line number Diff line
from vagen.env.base.base_env import BaseEnv
import numpy as np
import copy
from typing import Dict, List, Optional, Tuple, Any
from vagen.env.utils.env_utils import NoLoggerWarnings, set_seed
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .prompt import system_prompt, init_observation_template, action_template, format_prompt
from .env_config import VLNCEEnvConfig
from vagen.env.utils.state_reward_text_utils import env_state_reward_wrapper
import habitat
from VLN_CE.vlnce_baselines.config.default import get_config
import json
import regex as re
import time

def reset_env(env,i):
    env._reset_stats()
    assert len(env.episodes) > 0, "Episodes list is empty"
    if env._current_episode is not None:
        env._current_episode._shortest_path_cache = None

    env._current_episode = env._dataset.episodes[i]
    env.reconfigure(env._config)

    observations = env.task.reset(episode=env.current_episode)
    env._task.measurements.reset_measures(
        episode=env.current_episode, task=env.task
    )
    return observations

def init_env(simulator_config_path):
    config = get_config(simulator_config_path, None)
    dataset = habitat.datasets.make_dataset(id_dataset=config.TASK_CONFIG.DATASET.TYPE, config=config.TASK_CONFIG.DATASET)
    env = habitat.Env(config.TASK_CONFIG,dataset=dataset)
    mapping = {env._dataset.episodes[i].episode_id:i for i in range(len(env._dataset.episodes))}
    
    def reset(episode_id):
        assert str(episode_id) in mapping, f"Episode ID {episode_id} not found in mapping"
        return reset_env(env, mapping[str(episode_id)])
        
    return env, reset




class VLNCEEnv(BaseEnv):
    """
    VLNCE Environment for training and evaluating language models as agents.
    """
    
    ACTION_LOOKUP = {
        "stop": 0,
        "move foward": 1,
        "turn left": 2,
        "turn right": 3,
    }


    def __init__(self, config: VLNCEEnvConfig):
        """
        Initialize the VLNCE environment.
        
        Args:
            config (VLNCEEnvConfig): Configuration parameters for the environment
                including map size, slipperiness, rendering mode, etc.
        """
        BaseEnv.__init__(self)
        self.config = config
       
        self.env, self.reset_env = init_env(config.simulator_config_path)
        self.obs = self.reset_env(self.config.episode_id)
        self.instruction = obs["instruction"]["text"]
        self.initial_states = self.get_initial_checkpoint()
        
        
        self.reset_keys = ['ndtw', 'success', 'spl', 'oracle_success', 'oracle_spl']
        
        # Initialize episode state
        self.total_reward = 0
        self.valid_actions = []
        self.reward = 0
        
        # Store the format prompt function for later use
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]
        
        self._current_step = 0
        self._max_episode_steps = 30 # TODO: dynamically tune this parameter
        self._episode_start_time = 0

        self.action_patterns = {
            "r2r": r'(move forward) (25|50|75)cm|(turn (?:left|right)) (15|30|45) degrees|(stop)',
            "rxr": r'(move forward) (25|50|75)cm|(turn (?:left|right)) (30|60|90) degrees|(stop)',
        }

    def get_initial_checkpoint(self):
        env = self.env
        history_actions = self.config.history_actions
        for action in history_actions:
            env.step(action)
            previous_state = env.sim.get_agent_state()
            previous_metrics = env.get_metrics()
            locations_length = len(env._task.measurements.measures['ndtw'].locations)
            
        return {
            "previous_state": previous_state,
            "previous_metrics": previous_metrics,
            "locations_length": locations_length
        }
        
        
    def reset(self, seed=None):
        self._current_step = 0
        self._episode_start_time = time.time()
        self.total_reward = 0
        self.valid_actions = []
        self.reward = 0
        
        initial_states = self.initial_states
        env = self.env
        reset_keys = self.reset_keys
        
        previous_state = initial_states["previous_state"]
        previous_metrics = initial_states["previous_metrics"]
        locations_length = initial_states["locations_length"]
        
        env.sim.set_agent_state(previous_state.position,previous_state.rotation)
        env._task.measurements.measures['ndtw'].locations = env._task.measurements.measures['ndtw'].locations[:locations_length]
        env._task.measurements.update_measures(episode=env.current_episode, action=None, task=env._task)
        for key, value in previous_metrics.items():
            if key in reset_keys:
                env._task.measurements.measures[key]._metric = value

        
        return self._render(), {}

    # @env_state_reward_wrapper
    def step(self, action_str: str):
        """
        Take a step in the environment based on the agent's action.
        
        This method:
        1. Parses the raw LLM response to extract actions
        2. Executes each valid action in sequence
        3. Calculates rewards and metrics
        4. Generates the next observation
        
        Args:
            action_str (str): Raw string from LLM containing actions
        
        Returns:
            Tuple[Dict, float, bool, Dict]:
                - obs: Dictionary with observation string and optional image data
                - reward: Numeric reward for the step
                - done: Boolean indicating if episode is complete
                - info: Dictionary containing metrics and parsed action data
        """
        # Parse the LLM's raw response to extract actions
        rst = self.parse_func(
            response=action_str,
            special_token_list=self.config.special_token_list,
            action_sep=self.config.action_sep,
            max_actions=self.config.max_actions_per_step
        )
        
        action_list = rst['actions']
        
        metrics = {
            "turn_metrics": {
                "action_is_valid": len(action_list) > 0,
                "action_is_effective": False,
            },
            "traj_metrics": {
                "success": False,
            }
        }
        
        self.reward = 0
        self.valid_actions = []
        done = False
        success = False
        info = {}
        info.update(rst)
        
        if metrics["turn_metrics"]["action_is_valid"] and rst.get("format_correct", True):
            assert len(action_list) == 1
            for action in action_list:
                action_pattern = self.action_patterns[self.config.data_source]
                action_match = re.match(action_pattern, action)
                
                if action_match is None:
                    metrics['turn_metrics']['action_is_valid'] = False
                    break
                # Extract the action from the match
                action_str = action_match.group(1)
                
                if action_str == "stop":
                    # We don't execute the stop action for efficiency
                    metr = self._get_metrics()
                    orcale_success = metr['orcale_success'] 
                    distance_to_goal = metr['distance_to_goal']
                    
                    done = True
                    if orcale_success:
                        success = True
                        self.reward += 10.0  # Success reward
                        metrics['traj_metrics']['success'] = True
                else:
                    action_num = int(action_match.groups(2))
                    action_int = self.ACTION_LOOKUP[action_str]
                    self._execute_action(action_int, action_num)
                
                    metr = self._get_metrics()
                    orcale_success = metr['orcale_success']
                    distance_to_goal = metr['distance_to_goal']
                    
                self.valid_actions.append(action)
                
                if done:
                    break
                
                self._current_step += 1
                if self._current_step >= self._max_episode_steps:
                    done = True
                    break
        
        if metrics['turn_metrics']['action_is_valid'] and rst.get("format_correct", True):
            self.reward += self.config.format_reward
            info["is_format_rewarded"] = True
        else:
            info["is_format_rewarded"] = False
        
           # Update info dict
        info["metrics"] = metrics
        info['distance_to_goal'] = distance_to_goal
        info["orcale_success"] = orcale_success
        info['instruction'] = self.instruction
        info['env_step'] = self._current_step
        info['episode_elapsed_seconds'] = time.time() - self._episode_start_time
        info['task_success'] = success
        self.info = info
        # Update total reward
        self.total_reward += self.reward
        
        return self._render(init_obs=False), self.reward, done, info

    def system_prompt(self):
        """
        Get the system prompt for the environment.
        
        Returns a prompt explaining the environment to the LLM agent,
        with different prompts for text and vision modes.
        
        Returns:
            str: System prompt string with environment description and instructions
        """       
        format_prompt_text = self.format_prompt_func(
            add_example=True  # Always true for system prompt
        )
        
        return system_prompt(self.instruction, self.config.data_source) + '\n' + format_prompt_text

    def close(self):
        self.env.close()

    def _execute_action(self, action_index, action_num):
        assert action_num > 0
        if action_index == self.ACTION_LOOKUP["stop"]:
            assert False, "Stop action should not be executed"
        elif action_index == self.ACTION_LOOKUP["move foward"]:
            assert action_num in [25, 50, 75]
            for _ in range(action_num // 25):
                self.env.step(action_index)
        elif action_index == self.ACTION_LOOKUP["turn left"]:
            if self.config.data_source == "r2r":
                assert action_num in [15, 30, 45], "Turn left action should be 15, 30 or 45 degrees"
                for _ in range(action_num // 15):
                    self.env.step(action_index)
            elif self.config.data_source == "rxr":
                assert action_num in [30, 60, 90], "Turn left action should be 30, 60 or 90 degrees"
                for _ in range(action_num // 30):
                    self.env.step(action_index)
        elif action_index == self.ACTION_LOOKUP["turn right"]:
            if self.config.data_source == "r2r":
                assert action_num in [15, 30, 45], "Turn right action should be 15, 30 or 45 degrees"
                for _ in range(action_num // 15):
                    self.env.step(action_index)
            elif self.config.data_source == "rxr":
                assert action_num in [30, 60, 90], "Turn right action should be 30, 60 or 90 degrees"
                for _ in range(action_num // 30):
                    self.env.step(action_index)
        else:
            assert False, f"Invalid action index: {action_index}"
        
    def _get_metrics(self):
        distance_to_goal = self.env.get_metrics()['distance_to_goal']
        success = distance_to_goal <= 3.0
        metrics = {
            "orcale_success": success,
            "distance_to_goal": distance_to_goal,
        }
        return metrics
    
    def _render(self):
        """
        Render the environment as an observation.
        
        This method creates either a text representation or an image of the environment
        state, depending on the configured render mode. It formats the observation string
        based on whether this is the initial observation or a subsequent one.
        
        Returns:
            Dict: Observation dictionary containing:
                - "obs_str": String observation for the LLM
                - "multi_modal_data": Optional dictionary with image data for vision mode
        """
        init_obs = len(self.config.history_actions) == 0
        img_placeholder = self.config.get("image_placeholder", "<image>")
        
        # Get format prompt without examples for action/init templates
        format_prompt_text = self.format_prompt_func(
            add_example=False  # No examples for action and init obs
        ) 
        
        frame = env.sim.get_sensor_observations()['rgb']
        multi_modal_data = {
            img_placeholder: [convert_numpy_to_PIL(frame)]
        }
        
        if init_obs:
            obs_str = init_observation_template(
                observation=img_placeholder,
            ) + "\n" + format_prompt_text
        else:
            obs_str = action_template(
                observation=img_placeholder,
            ) + "\n" + format_prompt_text
        
        return {
            "obs_str": obs_str,
            "multi_modal_data": multi_modal_data
        }


if __name__ == "__main__":
    """
    Example usage of the VLNCE environment.
    
    This code demonstrates how to create an instance of the environment,
    reset it, and interact with it using manual input actions.
    """
    config = VLNCEEnvConfig()
    env = VLNCEEnv(config)
    
    print(env.system_prompt())
    obs, info = env.reset()
    print(obs["obs_str"])
    
    i = 0
    import os
    os.makedirs("./test_VLNCE", exist_ok=True)
    img = obs["multi_modal_data"][config.image_placeholder][0]
    img.save(f"./test_VLNCE/VLNCE_{i}.png")
    
    while not done:
        i += 1
        action = input("Enter action: ")
        action = f"<answer>{action}</answer>"
        
        obs, reward, done, info = env.step(action)
        print(obs["obs_str"])
        
        img = obs["multi_modal_data"][config.image_placeholder][0]
        img.save(f"./test_VLNCE/VLNCE_{i}.png")
        print(f"Success: {info['metrics']['traj_metrics']['success']}")
        
        if done:
            break
    
    print(f"Total reward: {env.compute_reward()}")
    print(info)
    env.close()
 No newline at end of file
+28 −0
Original line number Diff line number Diff line
from vagen.env.base.base_env_config import BaseEnvConfig
from dataclasses import dataclass, fields,field
from typing import Optional, List, Union
import hashlib
import json

@dataclass
class VLNCEEnvConfig(BaseEnvConfig):
    env_name: str = "VLNCE"
    prompt_format: str = "no_think" 
    # "free_think", "no_think", "grounding", "worldmodeling", "grounding_worldmodeling"
    # "grounding_symbolic", "worldmodeling_symbolic", "grounding_worldmodeling_symbolic"
    # "grounding_structured", "worldmodeling_structured", "grounding_worldmodeling_structured"
    
    simulator_config_path = "/nvme-ssd1/zwy/navid_ws/R1-V/src/r1-v/VLN_CE/vlnce_baselines/config/r2r_baselines/navid_r2r.yaml"      
    episode_id: int = 1
    history_actions: List[int] = field(default_factory=list)
    data_source: str = "r2r"
    
    def config_id(self) -> str:
        id_fields=["simulator_config_path", "episode_id", "history_actions", "data_source"]
        id_str = {field.name: getattr(self, field.name) for field in fields(self) if field.name in id_fields}
        id_str = hashlib.sha256(json.dumps(id_str, sort_keys=True).encode()).hexdigest()
        return f"VLNCEEnvConfig({id_str})"

if __name__ == "__main__":
    config = VLNCEEnvConfig()
    print(config.config_id())
 No newline at end of file
+101 −0
Original line number Diff line number Diff line
def system_prompt(instruction, data_source):
    system_prompt_action_merged = {
        "r2r": (
            "You are a helpful and intelligent agent for Vision-and-Language Navigation (VLN) in indoor environments. "
            "Your goal is to follow the given instruction to reach a specified destination. \n"
            "**Instruction**: %s\n"
            "At each step, you receive a first-person image (starting view if first step (step 1), or post-action view otherwise). "
            "\n"
            "Your task is to select choose one action from: move forward [25cm|50cm|75cm], turn left [15|30|45 degrees], turn right [15|30|45 degrees], or stop."
            "Note: Each action is a small, incremental adjustment. Complex navigation requires combining multiple such steps. "
        ),
        "rxr": (
            "You are a helpful and intelligent agent for Vision-and-Language Navigation (VLN) in indoor environments. "
            "Your goal is to follow the given instruction to reach a specified destination. \n"
            "**Instruction**: %s\n"
            "At each step, you receive a first-person image (starting view if first step (step 1), or post-action view otherwise). "
            "\n"
            "Your task is to select choose one action from: move forward [25cm|50cm|75cm], turn left [30|60|90 degrees], turn right [30|60|90 degrees], or stop."
            "Note: Each action is a small, incremental adjustment. Complex navigation requires combining multiple such steps. "
        ),
    }
    return system_prompt_action_merged[data_source] % instruction

def init_observation_template(observation):
    return f"""[Initial Observation]:
{observation}
Decide your next action."""

def action_template(observation):
    return f"""After that, the observation is:
{observation}
Decide your next action."""


# Format configurations defining the structure of each format
FORMAT_CONFIGS = {
    # "free_think": {
    #     "format": "<think>...</think><answer>...</answer>",
    #     "description": "You should first give your reasoning, and then your answer.",
    #     "example": "<think>I can see the target is on my down left, I should go down then left to reach the target</think><answer>Down{action_sep}Left</answer>"
    # },
    
    "no_think": {
        "format": "<answer>...</answer>",
        "description": "You should provide only your answer.",
        "example": "<answer>move forward 25cm</answer>"
    },

}

def format_prompt_generator(format_type):
    """
    Generates a prompt function for the specified format type.
    
    Args:
        format_type (str): The format type to generate a prompt function for
        
    Returns:
        function: A function that generates a prompt for the specified format
    """
    def prompt_function(**kwargs):
        """
        Generate a prompt for the specified format.
        
        Args:
            add_example (bool): Whether to add an example
            
        Returns:
            str: The formatted prompt
        """
        add_example = kwargs.get("add_example", True) # Default to True as per robot examples
        config = FORMAT_CONFIGS[format_type]
        
        # Build the base prompt text
        base_prompt = f"""{config["description"]}"""
        
        # Add response format instruction
        base_prompt += f"""
Your response should be in the format of:
{config["format"]}"""
        
        # Add example if requested
        if add_example:
            example = config["example"]
            return base_prompt + '\n' + f"e.g. {example}"
        
        return base_prompt
    
    return prompt_function

# Generate the format prompt dictionary using the generator
format_prompt = {format_type: format_prompt_generator(format_type) 
                for format_type in FORMAT_CONFIGS}

if __name__ == "__main__":
    # Example usage
    
    for key, func in format_prompt.items():
        print(f"{key} format prompt:")
        print(func(add_example=True))
        print("\n" + "="*50 + "\n")
 No newline at end of file
+299 −0

File added.

Preview size limit exceeded, changes collapsed.

+8 −0
Original line number Diff line number Diff line
from vagen.env.base.base_service_config import BaseServiceConfig
from dataclasses import dataclass, fields,field
from typing import Dict, List, Tuple, Optional, Any, Union

@dataclass
class FrozenLakeServiceConfig(BaseServiceConfig):
    device: Dict[str, Any] = field(default_factory=lambda: {"clip": 0})
    use_state_reward: bool = False
 No newline at end of file