Commit fef7d8a8 authored by jameskrw's avatar jameskrw
Browse files

updated env

parent 58612539
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -62,9 +62,11 @@ class BaseEnv(ABC):
        """
        pass
    
    @abstractmethod
   
    def compute_reward(self) -> float:
        """
        give final reward
        Currently the reward calculation in rollout manager will be: sum(step_rewards)+env.compute_reward()
        In most cases you can set this to 0.0 since the step rewards are already accumulated, but if you want to add some extra reward for the final step, define it here.
        """
        pass
 No newline at end of file
        return 0.0
 No newline at end of file
+0 −3
Original line number Diff line number Diff line
@@ -140,9 +140,6 @@ Given images and a question, first give your thought then answer.
Your answer should be in the format of <think>...</think><answer>...</answer>.
e.g. <think>I can see there're multiple images with different view. I can see from the second view the object is on the target's left.I think the correct answer is A</think><answer>A</answer>"""

    def compute_reward(self) -> float:
        return 0.0 


if __name__ == "__main__":
    # Create config
+3 −11
Original line number Diff line number Diff line
@@ -180,6 +180,9 @@ class FrozenLakeEnv(BaseEnv):
        # Add format reward if actions were valid
        if metrics["turn_metrics"]['action_is_valid'] and rst["format_correct"]:
            self.reward += self.config.format_reward
            info["is_format_rewarded"] = True
        else:
            info["is_format_rewarded"] = False
        
        
        # Check if position changed to determine if action was effective
@@ -210,17 +213,6 @@ class FrozenLakeEnv(BaseEnv):
        
        return system_prompt() + '\n' + format_prompt_text

    def compute_reward(self):
        """
        Get the cumulative reward for the episode.
        
        Returns:
            float: Total reward accumulated during the current episode
        """
        # Now we accumulate reward in each step in rollout_manager
        # Set it to non-zero only if you want to give a special trajectory reward
        return 0.0 

    def close(self):
        self.gym_env.close()

+9 −13
Original line number Diff line number Diff line
@@ -227,11 +227,9 @@ class NavigationEnv(BaseEnv):
        info = {}
        info.update(rst)
            
            
        # Execute valid actions
        if metrics["turn_metrics"]["action_is_valid"]:
            # Add format reward if actions were valid and format is correct
            if rst.get("format_correct", True):
                self.reward += self.config.format_reward
        if metrics["turn_metrics"]["action_is_valid"] and rst.get("format_correct", True):
            
            for action in action_list:
                action_lower = action.lower()
@@ -259,6 +257,12 @@ class NavigationEnv(BaseEnv):
                    done = True
                    break
        
        if metrics['turn_metrics']['action_is_valid'] and rst.get("format_correct", True):
            self.reward += self.config.format_reward
            info["is_format_rewarded"] = True
        else:
            info["is_format_rewarded"] = False
            
        # Check if the agent position has changed (action was effective)
        curr_pos = self.env.last_event.metadata["agent"]["position"]
        metrics['turn_metrics']['action_is_effective'] = curr_pos["x"] != prev_pos["x"] or curr_pos["z"] != prev_pos["z"]
@@ -388,14 +392,6 @@ class NavigationEnv(BaseEnv):
    
        return system_prompt(format=self.config.prompt_format) + '\n' + format_prompt_text
    
    def compute_reward(self):
        """Compute the total reward for the episode.
        
        Returns:
            Total reward
        """
        return 0.0
    
    def close(self):
        """Close the environment."""
        self.env.stop()
+5 −2
Original line number Diff line number Diff line
@@ -111,7 +111,10 @@ class PrimitiveSkillEnv(BaseEnv):
        metrics["turn_metrics"]['action_is_valid'] = len(valid_actions) > 0 and len(valid_actions) == len(rst['actions'])
        if metrics["turn_metrics"]['action_is_valid'] and rst["format_correct"]:
            reward += self.config.format_reward
        # Check for success
            output_info["is_format_rewarded"] = True
        else:
            output_info["is_format_rewarded"] = False
        
        if info.get('is_success', False):
            metrics["traj_metrics"]['success'] = True
        
Loading