Commit fd93bc98 authored by wayne's avatar wayne
Browse files

feat: support gt action sequence and dynamically tuning the length of gt action seq

parent 803238af
Loading
Loading
Loading
Loading
+50 −11
Original line number Diff line number Diff line
@@ -254,6 +254,7 @@ class QwenVLRolloutManagerService():
        #print(f"[DEBUG] ids2seeds_reset: {ids2seeds_reset}")
        reset_results=self.env_client.reset_batch(ids2seeds_reset)
        
        assert len(ids2seeds_reset) == len(env_configs) 
        
        if self.recorder is not None:
            del self.recorder
@@ -263,6 +264,37 @@ class QwenVLRolloutManagerService():
        
        
        for env_id, rst in reset_results.items():
            if isinstance(rst, list):
                # initial obs
                obs, info = rst
                initial_obs[env_id] = obs
                initial_info[env_id] = info
                self.record(
                    env_id, 
                    obs=obs, 
                    reward=0, 
                    done=False, 
                    info=info
                )
                # execute provided gt actions
                for i, item in enumerate(rst):
                    obs, info = item
                    initial_obs[env_id] = obs
                    initial_info[env_id] = info
                    self.record(
                        env_id, 
                        obs=obs, 
                        reward=0, 
                        done=False, 
                        info={
                            **info,
                            "execlude_from_loss": True,
                            "llm_raw_response": env_configs[env_id]["gt_actions"][i],
                        }
                    )
            else:
                # Currently we blocked this part for debugging since resettings the VLNCE always return a list
                assert False
                obs, info = rst
                initial_obs[env_id] = obs
                initial_info[env_id] = info
@@ -330,6 +362,12 @@ class QwenVLRolloutManagerService():
        start_step = max(0, step - window_size) if window_size is not None else 0
        end_step = step
        assert len(recording) >= end_step + 1, 'History length is not enough'

        len_gt_actions = len([record for record in recording if record['info'].get('execlude_from_loss', False)])
        assert window_size is None, "Currently we only do not support history trunication"
        assert start_step == 0
        end_step += len_gt_actions
        
        history = recording[start_step: end_step + 1]
        rewards=[]
        chat = []
@@ -341,7 +379,8 @@ class QwenVLRolloutManagerService():
        for i, record in enumerate(history):
            if i>0:
                llm_raw_response = record['info']['llm_raw_response']
                filtered_llm_raw_response = self._handle_special_tokens(llm_raw_response, prep_for_loss_mask=prep_for_loss_mask)
                execlude_from_loss = record['info'].get('execlude_from_loss', False)
                filtered_llm_raw_response = self._handle_special_tokens(llm_raw_response, prep_for_loss_mask=prep_for_loss_mask and not execlude_from_loss)
                chat.append({"role": "assistant", "content": filtered_llm_raw_response})
                rewards.append(record['reward'])
            if i<len(history)-1 or not is_final:
+7 −0
Original line number Diff line number Diff line
@@ -15,8 +15,12 @@ def serialize_observation(observation: Dict[str, Any]) -> Dict[str, Any]:
    Returns:
        Serialized observation
    """
    if isinstance(observation, List):
        return [serialize_observation(obs) for obs in observation]
    
    serialized_obs = observation.copy()

    
    # Handle multi_modal_data if present
    if "multi_modal_data" in serialized_obs:
        serialized_multi_modal = {}
@@ -44,6 +48,9 @@ def deserialize_observation(serialized_obs: Dict[str, Any]) -> Dict[str, Any]:
    Returns:
        Deserialized observation
    """
    if isinstance(serialized_obs, List):
        return [deserialize_observation(obs) for obs in serialized_obs]
        
    deserialized_obs = serialized_obs.copy()
    
    # Handle multi_modal_data if present
+22 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import math
import os
import uuid
from contextlib import contextmanager
@@ -1017,6 +1018,8 @@ class RayPPOTrainer(object):
        all_final_gen_batch_outputs = []
        all_rst = []
        
        training_percent = self.global_steps / self.total_training_steps
        
        for i in range(num_mini_batches):
            start_idx = i * mini_batch_size
            end_idx = min((i + 1) * mini_batch_size, batch_size)
@@ -1029,6 +1032,24 @@ class RayPPOTrainer(object):
                for j in range(start_idx, end_idx)
            ]
            
            # dynamically set the size of history (ground-truth action sequence)
            for i in range(start_idx, end_idx):
                gt_actions = batch.non_tensor_batch['extra_info'][i]['gt_actions']
                num_actions = len(gt_actions)
                # At least 1 action should be dropped for training
                num_dropped_actions = min(1, math.ceil(num_actions * training_percent))
                num_gt_actions = num_actions - num_dropped_actions
                
                assert num_gt_actions >= 0 and num_gt_actions < num_actions
                assert num_gt_actions + num_dropped_actions == num_actions
                # TODO: add tolerance ratio and episode_step_limit as training arguments
                max_episode_steps = min(num_gt_actions + math.ceil(2 * len(num_dropped_actions)), 150)
                # Update the history_actions in the mini-batch env_configs
                mini_batch_env_configs[i - start_idx]['history_actions'] = gt_actions[:num_gt_actions]
                mini_batch_env_configs[i - start_idx]['max_episode_steps'] = max_episode_steps
                print(f"{training_percent=}, {num_actions=}, {num_dropped_actions=}, {num_gt_actions=} {max_episode_steps=}")

            
            # Reset and process this mini-batch
            rollout_manager.reset(mini_batch_env_configs)
            rollout_manager.rollout_loop()
@@ -1095,6 +1116,7 @@ class RayPPOTrainer(object):
        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                print(f'global_steps: {self.global_steps}')
                                
                metrics = {}
                timing_raw = {}

@@ -1118,11 +1140,9 @@ class RayPPOTrainer(object):
                
                


                with _timer('step', timing_raw):
                    # generate a batch
                    with _timer('gen', timing_raw):
                       
                        mini_batch_size=self.config.rollout_manager.get('mini_batch_size',len(batch))
                        final_gen_batch_output, rst=self._process_in_mini_batches(batch, rollout_manager, mini_batch_size) 
                        train_metrics=self.log_rst_to_metrics_dict(rst=rst,mode='train')