Commit 3045c90f authored by jameskrw's avatar jameskrw
Browse files

fix special token filtering bug

parent cf1d5e57
Loading
Loading
Loading
Loading
+3 −3
Original line number Original line Diff line number Diff line
@@ -39,13 +39,13 @@ class QwenVLRolloutManager():
        1. Filter out special tokens: <image> and special tokens marking environment observation in the llm generated response
        1. Filter out special tokens: <image> and special tokens marking environment observation in the llm generated response
        2. prep_for_loss_mask: if true, add special tokens to the beginning and end of the response if compute_loss_mask is True
        2. prep_for_loss_mask: if true, add special tokens to the beginning and end of the response if compute_loss_mask is True
        """
        """
        llm_raw_response = re.sub(r'<image>', '', llm_raw_response)
        llm_raw_response = llm_raw_response.replace('<image>', '')
        if prep_for_loss_mask:
        if prep_for_loss_mask:
            # filtering special tokens for llm_raw_response, then adding them to the beginning and end of the response for loss mask computation
            # filtering special tokens for llm_raw_response, then adding them to the beginning and end of the response for loss mask computation
            sptk_b = self.config.special_token_for_loss_mask[0]
            sptk_b = self.config.special_token_for_loss_mask[0]
            sptk_e = self.config.special_token_for_loss_mask[1]
            sptk_e = self.config.special_token_for_loss_mask[1]
            llm_raw_response = re.sub(sptk_e, '', llm_raw_response)
            llm_raw_response = llm_raw_response.replace(sptk_b, '')
            llm_raw_response = re.sub(sptk_b, '', llm_raw_response)
            llm_raw_response = llm_raw_response.replace(sptk_e, '')
            llm_raw_response = sptk_b + llm_raw_response + sptk_e
            llm_raw_response = sptk_b + llm_raw_response + sptk_e
        return llm_raw_response
        return llm_raw_response
    
    
+3 −3
Original line number Original line Diff line number Diff line
@@ -43,13 +43,13 @@ class QwenVLRolloutManagerService():
        1. Filter out special tokens: <image> and special tokens marking environment observation in the llm generated response
        1. Filter out special tokens: <image> and special tokens marking environment observation in the llm generated response
        2. prep_for_loss_mask: if true, add special tokens to the beginning and end of the response if compute_loss_mask is True
        2. prep_for_loss_mask: if true, add special tokens to the beginning and end of the response if compute_loss_mask is True
        """
        """
        llm_raw_response = re.sub(r'<image>', '', llm_raw_response)
        llm_raw_response = llm_raw_response.replace('<image>', '')
        if prep_for_loss_mask:
        if prep_for_loss_mask:
            # filtering special tokens for llm_raw_response, then adding them to the beginning and end of the response for loss mask computation
            # filtering special tokens for llm_raw_response, then adding them to the beginning and end of the response for loss mask computation
            sptk_b = self.config.special_token_for_loss_mask[0]
            sptk_b = self.config.special_token_for_loss_mask[0]
            sptk_e = self.config.special_token_for_loss_mask[1]
            sptk_e = self.config.special_token_for_loss_mask[1]
            llm_raw_response = re.sub(sptk_e, '', llm_raw_response)
            llm_raw_response = llm_raw_response.replace(sptk_b, '')
            llm_raw_response = re.sub(sptk_b, '', llm_raw_response)
            llm_raw_response = llm_raw_response.replace(sptk_e, '')
            llm_raw_response = sptk_b + llm_raw_response + sptk_e
            llm_raw_response = sptk_b + llm_raw_response + sptk_e
        return llm_raw_response
        return llm_raw_response