Commit 4077e7c2 authored by wayne's avatar wayne
Browse files

feat: multi-thread env -> multi-process env

parent 695eae29
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
from .env_config import VLNCEEnvConfig
from .env import VLNCEEnv
from .service import VLNCEService
# from .service_legacy import VLNCEService
from .service_ray import VLNCEService
from .service_config import VLNCEServiceConfig
 No newline at end of file
+11 −8
Original line number Diff line number Diff line
@@ -11,9 +11,7 @@ from .env_config import VLNCEEnvConfig
# from env_config import VLNCEEnvConfig
from vagen.env.utils.state_reward_text_utils import env_state_reward_wrapper
import os
import sys
# print(os.path.dirname(__file__))
# sys.path.append(os.path.dirname(__file__))

import json
import regex as re
import time
@@ -35,6 +33,8 @@ def reset_env(env,i):

def init_env(simulator_config_path):
    import habitat
    import sys
    sys.path.append(os.path.dirname(__file__))
    from VLN_CE.vlnce_baselines.config.default import get_config
    config = get_config(simulator_config_path, None)
    dataset = habitat.datasets.make_dataset(id_dataset=config.TASK_CONFIG.DATASET.TYPE, config=config.TASK_CONFIG.DATASET)
@@ -105,7 +105,8 @@ class VLNCEEnv(BaseEnv):
        self.env, self.reset_env = init_env(config.simulator_config_path)
        obs = self.reset_env(self.config.episode_id)
        self.instruction = obs["instruction"]["text"]
        self.initial_states = self.get_initial_checkpoint()
        self._current_step = 0
        self.history_renders = []
        
        
        self.reset_keys = ['ndtw', 'success', 'spl', 'oracle_success', 'oracle_spl']
@@ -119,14 +120,15 @@ class VLNCEEnv(BaseEnv):
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]
        
        self._current_step = 0
      
        self._episode_start_time = 0
        self._success = False
        self.initial_states = self._get_initial_checkpoint()

        
        self.history_renders = []


    def get_initial_checkpoint(self):
    def _get_initial_checkpoint(self):
        env = self.env
        history_actions = self.config.history_actions
        
@@ -170,6 +172,7 @@ class VLNCEEnv(BaseEnv):
        
        
    def reset(self, seed=None):
        assert len(self.history_renders) == len(self.config.history_actions) + 1
        self._current_step = 0
        self._episode_start_time = time.time()
        self._success = False
@@ -369,7 +372,7 @@ class VLNCEEnv(BaseEnv):
                - "multi_modal_data": Optional dictionary with image data for vision mode
        """
        if init_obs:
            assert len(self.config.history_actions) == 0 and self._current_step == 0
            assert len(self.history_renders) == 0 and self._current_step == 0
        img_placeholder = self.config.get("image_placeholder", "<image>")
        
        # Get format prompt without examples for action/init templates
+278 −0
Original line number Diff line number Diff line
from typing import Dict, List, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from vagen.env.base.base_service import BaseService
from vagen.env.vlnce.env import VLNCEEnv
from vagen.env.vlnce.env_config import VLNCEEnvConfig
from vagen.server.serial import serialize_observation
from .service_config import VLNCEServiceConfig
from vagen.env.utils.state_reward_text_utils import service_state_reward_wrapper

class VLNCEService(BaseService):
    """
    Service class for VLNCE environments based on AI2-THOR.
    Implements batch operations with parallel processing for efficiency.
    """
    
    def __init__(self, config:VLNCEServiceConfig):
        """
        Initialize the VLNCEService.
        
        Args:
            max_workers: Maximum number of worker threads for parallel processing
        """
        self.max_workers = config.max_workers
        self.environments = {}
        self.env_configs = {}
        self.config=config
        print(f"[DEBUG] {self.config}")
    
    def create_environments_batch(self, ids2configs: Dict[str, Any]) -> None:
        """
        Create multiple VLNCE environments in parallel.
        
        Args:
            ids2configs: A dictionary where each key is an environment ID and the corresponding
                        value is the configuration for that environment.
                Each config should contain:
                - env_name: Should be "VLNCE"
                - env_config: VLNCE specific configuration
        """
        # Define worker function
        def create_single_env(env_id, config):
            # Verify environment type
            env_name = config.get('env_name', 'vlnce')
            if env_name != 'vlnce':
                return env_id, None, f"Expected environment type 'vlnce', got '{env_name}'"
            
            env_config_dict = config['env_config']
            env_config = VLNCEEnvConfig(**env_config_dict)
            env = VLNCEEnv(env_config)
            return env_id, (env, env_config), None
        
        # Use ThreadPoolExecutor for parallel creation
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all environment creation tasks
            futures = {
                executor.submit(create_single_env, env_id, config): env_id 
                for env_id, config in ids2configs.items()
            }
            
            # Process results as they complete
            for future in as_completed(futures):
                env_id = futures[future]
                env_id, result, error = future.result()
                if error:
                    print(f"Error creating environment {env_id}: {error}")
                    continue
                
                env, env_config = result
                self.environments[env_id] = env
                self.env_configs[env_id] = env_config
    
    def reset_batch(self, ids2seeds: Dict[str, Any]) -> Dict[str, Tuple[Any, Any]]:
        """
        Reset multiple VLNCE environments in parallel.
        
        Args:
            ids2seeds: A dictionary where each key is an environment ID and the corresponding
                     value is a seed value (or None for using default seeding behavior).
            
        Returns:
            A dictionary mapping environment IDs to tuples of the form (observation, info)
        """
        results = {}
        
        # Define worker function
        def reset_single_env(env_id, seed):     
            env = self.environments[env_id]
            observation, info = env.reset(seed=seed)
            serialized_observation = serialize_observation(observation)
            return env_id, (serialized_observation, info), None
            
        
        # Use ThreadPoolExecutor for parallel reset
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all reset tasks
            futures = {
                executor.submit(reset_single_env, env_id, seed): env_id 
                for env_id, seed in ids2seeds.items()
            }
            
            # Process results as they complete
            for future in as_completed(futures):
                env_id = futures[future]
                env_id, result, error = future.result()
                if error:
                    print(f"Error resetting environment {env_id}: {error}")
                    results[env_id] = ({}, {"error": error})
                else:
                    results[env_id] = result
        
        return results
    
    def step_batch(self, ids2actions: Dict[str, Any]) -> Dict[str, Tuple[Dict, float, bool, Dict]]:
        """
        Take a step in multiple VLNCE environments in parallel.
        
        Args:
            ids2actions: A dictionary where each key is an environment ID and the corresponding
                       value is the action to execute in that environment.
            
        Returns:
            A dictionary mapping environment IDs to tuples of the form (observation, reward, done, info)
        """
        results = {}
        
        # Define worker function
        def step_single_env(env_id, action):
            env = self.environments[env_id]
            try:
                observation, reward, done, info = env.step(action)
            except Exception as e:
                print(f"Error stepping VLNCE environment {env_id}: {e}")
                try:
                    observation,info=env.reset()
                    reward=0.0
                    done=True 
                except Exception as e:
                    print(f"Error resetting VLNCE environment {env_id} after step failure: {e}")
                    env.close()
                    self.environments.pop(env_id, None)
                    observation={"obs_str":"error"}
                    reward=0.0
                    done=True
                    info={}
            serialized_observation = serialize_observation(observation)
            return env_id, (serialized_observation, reward, done, info), None
            
        
        # Use ThreadPoolExecutor for parallel step
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all step tasks
            futures = {
                executor.submit(step_single_env, env_id, action): env_id 
                for env_id, action in ids2actions.items()
            }
            
            # Process results as they complete
            for future in as_completed(futures):
                env_id = futures[future]
                env_id, result, error = future.result()
                if error:
                    print(f"Error stepping environment {env_id}: {error}")
                    results[env_id] = ({}, 0.0, True, {"error": error})
                else:
                    results[env_id] = result
        
        return results
    
    def compute_reward_batch(self, env_ids: List[str]) -> Dict[str, float]:
        """
        Compute the total reward for multiple VLNCE environments in parallel.
        
        Args:
            env_ids: A list of environment IDs
            
        Returns:
            A dictionary mapping each environment ID to its computed total reward
        """
        raise NotImplementedError("compute_reward_batch is not implemented in VLNCEService.")
        results = {}
        
        # Define worker function
        def compute_reward_single_env(env_id):
            env = self.environments[env_id]
            return env_id, env.compute_reward(), None
           
        
        # Use ThreadPoolExecutor for parallel computation
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all computation tasks
            futures = {
                executor.submit(compute_reward_single_env, env_id): env_id 
                for env_id in env_ids
            }
            
            # Process results as they complete
            for future in as_completed(futures):
                env_id = futures[future]
                env_id, result, error = future.result()
                if error:
                    print(f"Error computing reward for environment {env_id}: {error}")
                    results[env_id] = 0.0
                else:
                    results[env_id] = result
        
        return results
    
    def get_system_prompts_batch(self, env_ids: List[str]) -> Dict[str, str]:
        """
        Get system prompts for multiple VLNCE environments in parallel.
        
        Args:
            env_ids: A list of environment IDs
            
        Returns:
            A dictionary mapping each environment ID to its corresponding system prompt string
        """
        results = {}
        
        # Define worker function
        def get_system_prompt_single_env(env_id):
            env = self.environments[env_id]
            return env_id, env.system_prompt(), None
       
        
        # Use ThreadPoolExecutor for parallel retrieval
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all retrieval tasks
            futures = {
                executor.submit(get_system_prompt_single_env, env_id): env_id 
                for env_id in env_ids
            }
            
            # Process results as they complete
            for future in as_completed(futures):
                env_id = futures[future]
                env_id, result, error = future.result()
                if error:
                    print(f"Error getting system prompt for environment {env_id}: {error}")
                    results[env_id] = ""
                else:
                    results[env_id] = result
        
        return results
    
    def close_batch(self, env_ids: Optional[List[str]] = None) -> None:
        """
        Close multiple VLNCE environments and clean up resources in parallel.
        
        Args:
            env_ids: Optional list of environment IDs to close. If None, close all environments.
        """
        # If no env_ids provided, close all environments
        if env_ids is None:
            env_ids = list(self.environments.keys())
        
        # Define worker function
        def close_single_env(env_id):      
            env = self.environments[env_id]
            env.close()
            return None
            
        
        # Use ThreadPoolExecutor for parallel closing
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all closing tasks
            futures = [executor.submit(close_single_env, env_id) for env_id in env_ids]
            
            # Wait for all tasks to complete
            for future in as_completed(futures):
                error = future.result()
                if error:
                    print(f"Error closing environment: {error}")
        
        # Remove closed environments from dictionaries
        for env_id in env_ids:
            self.environments.pop(env_id, None)
            self.env_configs.pop(env_id, None)
+110 −0
Original line number Diff line number Diff line
import ray
from typing import Dict, List, Optional, Tuple, Any

from vagen.env.vlnce.env import VLNCEEnv
from vagen.env.vlnce.env_config import VLNCEEnvConfig
from vagen.server.serial import serialize_observation
from vagen.env.utils.state_reward_text_utils import service_state_reward_wrapper
from .service_config import VLNCEServiceConfig
from vagen.env.base.base_service import BaseService

def ensure_ray_initialized():
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True)

ensure_ray_initialized()

@ray.remote(num_gpus=0.001)
class VLNCEActor:
    def __init__(self, config_dict: Dict[str, Any]):
        env_config = VLNCEEnvConfig(**config_dict)
        self.env_config = env_config
        self.env = VLNCEEnv(env_config)

    def reset(self, seed: Optional[int] = None) -> Tuple[Dict, Dict]:
        obs, info = self.env.reset(seed=seed)
        return serialize_observation(obs), info

    def step(self, action: Any) -> Tuple[Dict, float, bool, Dict]:
        try:
            obs, reward, done, info = self.env.step(action)
        except Exception as e:
            print(f"[ERROR] step failed: {e}")
            try:
                obs, info = self.env.reset()
                reward = 0.0
                done = True
            except Exception as e2:
                print(f"[ERROR] reset after step failed: {e2}")
                obs = {"obs_str": "error"}
                reward = 0.0
                done = True
                info = {}
        return serialize_observation(obs), reward, done, info

    def get_system_prompt(self) -> str:
        return self.env.system_prompt()
    
    def compute_reward(self):
        return self.env.compute_reward()

    def close(self):
        self.env.close()

class VLNCEService(BaseService):
    def __init__(self, config: VLNCEServiceConfig):
        self.config = config
        self.actors: Dict[str, ray.actor.ActorHandle] = {}
        print(f"[DEBUG] {self.config}")

    def create_environments_batch(self, ids2configs: Dict[str, Any]) -> None:
        for env_id, config in ids2configs.items():
            env_name = config.get('env_name', 'vlnce')
            if env_name != 'vlnce':
                print(f"[ERROR] Unsupported environment type: {env_name}")
                continue
            actor = VLNCEActor.remote(config['env_config'])
            self.actors[env_id] = actor

    def reset_batch(self, ids2seeds: Dict[str, Any]) -> Dict[str, Tuple[Any, Any]]:
        futures = {
            env_id: self.actors[env_id].reset.remote(seed)
            for env_id, seed in ids2seeds.items()
        }
        results = {
            env_id: ray.get(future) for env_id, future in futures.items()
        }
        return results

    def step_batch(self, ids2actions: Dict[str, Any]) -> Dict[str, Tuple[Dict, float, bool, Dict]]:
        futures = {
            env_id: self.actors[env_id].step.remote(action)
            for env_id, action in ids2actions.items()
        }
        results = {
            env_id: ray.get(future) for env_id, future in futures.items()
        }
        return results

    def get_system_prompts_batch(self, env_ids: List[str]) -> Dict[str, str]:
        futures = {
            env_id: self.actors[env_id].get_system_prompt.remote()
            for env_id in env_ids
        }
        return {env_id: ray.get(future) for env_id, future in futures.items()}

    def compute_reward_batch(self, env_ids: List[str]) -> Dict[str, float]:
        futures = {
            env_id: self.actors[env_id].compute_reward.remote()
            for env_id in env_ids if env_id in self.actors
        }
        results = ray.get(futures.values())
        return {env_id: result for env_id, result in zip(futures.keys(), results)}
    
    def close_batch(self, env_ids: Optional[List[str]] = None) -> None:
        if env_ids is None:
            env_ids = list(self.actors.keys())
        futures = [self.actors[env_id].close.remote() for env_id in env_ids if env_id in self.actors]
        ray.get(futures)
        for env_id in env_ids:
            self.actors.pop(env_id, None)