Commit f8e80831 authored by jameskrw's avatar jameskrw
Browse files

minor

parent 3c36d772
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -293,8 +293,8 @@ def reduce_metrics(metrics: dict):
def _compute_response_info(batch):
    if "loss_mask" in batch.batch.keys():
        # end_of_response_position_mask=batch.batch["end_of_response_position_mask"]
        response_length = batch.batch['loss_mask'].sum(-1)
        prompt_length = batch.batch['attention_mask'].sum(-1)-batch.batch['loss_mask'].sum(-1)
        response_length = batch.batch['loss_mask'].sum(-1).float()
        prompt_length = (batch.batch['attention_mask'].sum(-1)-batch.batch['loss_mask'].sum(-1)).float()
        response_part_length = batch.batch['responses'].shape[-1]
        response_mask = batch.batch['loss_mask'][:, -response_part_length:]
    else: