Commit 175b10c8 authored by jameskrw's avatar jameskrw
Browse files

updated row dict

parent b2aa3f2d
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -506,7 +506,7 @@ class QwenVLRolloutManager():
            for idx,reward in enumerate(rewards):
                multi_turn_token_level_rewards[reward_positions[idx]] = reward
            row_dict["multi_turn_token_level_rewards"] = multi_turn_token_level_rewards # (seq_len,) 
            row_dict["reward_positions"] = reward_positions
            row_dict["end_of_response_position_mask"] = end_of_response_position_mask
        if self.config.use_loss_mask:
            row_dict['loss_mask'] = loss_mask
        if self.config.use_gae_mask:
@@ -629,7 +629,9 @@ class QwenVLRolloutManager():
            last_reward=self.envs[env_id].compute_reward()
            row_dict['reward_model'] = {"style": "given", "ground_truth": {"reward": last_reward+step_reward_sum}}
            if self.config.use_multi_turn_reward:
                last_reward_index = row_dict['reward_positions'][-1]
                end_of_response_position_mask = row_dict['end_of_response_position_mask']
                reward_positions = torch.nonzero(end_of_response_position_mask).squeeze(-1)
                last_reward_index = reward_positions[-1]
                row_dict['multi_turn_token_level_rewards'][last_reward_index] += last_reward
            batch_list.append(row_dict)
        batch_dict = collate_fn(batch_list)
+4 −2
Original line number Diff line number Diff line
@@ -517,7 +517,7 @@ class QwenVLRolloutManagerService():
            for idx,reward in enumerate(rewards):
                multi_turn_token_level_rewards[reward_positions[idx]] = reward
            row_dict["multi_turn_token_level_rewards"] = multi_turn_token_level_rewards # (seq_len,) 
            row_dict["reward_positions"] = reward_positions
            row_dict["end_of_response_position_mask"] = end_of_response_position_mask
        if self.config.use_loss_mask:
            row_dict['loss_mask'] = loss_mask
        if self.config.use_gae_mask:
@@ -646,7 +646,9 @@ class QwenVLRolloutManagerService():
    
            row_dict['reward_model'] = {"style": "given", "ground_truth": {"reward": reward_rst[env_id]+step_reward_sum}}
            if self.config.use_multi_turn_reward:
                last_reward_index = row_dict['reward_positions'][-1]
                end_of_response_position_mask = row_dict['end_of_response_position_mask']
                reward_positions = torch.nonzero(end_of_response_position_mask).squeeze(-1)
                last_reward_index = reward_positions[-1]
                row_dict['multi_turn_token_level_rewards'][last_reward_index] += reward_rst[env_id]
            batch_list.append(row_dict)
        batch_dict = collate_fn(batch_list)