Commit f246249e authored by jameskrw's avatar jameskrw
Browse files

minor

parent b6da0609
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -78,7 +78,7 @@ class FrozenLakeEnv(BaseEnv):
        # Store the format prompt function for later use
        self.format_prompt_func = format_prompt[self.config.prompt_format]
        
        self.parse_func = parse_function_map[self.config.prompt_format.rstrip("_symbol")]
        self.parse_func = parse_function_map[self.config.prompt_format]

    def reset(self, seed=None):
        """
+1 −1
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ class FrozenLakeEnvConfig(BaseEnvConfig):
    min_actions_to_succeed: int = 5
    prompt_format: str = "free_think" 
    # "free_think", "no_think", "grounding", "worldmodeling", "grounding_worldmodeling"
    # "grounding_symbol", "worldmodeling_symbol", "grounding_worldmodeling_symbol"
    # "grounding_symbolic", "worldmodeling_symbolic", "grounding_worldmodeling_symbolic"
    # "grounding_structured", "worldmodeling_structured", "grounding_worldmodeling_structured"
    use_accuracy_reward: bool = False
    
+3 −3
Original line number Diff line number Diff line
@@ -57,21 +57,21 @@ FORMAT_CONFIGS = {
        "example": "<think><observation>The player is on the above the target</observation><reasoning>I should go down then left to reach the target</reasoning><prediction>The player will reach the target</prediction></think><answer>Down{action_sep}Left</answer>"
    },
    
    "grounding_symbol": {
    "grounding_symbolic": {
        "format": "<think><observation>...</observation><reasoning>...</reasoning></think><answer>...</answer>",
        "description": "You should first describe the observation as a grid, then your reasoning, and finally your answer.",
        "additional_info": "The observation should be represented as a grid using the symbols: _ Frozen | O Hole | G Goal | P Player | X Player fell into hole | √ Player on goal.",
        "example": "<think><observation>_P__\nG___\n*OO*\n____</observation><reasoning>I should go down then left to reach the target</reasoning></think><answer>Down{action_sep}Left</answer>"
    },
    
    "worldmodeling_symbol": {
    "worldmodeling_symbolic": {
        "format": "<think><reasoning>...</reasoning><prediction>...</prediction></think>",
        "description": "You should first give your reasoning, then predict the next state, and finally your answer.",
        "additional_info": "The prediction should be represented as a grid using the symbols: _ Frozen | O Hole | G Goal | P Player | X Player fell into hole | √ Player on goal.",
        "example": "<think><reasoning>I can see the target is on my down left, I should go down then left</reasoning><prediction>____\n√___\n*OO*\n____</prediction></think><answer>Down{action_sep}Left</answer>"
    },
    
    "grounding_worldmodeling_symbol": {
    "grounding_worldmodeling_symbolic": {
        "format": "<think><observation>...</observation><reasoning>...</reasoning><prediction>...</prediction></think>",
        "description": "You should first describe the observation as a grid, then your reasoning, then predict the next state, and finally your answer.",
        "additional_info": "The observation and state should be represented as grids using the symbols: _ Frozen | O Hole | G Goal | P Player | X Player fell into hole | √ Player on goal.",
+1 −2
Original line number Diff line number Diff line
@@ -370,8 +370,7 @@ if __name__ == "__main__":
    while True:
        i += 1
        action = input("Enter action:")
        action = f"<think>Let me try this direction.</think><answer>{action}</answer>"
        
        #action = f"<think>Let me try this direction.</think><answer>{action}</answer>"
        obs, reward, done, info = env.step(action)
        print(obs["obs_str"])
        
+1 −1
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ class SokobanEnv(BaseEnv):
        # Call the function with add_example=True for system prompt
    
        
        self.parse_func = parse_function_map[self.config.prompt_format.rstrip("_symbol")]
        self.parse_func = parse_function_map[self.config.prompt_format]
        
    def reset(self, seed=None):
        with NoLoggerWarnings():
Loading