Unverified Commit ee1838ae authored by Kangrui Wang's avatar Kangrui Wang Committed by GitHub
Browse files

Update __init__.py

parent 138152e3
Loading
Loading
Loading
Loading
+47 −24
Original line number Diff line number Diff line
# First, import the modules that are assumed to be always available
from .sokoban import SokobanEnv, SokobanEnvConfig, SokobanService, SokobanServiceConfig
from .frozenlake import FrozenLakeEnv, FrozenLakeEnvConfig, FrozenLakeService
from .navigation import NavigationEnv, NavigationEnvConfig, NavigationServiceConfig, NavigationService
from .svg import SVGEnv, SvgEnvConfig, SVGService, SVGServiceConfig
from .primitive_skill import PrimitiveSkillEnv, PrimitiveSkillEnvConfig, PrimitiveSkillService, PrimitiveSkillServiceConfig
# from .alfworld import ALFWorldEnv, ALFWorldEnvConfig, ALFWorldService, ALFWorldServiceConfig
# from .crossview import CrossViewEnv, CrossViewEnvConfig

REGISTERED_ENV = {
    "sokoban": {
        "env_cls": SokobanEnv,
@@ -16,33 +13,59 @@ REGISTERED_ENV = {
        "env_cls": FrozenLakeEnv,
        "config_cls": FrozenLakeEnvConfig,
        "service_cls": FrozenLakeService
    },
    "navigation": {
    }
}

try:
    from .navigation import NavigationEnv, NavigationEnvConfig, NavigationServiceConfig, NavigationService
    REGISTERED_ENV["navigation"] = {
        "env_cls": NavigationEnv,
        "config_cls": NavigationEnvConfig,
        "service_cls": NavigationService,
        "service_config_cls": NavigationServiceConfig
    },
    "svg": {
    }
except ImportError:
    pass

try:
    from .svg import SVGEnv, SvgEnvConfig, SVGService, SVGServiceConfig
    REGISTERED_ENV["svg"] = {
        "env_cls": SVGEnv,
        "config_cls": SvgEnvConfig,
        "service_cls": SVGService,
        "service_config_cls": SVGServiceConfig
    },
    "primitive_skill": {
    }
except ImportError:
    pass

try:
    from .primitive_skill import PrimitiveSkillEnv, PrimitiveSkillEnvConfig, PrimitiveSkillService, PrimitiveSkillServiceConfig
    REGISTERED_ENV["primitive_skill"] = {
        "env_cls": PrimitiveSkillEnv,
        "config_cls": PrimitiveSkillEnvConfig,
        "service_cls": PrimitiveSkillService,
        "service_config_cls": PrimitiveSkillServiceConfig
    },
    # "alfworld": {
    #     "env_cls": ALFWorldEnv,
    #     "config_cls": ALFWorldEnvConfig,
    #     "service_cls": ALFWorldService,
    #     "service_config_cls": ALFWorldServiceConfig
    # },
    # "crossview": {
    #     "env_cls": CrossViewEnv,
    #     "config_cls": CrossViewEnvConfig
    # }
    }
except ImportError:
    pass


try:
    from .alfworld import ALFWorldEnv, ALFWorldEnvConfig, ALFWorldService, ALFWorldServiceConfig
    REGISTERED_ENV["alfworld"] = {
        "env_cls": ALFWorldEnv,
        "config_cls": ALFWorldEnvConfig,
        "service_cls": ALFWorldService,
        "service_config_cls": ALFWorldServiceConfig
    }
except ImportError:
    pass

try:
    from .crossview import CrossViewEnv, CrossViewEnvConfig
    REGISTERED_ENV["crossview"] = {
        "env_cls": CrossViewEnv,
        "config_cls": CrossViewEnvConfig
    }
except ImportError:
    pass