Commit 803238af authored by wayne's avatar wayne
Browse files

vlnce env

parent 44161e21
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
# First, import the modules that are assumed to be always available
from .sokoban import SokobanEnv, SokobanEnvConfig, SokobanService, SokobanServiceConfig
from .frozenlake import FrozenLakeEnv, FrozenLakeEnvConfig, FrozenLakeService, FrozenLakeServiceConfig
from .vlnce import VLNCEEnv, VLNCEEnvConfig, VLNCEService, VLNCEServiceConfig    

REGISTERED_ENV = {
    "sokoban": {
@@ -14,6 +15,12 @@ REGISTERED_ENV = {
        "config_cls": FrozenLakeEnvConfig,
        "service_cls": FrozenLakeService,
        "service_config_cls": FrozenLakeServiceConfig
    },
    "vlnce": {
        "env_cls": VLNCEEnv,
        "config_cls": VLNCEEnvConfig,
        "service_cls": VLNCEService,
        "service_config_cls": VLNCEServiceConfig
    }
}

+4 −0
Original line number Diff line number Diff line
from .env_config import VLNCEEnvConfig
from .env import VLNCEEnv
from .service import VLNCEService
from .service_config import VLNCEServiceConfig
 No newline at end of file
+25 −14
Original line number Diff line number Diff line
@@ -5,17 +5,15 @@ from typing import Dict, List, Optional, Tuple, Any
from vagen.env.utils.env_utils import NoLoggerWarnings, set_seed
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
# from .prompt import system_prompt, init_observation_template, action_template, format_prompt
from prompt import system_prompt, init_observation_template, action_template, format_prompt
# from .env_config import VLNCEEnvConfig
from env_config import VLNCEEnvConfig
from .prompt import system_prompt, init_observation_template, action_template, format_prompt
# from prompt import system_prompt, init_observation_template, action_template, format_prompt
from .env_config import VLNCEEnvConfig
# from env_config import VLNCEEnvConfig
from vagen.env.utils.state_reward_text_utils import env_state_reward_wrapper
import habitat
import os
import sys
# print(os.path.dirname(__file__))
# sys.path.append(os.path.dirname(__file__))
from VLN_CE.vlnce_baselines.config.default import get_config
import json
import regex as re
import time
@@ -36,6 +34,8 @@ def reset_env(env,i):
    return observations

def init_env(simulator_config_path):
    import habitat
    from VLN_CE.vlnce_baselines.config.default import get_config
    config = get_config(simulator_config_path, None)
    dataset = habitat.datasets.make_dataset(id_dataset=config.TASK_CONFIG.DATASET.TYPE, config=config.TASK_CONFIG.DATASET)
    env = habitat.Env(config.TASK_CONFIG,dataset=dataset)
@@ -120,8 +120,10 @@ class VLNCEEnv(BaseEnv):
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]
        
        self._current_step = 0
        self._max_episode_steps = 30 # TODO: dynamically tune this parameter
        self._episode_start_time = 0
        self._success = False

        self.history_renders = []


    def get_initial_checkpoint(self):
@@ -132,6 +134,8 @@ class VLNCEEnv(BaseEnv):
        previous_metrics = env.get_metrics()
        locations_length = len(env._task.measurements.measures['ndtw'].locations)
        
        self.history_renders.append(self._render(init_obs=True))
        
        for action_str in history_actions:
            action_name, action_value = parse_action(action_str, self.config.data_source)
            action_index = ACTION_LOOKUP[action_name]
@@ -152,6 +156,7 @@ class VLNCEEnv(BaseEnv):
            
            for _ in range(num_action_repeat):
                env.step(action_index)
            self.history_renders.append(self._render(init_obs=False))
                
        previous_state = env.sim.get_agent_state()
        previous_metrics = env.get_metrics()
@@ -167,6 +172,7 @@ class VLNCEEnv(BaseEnv):
    def reset(self, seed=None):
        self._current_step = 0
        self._episode_start_time = time.time()
        self._success = False
        self.total_reward = 0
        self.valid_actions = []
        self.reward = 0
@@ -187,7 +193,7 @@ class VLNCEEnv(BaseEnv):
                env._task.measurements.measures[key]._metric = value

        
        return self._render(), {}
        return self.history_renders, {}

    # @env_state_reward_wrapper
    def step(self, action_str: str):
@@ -242,7 +248,6 @@ class VLNCEEnv(BaseEnv):
        if metrics["turn_metrics"]["action_is_valid"] and rst.get("format_correct", True):
            assert len(action_list) == 1
            for action in action_list:
                self._current_step += 1

                try:
                    action_name, action_value = parse_action(action, self.config.data_source)
@@ -261,8 +266,8 @@ class VLNCEEnv(BaseEnv):
                    done = True
                    if orcale_success:
                        success = True
                        self.reward += 10.0  # Success reward
                        metrics['traj_metrics']['success'] = True
                        self._success = True
                else:
                    self._execute_action(action_name, action_value)
                
@@ -272,7 +277,9 @@ class VLNCEEnv(BaseEnv):
                    
                self.valid_actions.append(action)
                
                if self._current_step >= self._max_episode_steps:
                self._current_step += 1
                
                if self._current_step >= self.config.max_episode_steps:
                    done = True
                    break
                
@@ -348,7 +355,7 @@ class VLNCEEnv(BaseEnv):
        }
        return metrics
    
    def _render(self):
    def _render(self, init_obs=False):
        """
        Render the environment as an observation.
        
@@ -361,7 +368,8 @@ class VLNCEEnv(BaseEnv):
                - "obs_str": String observation for the LLM
                - "multi_modal_data": Optional dictionary with image data for vision mode
        """
        init_obs = len(self.config.history_actions) == 0 and self._current_step == 0
        if init_obs:
            assert len(self.config.history_actions) == 0 and self._current_step == 0
        img_placeholder = self.config.get("image_placeholder", "<image>")
        
        # Get format prompt without examples for action/init templates
@@ -388,6 +396,9 @@ class VLNCEEnv(BaseEnv):
            "multi_modal_data": multi_modal_data
        }
    
    def compute_reward(self) -> float:
        """Calculate final episode reward"""
        return 10.0 if self._success else 0.0

if __name__ == "__main__":
    """
+1 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ class VLNCEEnvConfig(BaseEnvConfig):
    episode_id: int = 1
    history_actions: List[str] = field(default_factory=list)
    data_source: str = "r2r"
    max_episode_steps: int = 0
    
    def config_id(self) -> str:
        id_fields=["simulator_config_path", "episode_id", "history_actions", "data_source"]
+54 −75
Original line number Diff line number Diff line
from typing import Dict, List, Tuple, Optional, Any, Union
from typing import Dict, List, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from vagen.env.base.base_service import BaseService
from vagen.env.vlnce.env import VLNCEEnv
from vagen.env.vlnce.env_config import VLNCEEnvConfig
from vagen.server.serial import serialize_observation

from .env import VLNCEEnv
from .env_config import VLNCEEnvConfig
from ..base.base_service_config import BaseServiceConfig
from .service_config import VLNCEServiceConfig
from vagen.env.utils.state_reward_text_utils import service_state_reward_wrapper

class VLNCEService(BaseService):
    """
    Service class for VLNCE environments.
    Service class for VLNCE environments based on AI2-THOR.
    Implements batch operations with parallel processing for efficiency.
    """
    
    def __init__(self, config:BaseServiceConfig):
    def __init__(self, config:VLNCEServiceConfig):
        """
        Initialize the VLNCEService.
        
        Args:
            max_workers: Maximum number of worker threads for parallel processing
        """
        self.max_workers = config.get('max_workers', 10)
        self.max_workers = config.max_workers
        self.environments = {}
        self.env_configs = {}
        self.config=config
        print(f"[DEBUG] {self.config}")
    
    def create_environments_batch(self, ids2configs: Dict[Any, Any]) -> None:
    def create_environments_batch(self, ids2configs: Dict[str, Any]) -> None:
        """
        Create multiple VLNCE environments in parallel.
        
@@ -44,19 +44,10 @@ class VLNCEService(BaseService):
            if env_name != 'VLNCE':
                return env_id, None, f"Expected environment type 'VLNCE', got '{env_name}'"
            
            try:
                # Get VLNCE specific configuration
                env_config_dict = config.get('env_config', {})
                
                # Create environment config
            env_config_dict = config['env_config']
            env_config = VLNCEEnvConfig(**env_config_dict)
                
                # Create environment
            env = VLNCEEnv(env_config)
                
            return env_id, (env, env_config), None
            except Exception as e:
                return env_id, None, str(e)
        
        # Use ThreadPoolExecutor for parallel creation
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -78,7 +69,7 @@ class VLNCEService(BaseService):
                self.environments[env_id] = env
                self.env_configs[env_id] = env_config
    
    def reset_batch(self, ids2seeds: Dict[Any, Any]) -> Dict[Any, Tuple[Any, Any]]:
    def reset_batch(self, ids2seeds: Dict[str, Any]) -> Dict[str, Tuple[Any, Any]]:
        """
        Reset multiple VLNCE environments in parallel.
        
@@ -93,16 +84,11 @@ class VLNCEService(BaseService):
        
        # Define worker function
        def reset_single_env(env_id, seed):     
            try:
                if env_id not in self.environments:
                    return env_id, None, f"Environment {env_id} not found"
                
            env = self.environments[env_id]
            observation, info = env.reset(seed=seed)
            serialized_observation = serialize_observation(observation)
            return env_id, (serialized_observation, info), None
            except Exception as e:
                return env_id, None, str(e)
            
        
        # Use ThreadPoolExecutor for parallel reset
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -124,8 +110,7 @@ class VLNCEService(BaseService):
        
        return results
    
    @service_state_reward_wrapper
    def step_batch(self, ids2actions: Dict[Any, Any]) -> Dict[Any, Tuple[Dict, float, bool, Dict]]:
    def step_batch(self, ids2actions: Dict[str, Any]) -> Dict[str, Tuple[Dict, float, bool, Dict]]:
        """
        Take a step in multiple VLNCE environments in parallel.
        
@@ -140,16 +125,26 @@ class VLNCEService(BaseService):
        
        # Define worker function
        def step_single_env(env_id, action):
            try:
                if env_id not in self.environments:
                    return env_id, None, f"Environment {env_id} not found"
                
            env = self.environments[env_id]
            try:
                observation, reward, done, info = env.step(action)
            except Exception as e:
                print(f"Error stepping VLNCE environment {env_id}: {e}")
                try:
                    observation,info=env.reset()
                    reward=0.0
                    done=True 
                except Exception as e:
                    print(f"Error resetting VLNCE environment {env_id} after step failure: {e}")
                    env.close()
                    self.environments.pop(env_id, None)
                    observation={"obs_str":"error"}
                    reward=0.0
                    done=True
                    info={}
            serialized_observation = serialize_observation(observation)
            return env_id, (serialized_observation, reward, done, info), None
            except Exception as e:
                return env_id, None, str(e)
            
        
        # Use ThreadPoolExecutor for parallel step
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -171,7 +166,7 @@ class VLNCEService(BaseService):
        
        return results
    
    def compute_reward_batch(self, env_ids: List[str]) -> Dict[Any, float]:
    def compute_reward_batch(self, env_ids: List[str]) -> Dict[str, float]:
        """
        Compute the total reward for multiple VLNCE environments in parallel.
        
@@ -181,18 +176,14 @@ class VLNCEService(BaseService):
        Returns:
            A dictionary mapping each environment ID to its computed total reward
        """
        raise NotImplementedError("compute_reward_batch is not implemented in VLNCEService.")
        results = {}
        
        # Define worker function
        def compute_reward_single_env(env_id):
            try:
                if env_id not in self.environments:
                    return env_id, None, f"Environment {env_id} not found"
                
            env = self.environments[env_id]
            return env_id, env.compute_reward(), None
            except Exception as e:
                return env_id, None, str(e)
           
        
        # Use ThreadPoolExecutor for parallel computation
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -214,7 +205,7 @@ class VLNCEService(BaseService):
        
        return results
    
    def get_system_prompts_batch(self, env_ids: List[str]) -> Dict[Any, str]:
    def get_system_prompts_batch(self, env_ids: List[str]) -> Dict[str, str]:
        """
        Get system prompts for multiple VLNCE environments in parallel.
        
@@ -228,14 +219,9 @@ class VLNCEService(BaseService):
        
        # Define worker function
        def get_system_prompt_single_env(env_id):
            try:
                if env_id not in self.environments:
                    return env_id, None, f"Environment {env_id} not found"
                
            env = self.environments[env_id]
            return env_id, env.system_prompt(), None
            except Exception as e:
                return env_id, None, str(e)
       
        
        # Use ThreadPoolExecutor for parallel retrieval
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -270,15 +256,10 @@ class VLNCEService(BaseService):
        
        # Define worker function
        def close_single_env(env_id):      
            try:
                if env_id not in self.environments:
                    return f"Environment {env_id} not found"
                
            env = self.environments[env_id]
            env.close()
            return None
            except Exception as e:
                return str(e)
            
        
        # Use ThreadPoolExecutor for parallel closing
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@@ -295,5 +276,3 @@ class VLNCEService(BaseService):
        for env_id in env_ids:
            self.environments.pop(env_id, None)
            self.env_configs.pop(env_id, None)
    
    
 No newline at end of file
Loading