Commit 0f743f3f authored by jameskrw's avatar jameskrw
Browse files

updated prompt length calculation

parent 43612654
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
from .sokoban import SokobanEnv, SokobanEnvConfig, SokobanService, SokobanServiceConfig
from .frozenlake import FrozenLakeEnv,FrozenLakeEnvConfig, FrozenLakeService
from .navigation import NavigationEnv, NavigationEnvConfig, NavigationServiceConfig, NavigationService
# from .navigation import NavigationEnv, NavigationEnvConfig, NavigationServiceConfig, NavigationService
# from .svg import SVGEnv, SvgEnvConfig, SVGService, SVGServiceConfig
# from .primitive_skill import PrimitiveSkillEnv, PrimitiveSkillEnvConfig, PrimitiveSkillService, PrimitiveSkillServiceConfig
# from .alfworld import ALFWorldEnv, ALFWorldEnvConfig, ALFWorldService, ALFWorldServiceConfig
@@ -17,12 +17,12 @@ REGISTERED_ENV = {
        "config_cls": FrozenLakeEnvConfig,
        "service_cls": FrozenLakeService
    },
    "navigation": {
        "env_cls": NavigationEnv,
        "config_cls": NavigationEnvConfig,
        "service_cls": NavigationService,
        "service_config_cls": NavigationServiceConfig
    },
    # "navigation": {
    #     "env_cls": NavigationEnv,
    #     "config_cls": NavigationEnvConfig,
    #     "service_cls": NavigationService,
    #     "service_config_cls": NavigationServiceConfig
    # },
    # "svg": {
    #     "env_cls": SVGEnv,
    #     "config_cls": SvgEnvConfig,
+3 −3
Original line number Diff line number Diff line
@@ -292,9 +292,9 @@ def reduce_metrics(metrics: dict):

def _compute_response_info(batch):
    if "loss_mask" in batch.batch.keys():
        end_of_response_position_mask=batch.batch["end_of_response_position_mask"]
        response_length = (batch.batch['loss_mask'].sum(-1)/ end_of_response_position_mask.sum(-1))
        prompt_length = ((batch.batch['attention_mask']-batch.batch['loss_mask']).sum(-1)/ end_of_response_position_mask.sum(-1))
        # end_of_response_position_mask=batch.batch["end_of_response_position_mask"]
        response_length = batch.batch['loss_mask'].sum(-1)
        prompt_length = batch.batch['attention_mask'].sum(-1)-batch.batch['loss_mask'].sum(-1)
        response_part_length = batch.batch['responses'].shape[-1]
        response_mask = batch.batch['loss_mask'][:, -response_part_length:]
    else: