Commit 5b0fb83c authored by jameskrw's avatar jameskrw
Browse files

minor

parent 08291d44
Loading
Loading
Loading
Loading
+11 −7
Original line number Diff line number Diff line
@@ -62,16 +62,20 @@ def create_dataset_from_yaml(yaml_file_path: str, force_gen=False,seed=42,train_
        env_name = value.get('env_name')
        custom_env_config = value.get('env_config', {})
        train_size,test_size = (value.get('train_size', 100), value.get('test_size', 100))
        env_size = train_size + test_size
       
        
        env_config = REGISTERED_ENV[env_name]["config_cls"](**custom_env_config)
        seeds_for_env = None
        seeds_for_env_train = None
        seeds_for_env_test = None
        if hasattr(env_config, 'generate_seeds'):
            seeds_for_env = env_config.generate_seeds(env_size)
            print(f"Using {len(seeds_for_env)} seeds generated by {env_name} config's generate_seeds method")
            seeds_for_env_train = env_config.generate_seeds(train_size)
            seeds_for_env_test = env_config.generate_seeds(test_size)
            print(f"Using {len(seeds_for_env_train)} trian seeds generated by {env_name} config's generate_seeds method")
            print(f"Using {len(seeds_for_env_test)} test seeds generated by {env_name} config's generate_seeds method")
        else:
            seeds_for_env = np.random.randint(0, 2**31 - 1, size=env_size).tolist()
        for seed in seeds_for_env[:train_size]:
            seeds_for_env_train = np.random.randint(0, 2**31 - 1, size=train_size).tolist()
            seeds_for_env_test = np.random.randint(0, 2**31 - 1, size=test_size).tolist()
        for seed in seeds_for_env_train:
            env_settings = {
                'env_name': env_name,
                'env_config': custom_env_config,
@@ -83,7 +87,7 @@ def create_dataset_from_yaml(yaml_file_path: str, force_gen=False,seed=42,train_
                "extra_info": {"split": "train", **env_settings}
            }
            train_instances.append(instance)
        for seed in seeds_for_env[train_size:]:
        for seed in seeds_for_env_test:
            env_settings = {
                'env_name': env_name,
                'env_config': custom_env_config,