Commit e3d33e16 authored by 张泽凯's avatar 张泽凯
Browse files

fix

parent 7a6d0498
Loading
Loading
Loading
Loading
+12 −8
Original line number Diff line number Diff line
@@ -599,7 +599,9 @@ class RayPPOTrainer(object):
                                         max_prompt_length=self.config.data.max_prompt_length,
                                         filter_prompts=True,
                                         return_raw_chat=self.config.data.get('return_raw_chat', False),
                                         truncation='error')
                                         truncation='error',
                                         max_pixels=self.config.data.max_pixels,
                                         min_pixels=self.config.data.min_pixels,)
        # use sampler for better ckpt resume
        if self.config.data.shuffle:
            train_dataloader_generator = torch.Generator()
@@ -1033,8 +1035,11 @@ class RayPPOTrainer(object):
            ]
            
            # dynamically set the size of history (ground-truth action sequence)
            assert 'gt_actions' in mini_batch_env_configs[0]
            mini_batch_gt_actions = [item["gt_actions"] for item in mini_batch_env_configs]
            
            for i in range(start_idx, end_idx):
                gt_actions = batch.non_tensor_batch['extra_info'][i]['gt_actions']
                gt_actions = mini_batch_gt_actions[i - start_idx]
                num_actions = len(gt_actions)
                # At least 1 action should be dropped for training
                num_dropped_actions = min(1, math.ceil(num_actions * training_percent))
@@ -1043,12 +1048,13 @@ class RayPPOTrainer(object):
                assert num_gt_actions >= 0 and num_gt_actions < num_actions
                assert num_gt_actions + num_dropped_actions == num_actions
                # TODO: add tolerance ratio and episode_step_limit as training arguments
                max_episode_steps = min(num_gt_actions + math.ceil(2 * len(num_dropped_actions)), 150)
                max_episode_steps = min(num_gt_actions + math.ceil(2 * num_dropped_actions), 150)
                # Update the history_actions in the mini-batch env_configs
                mini_batch_env_configs[i - start_idx]['history_actions'] = gt_actions[:num_gt_actions]
                mini_batch_env_configs[i - start_idx]['max_episode_steps'] = max_episode_steps
                print(f"{training_percent=}, {num_actions=}, {num_dropped_actions=}, {num_gt_actions=} {max_episode_steps=}")
                mini_batch_env_configs[i - start_idx]['env_config']['history_actions'] = gt_actions[:num_gt_actions].tolist()
                mini_batch_env_configs[i - start_idx].pop("gt_actions", None) # gt_actions is converted to np.array which is not serializable, so we delete it
              
                mini_batch_env_configs[i - start_idx]['env_config']['max_episode_steps'] = max_episode_steps
                print(f"{training_percent=}, {num_actions=}, {num_dropped_actions=}, {num_gt_actions=} {max_episode_steps=}")

            # Reset and process this mini-batch
            rollout_manager.reset(mini_batch_env_configs)
@@ -1139,7 +1145,6 @@ class RayPPOTrainer(object):
                batch = batch.repeat(repeat_times=self.config.rollout_manager.n_trajectory, interleave=True)
                
                

                with _timer('step', timing_raw):
                    # generate a batch
                    with _timer('gen', timing_raw):
@@ -1149,7 +1154,6 @@ class RayPPOTrainer(object):
                        metrics.update(train_metrics)
                    print(f"[DEBUG] step {self.global_steps} rollout ends")
                    batch = batch.union(final_gen_batch_output)

                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo