Commit 44161e21 authored by wayne's avatar wayne
Browse files

fix vlnce env

parent 98685a3b
Loading
Loading
Loading
Loading
+110 −63
Original line number Original line Diff line number Diff line
@@ -5,10 +5,16 @@ from typing import Dict, List, Optional, Tuple, Any
from vagen.env.utils.env_utils import NoLoggerWarnings, set_seed
from vagen.env.utils.env_utils import NoLoggerWarnings, set_seed
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .prompt import system_prompt, init_observation_template, action_template, format_prompt
# from .prompt import system_prompt, init_observation_template, action_template, format_prompt
from .env_config import VLNCEEnvConfig
from prompt import system_prompt, init_observation_template, action_template, format_prompt
# from .env_config import VLNCEEnvConfig
from env_config import VLNCEEnvConfig
from vagen.env.utils.state_reward_text_utils import env_state_reward_wrapper
from vagen.env.utils.state_reward_text_utils import env_state_reward_wrapper
import habitat
import habitat
import os
import sys
# print(os.path.dirname(__file__))
# sys.path.append(os.path.dirname(__file__))
from VLN_CE.vlnce_baselines.config.default import get_config
from VLN_CE.vlnce_baselines.config.default import get_config
import json
import json
import regex as re
import regex as re
@@ -41,7 +47,42 @@ def init_env(simulator_config_path):
        
        
    return env, reset
    return env, reset


STOP = "stop"
MOVE_FORWARD = "move forward"
TURN_LEFT = "turn left"
TURN_RIGHT = "turn right"


ACTION_LOOKUP = {
    STOP: 0,
    MOVE_FORWARD: 1,
    TURN_LEFT: 2,
    TURN_RIGHT: 3,
}

def parse_action(action, data_source):
    """
    Parse the action string into a tuple of (action_name, action_value).
    """
    action_pattern = {
        "r2r": rf'({MOVE_FORWARD}) (25|50|75)cm|(({TURN_LEFT})|({TURN_RIGHT})) (15|30|45) degrees|({STOP})',
        "rxr": rf'({MOVE_FORWARD}) (25|50|75)cm|(({TURN_LEFT})|({TURN_RIGHT})) (30|60|90) degrees|({STOP})',
    }
    action = action.strip()
    if data_source not in action_pattern:
        raise ValueError(f"Unknown data source: {data_source}")
    action_regex = action_pattern[data_source]
    action_match = re.match(action_regex, action)
    if action_match is None:
        raise ValueError(f"Unknown action: {action}")
    
    if action.startswith(MOVE_FORWARD):
        return (MOVE_FORWARD, int(action.split(" ")[-1].replace("cm", "")))
    elif action.startswith(TURN_LEFT):
        return (TURN_LEFT, int(action.split(" ")[-2]))
    elif action.startswith(TURN_RIGHT):
        return (TURN_RIGHT, int(action.split(" ")[-2]))
    else:
        return (STOP, None)
                
                


class VLNCEEnv(BaseEnv):
class VLNCEEnv(BaseEnv):
@@ -49,13 +90,6 @@ class VLNCEEnv(BaseEnv):
    VLNCE Environment for training and evaluating language models as agents.
    VLNCE Environment for training and evaluating language models as agents.
    """
    """


    ACTION_LOOKUP = {
        "stop": 0,
        "move foward": 1,
        "turn left": 2,
        "turn right": 3,
    }



    def __init__(self, config: VLNCEEnvConfig):
    def __init__(self, config: VLNCEEnvConfig):
        """
        """
@@ -69,7 +103,7 @@ class VLNCEEnv(BaseEnv):
        self.config = config
        self.config = config
       
       
        self.env, self.reset_env = init_env(config.simulator_config_path)
        self.env, self.reset_env = init_env(config.simulator_config_path)
        self.obs = self.reset_env(self.config.episode_id)
        obs = self.reset_env(self.config.episode_id)
        self.instruction = obs["instruction"]["text"]
        self.instruction = obs["instruction"]["text"]
        self.initial_states = self.get_initial_checkpoint()
        self.initial_states = self.get_initial_checkpoint()
        
        
@@ -89,16 +123,36 @@ class VLNCEEnv(BaseEnv):
        self._max_episode_steps = 30 # TODO: dynamically tune this parameter
        self._max_episode_steps = 30 # TODO: dynamically tune this parameter
        self._episode_start_time = 0
        self._episode_start_time = 0


        self.action_patterns = {
            "r2r": r'(move forward) (25|50|75)cm|(turn (?:left|right)) (15|30|45) degrees|(stop)',
            "rxr": r'(move forward) (25|50|75)cm|(turn (?:left|right)) (30|60|90) degrees|(stop)',
        }


    def get_initial_checkpoint(self):
    def get_initial_checkpoint(self):
        env = self.env
        env = self.env
        history_actions = self.config.history_actions
        history_actions = self.config.history_actions
        for action in history_actions:
        
            env.step(action)
        previous_state = env.sim.get_agent_state()
        previous_metrics = env.get_metrics()
        locations_length = len(env._task.measurements.measures['ndtw'].locations)
        
        for action_str in history_actions:
            action_name, action_value = parse_action(action_str, self.config.data_source)
            action_index = ACTION_LOOKUP[action_name]
            num_action_repeat = 0
            if action_name == STOP:
                assert False
            elif action_name == MOVE_FORWARD:
                num_action_repeat = action_value // 25
            elif action_name == TURN_LEFT or action_name == TURN_RIGHT:
                if self.config.data_source == 'r2r':
                    num_action_repeat = action_value // 15
                elif self.config.data_source == 'rxr':
                    num_action_repeat = action_value // 30
                else:
                    assert False
            else:
                assert False
            
            for _ in range(num_action_repeat):
                env.step(action_index)
                
        previous_state = env.sim.get_agent_state()
        previous_state = env.sim.get_agent_state()
        previous_metrics = env.get_metrics()
        previous_metrics = env.get_metrics()
        locations_length = len(env._task.measurements.measures['ndtw'].locations)
        locations_length = len(env._task.measurements.measures['ndtw'].locations)
@@ -161,7 +215,7 @@ class VLNCEEnv(BaseEnv):
            response=action_str,
            response=action_str,
            special_token_list=self.config.special_token_list,
            special_token_list=self.config.special_token_list,
            action_sep=self.config.action_sep,
            action_sep=self.config.action_sep,
            max_actions=self.config.max_actions_per_step
            max_actions=1
        )
        )
        
        
        action_list = rst['actions']
        action_list = rst['actions']
@@ -169,7 +223,6 @@ class VLNCEEnv(BaseEnv):
        metrics = {
        metrics = {
            "turn_metrics": {
            "turn_metrics": {
                "action_is_valid": len(action_list) > 0,
                "action_is_valid": len(action_list) > 0,
                "action_is_effective": False,
            },
            },
            "traj_metrics": {
            "traj_metrics": {
                "success": False,
                "success": False,
@@ -180,22 +233,26 @@ class VLNCEEnv(BaseEnv):
        self.valid_actions = []
        self.valid_actions = []
        done = False
        done = False
        success = False
        success = False
        metr = self._get_metrics()
        distance_to_goal = metr['distance_to_goal']
        orcale_success = metr['orcale_success']
        info = {}
        info = {}
        info.update(rst)

        
        
        if metrics["turn_metrics"]["action_is_valid"] and rst.get("format_correct", True):
        if metrics["turn_metrics"]["action_is_valid"] and rst.get("format_correct", True):
            assert len(action_list) == 1
            assert len(action_list) == 1
            for action in action_list:
            for action in action_list:
                action_pattern = self.action_patterns[self.config.data_source]
                self._current_step += 1
                action_match = re.match(action_pattern, action)


                if action_match is None:
                try:
                    action_name, action_value = parse_action(action, self.config.data_source)
                except ValueError as e:
                    # If parsing fails, mark action as invalid
                    metrics['turn_metrics']['action_is_valid'] = False
                    metrics['turn_metrics']['action_is_valid'] = False
                    print(f"Invalid action: {action}. Error: {e}")
                    break              
                    break              
                # Extract the action from the match
                action_str = action_match.group(1)
                    
                    
                if action_str == "stop":
                if action_name == STOP:
                    # We don't execute the stop action for efficiency
                    # We don't execute the stop action for efficiency
                    metr = self._get_metrics()
                    metr = self._get_metrics()
                    orcale_success = metr['orcale_success'] 
                    orcale_success = metr['orcale_success'] 
@@ -207,9 +264,7 @@ class VLNCEEnv(BaseEnv):
                        self.reward += 10.0  # Success reward
                        self.reward += 10.0  # Success reward
                        metrics['traj_metrics']['success'] = True
                        metrics['traj_metrics']['success'] = True
                else:
                else:
                    action_num = int(action_match.groups(2))
                    self._execute_action(action_name, action_value)
                    action_int = self.ACTION_LOOKUP[action_str]
                    self._execute_action(action_int, action_num)
                
                
                    metr = self._get_metrics()
                    metr = self._get_metrics()
                    orcale_success = metr['orcale_success']
                    orcale_success = metr['orcale_success']
@@ -217,14 +272,13 @@ class VLNCEEnv(BaseEnv):
                    
                    
                self.valid_actions.append(action)
                self.valid_actions.append(action)
                
                
                if done:
                    break
                
                self._current_step += 1
                if self._current_step >= self._max_episode_steps:
                if self._current_step >= self._max_episode_steps:
                    done = True
                    done = True
                    break
                    break
                
                
                if done:
                    break
        
        if metrics['turn_metrics']['action_is_valid'] and rst.get("format_correct", True):
        if metrics['turn_metrics']['action_is_valid'] and rst.get("format_correct", True):
            self.reward += self.config.format_reward
            self.reward += self.config.format_reward
            info["is_format_rewarded"] = True
            info["is_format_rewarded"] = True
@@ -232,6 +286,8 @@ class VLNCEEnv(BaseEnv):
            info["is_format_rewarded"] = False
            info["is_format_rewarded"] = False
        
        
           # Update info dict
           # Update info dict
        
        info.update(rst)
        info["metrics"] = metrics
        info["metrics"] = metrics
        info['distance_to_goal'] = distance_to_goal
        info['distance_to_goal'] = distance_to_goal
        info["orcale_success"] = orcale_success
        info["orcale_success"] = orcale_success
@@ -243,7 +299,7 @@ class VLNCEEnv(BaseEnv):
        # Update total reward
        # Update total reward
        self.total_reward += self.reward
        self.total_reward += self.reward
        
        
        return self._render(init_obs=False), self.reward, done, info
        return self._render(), self.reward, done, info


    def system_prompt(self):
    def system_prompt(self):
        """
        """
@@ -264,34 +320,24 @@ class VLNCEEnv(BaseEnv):
    def close(self):
    def close(self):
        self.env.close()
        self.env.close()


    def _execute_action(self, action_index, action_num):
    def _execute_action(self, action_name, action_value):
        assert action_num > 0
        assert action_name in ACTION_LOOKUP, f"Invalid action name: {action_name}"
        if action_index == self.ACTION_LOOKUP["stop"]:
        action_index = ACTION_LOOKUP[action_name]
        if action_name == STOP:
            assert False, "Stop action should not be executed"
            assert False, "Stop action should not be executed"
        elif action_index == self.ACTION_LOOKUP["move foward"]:
        elif action_name == MOVE_FORWARD:
            assert action_num in [25, 50, 75]
            assert action_value in [25, 50, 75]
            for _ in range(action_num // 25):
            for _ in range(action_value // 25):
                self.env.step(action_index)
                self.env.step(action_index)
        elif action_index == self.ACTION_LOOKUP["turn left"]:
        elif action_name == TURN_LEFT or action_name == TURN_RIGHT:
            if self.config.data_source == "r2r":
            if self.config.data_source == "r2r":
                assert action_num in [15, 30, 45], "Turn left action should be 15, 30 or 45 degrees"
                assert action_value in [15, 30, 45], "Turn left/right action should be 15, 30 or 45 degrees"
                for _ in range(action_num // 15):
                for _ in range(action_value // 15):
                    self.env.step(action_index)
                    self.env.step(action_index)
            elif self.config.data_source == "rxr":
            elif self.config.data_source == "rxr":
                assert action_num in [30, 60, 90], "Turn left action should be 30, 60 or 90 degrees"
                assert action_value in [30, 60, 90], "Turn left/right action should be 30, 60 or 90 degrees"
                for _ in range(action_num // 30):
                for _ in range(action_value // 30):
                    self.env.step(action_index)
                    self.env.step(action_index)
        elif action_index == self.ACTION_LOOKUP["turn right"]:
            if self.config.data_source == "r2r":
                assert action_num in [15, 30, 45], "Turn right action should be 15, 30 or 45 degrees"
                for _ in range(action_num // 15):
                    self.env.step(action_index)
            elif self.config.data_source == "rxr":
                assert action_num in [30, 60, 90], "Turn right action should be 30, 60 or 90 degrees"
                for _ in range(action_num // 30):
                    self.env.step(action_index)
        else:
            assert False, f"Invalid action index: {action_index}"
        
        
    def _get_metrics(self):
    def _get_metrics(self):
        distance_to_goal = self.env.get_metrics()['distance_to_goal']
        distance_to_goal = self.env.get_metrics()['distance_to_goal']
@@ -315,7 +361,7 @@ class VLNCEEnv(BaseEnv):
                - "obs_str": String observation for the LLM
                - "obs_str": String observation for the LLM
                - "multi_modal_data": Optional dictionary with image data for vision mode
                - "multi_modal_data": Optional dictionary with image data for vision mode
        """
        """
        init_obs = len(self.config.history_actions) == 0
        init_obs = len(self.config.history_actions) == 0 and self._current_step == 0
        img_placeholder = self.config.get("image_placeholder", "<image>")
        img_placeholder = self.config.get("image_placeholder", "<image>")
        
        
        # Get format prompt without examples for action/init templates
        # Get format prompt without examples for action/init templates
@@ -323,7 +369,7 @@ class VLNCEEnv(BaseEnv):
            add_example=False  # No examples for action and init obs
            add_example=False  # No examples for action and init obs
        ) 
        ) 
        
        
        frame = env.sim.get_sensor_observations()['rgb']
        frame = self.env.sim.get_sensor_observations()['rgb'][:,:,:-1]
        multi_modal_data = {
        multi_modal_data = {
            img_placeholder: [convert_numpy_to_PIL(frame)]
            img_placeholder: [convert_numpy_to_PIL(frame)]
        }
        }
@@ -363,6 +409,7 @@ if __name__ == "__main__":
    img = obs["multi_modal_data"][config.image_placeholder][0]
    img = obs["multi_modal_data"][config.image_placeholder][0]
    img.save(f"./test_VLNCE/VLNCE_{i}.png")
    img.save(f"./test_VLNCE/VLNCE_{i}.png")
    
    
    done = False
    while not done:
    while not done:
        i += 1
        i += 1
        action = input("Enter action: ")
        action = input("Enter action: ")
+1 −1
Original line number Original line Diff line number Diff line
@@ -14,7 +14,7 @@ class VLNCEEnvConfig(BaseEnvConfig):
    
    
    simulator_config_path = "/nvme-ssd1/zwy/navid_ws/R1-V/src/r1-v/VLN_CE/vlnce_baselines/config/r2r_baselines/navid_r2r.yaml"      
    simulator_config_path = "/nvme-ssd1/zwy/navid_ws/R1-V/src/r1-v/VLN_CE/vlnce_baselines/config/r2r_baselines/navid_r2r.yaml"      
    episode_id: int = 1
    episode_id: int = 1
    history_actions: List[int] = field(default_factory=list)
    history_actions: List[str] = field(default_factory=list)
    data_source: str = "r2r"
    data_source: str = "r2r"
    
    
    def config_id(self) -> str:
    def config_id(self) -> str:
+2 −2

File changed.

Contains only whitespace changes.