Commit c9b93fac authored by jameskrw's avatar jameskrw
Browse files

updated state reward

parent 2df505f2
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
from vagen.env.base.base_env import BaseEnv
from vagen.env.alfworld.alfworld_utils import load_alfworld_dataset
from vagen.env.utils.context_utils import parse_llm_raw_response, convert_numpy_to_PIL
from vagen.env.utils.parse_utils import parse_function_map
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .env_config import AlfEnvConfig
from .prompt import (
    system_prompt,
@@ -61,7 +61,7 @@ class AlfEnv(BaseEnv):
        self.format_prompt_func = format_prompt[self.config.get('prompt_format', 'free_think')]
        
        # Get the parse function based on the prompt format
        self.parse_func = parse_function_map[self.config.get('prompt_format', 'free_think')]
        self.parse_func = PARSE_FUNC_MAP[self.config.get('prompt_format', 'free_think')]
        
        # Initialize the dataset
        self.dataset = self._load_dataset()

CrossViewQA @ 3b8a92a0

Original line number Diff line number Diff line
Subproject commit 3b8a92a025e47c10e3881c3940c98cf5262edd43
+6 −7
Original line number Diff line number Diff line
@@ -6,11 +6,11 @@ from gymnasium.utils import seeding
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv as GymFrozenLakeEnv
from vagen.env.utils.env_utils import NoLoggerWarnings, set_seed
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils_4 import parse_function_map
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .prompt import system_prompt, init_observation_template, action_template, format_prompt
from .env_config import FrozenLakeEnvConfig
from .utils import generate_random_map, is_valid

from vagen.env.utils.state_reward_utils import state_reward_wrapper
class FrozenLakeEnv(BaseEnv):
    """
    FrozenLake Environment for training and evaluating language models as agents.
@@ -78,7 +78,7 @@ class FrozenLakeEnv(BaseEnv):
        # Store the format prompt function for later use
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        
        self.parse_func = parse_function_map[self.config.prompt_format]
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]

    def reset(self, seed=None):
        """
@@ -103,6 +103,7 @@ class FrozenLakeEnv(BaseEnv):
        self.total_reward = 0
        return self._render(init_obs=True), {}

    @state_reward_wrapper
    def step(self, action_str: str):
        """
        Take a step in the environment based on the agent's action.
@@ -325,13 +326,11 @@ class FrozenLakeEnv(BaseEnv):
        # Get dimensions of the grid
        nrow, ncol = self.gym_env.desc.shape
        
        # Get player position
        player_position = self._get_player_position()  # Already returns (row, col)

        # Find target/goal position (marked as 'G')
        player_position = player_position = tuple(map(int, self._get_player_position()))
        
        target_position = tuple(map(int, np.argwhere(self.gym_env.desc == b'G')[0]))
        
        # Find all hole positions (marked as 'H')
        hole_positions = [tuple(map(int, pos)) for pos in np.argwhere(self.gym_env.desc == b'H')]
        
        return {
+4 −3
Original line number Diff line number Diff line
@@ -5,10 +5,10 @@ import time
import math
from ai2thor.platform import CloudRendering
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils_4 import parse_function_map
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .env_config import NavigationEnvConfig
from .prompt import system_prompt,init_observation_template, action_template, format_prompt

from vagen.env.utils.state_reward_utils import state_reward_wrapper

class NavigationEnv(BaseEnv):
    """Navigation environment from embodied bench. """   
@@ -98,7 +98,7 @@ class NavigationEnv(BaseEnv):
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        
        # Get the parse function based on the prompt format
        self.parse_func = parse_function_map[self.config.prompt_format]
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]
        
    def _get_dataset_path(self, eval_set):
        """Get the path to the dataset file."""
@@ -184,6 +184,7 @@ class NavigationEnv(BaseEnv):
        
        return self._render(init_obs=True), {}
    
    @state_reward_wrapper
    def step(self, action_str: str):
        """Execute an action in the environment.
        
+4 −2
Original line number Diff line number Diff line
@@ -4,12 +4,13 @@ import copy
from typing import Dict, List, Optional, Tuple, Any
from gymnasium.utils import seeding
from vagen.env.utils.context_utils import convert_numpy_to_PIL
from vagen.env.utils.parse_utils_4 import parse_function_map
from vagen.env.utils.parse_utils import PARSE_FUNC_MAP
from .env_config import PrimitiveSkillEnvConfig
from .maniskill.utils import build_env, handle_info, get_workspace_limits
from .prompt import system_prompt, init_observation_template, action_template, format_prompt
import vagen.env.primitive_skill.maniskill.env
import random
from vagen.env.utils.state_reward_utils import state_reward_wrapper
class PrimitiveSkillEnv(BaseEnv):
    def __init__(self, config: PrimitiveSkillEnvConfig):
        """
@@ -28,7 +29,7 @@ class PrimitiveSkillEnv(BaseEnv):
        
        # Store the format prompt function for later use based on the configuration
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        self.parse_func = parse_function_map[self.config.prompt_format]
        self.parse_func = PARSE_FUNC_MAP[self.config.prompt_format]
        # Define the state keys for the environment
        self.state_keys = self.env.state_keys
        self.last_info = None
@@ -54,6 +55,7 @@ class PrimitiveSkillEnv(BaseEnv):
        self.steps = 0
        return obs, {}
    
    @state_reward_wrapper
    def step(self, action_str):
        """
        Take a step in the environment based on the agent's action.
Loading