Commit 695eae29 authored by 张泽凯's avatar 张泽凯
Browse files

fix

parent e3d33e16
Loading
Loading
Loading
Loading
+12 −7
Original line number Original line Diff line number Diff line
@@ -9,14 +9,15 @@ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# run python -m vagen.server.server in a tmux session first
# run python -m vagen.server.server in a tmux session first


# max_trajectory_length = max_prompt_length + max_response_length
# max_trajectory_length = max_prompt_length + max_response_length

export RAY_DEBUG="1"
export RAY_DEBUG_POST_MORTEM="1"
python3 -m vagen.trainer.main_ppo \
python3 -m vagen.trainer.main_ppo \
    algorithm.adv_estimator=masked_gae \
    algorithm.adv_estimator=masked_gae \
    algorithm.high_level_gamma=0.95 \
    algorithm.high_level_gamma=0.95 \
    data.train_files=data/vlnce/train.parquet \
    data.train_files=data/vlnce/train.parquet \
    data.val_files=data/vlnce/train.parquet \
    data.val_files=data/vlnce/train.parquet \
    data.train_batch_size=32 \
    data.train_batch_size=2 \
    data.val_batch_size=32 \
    data.val_batch_size=2 \
    data.max_prompt_length=1024 \
    data.max_prompt_length=1024 \
    data.max_response_length=128 \
    data.max_response_length=128 \
    data.max_trajectory_length=20000 \
    data.max_trajectory_length=20000 \
@@ -25,7 +26,7 @@ python3 -m vagen.trainer.main_ppo \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=32 \
    actor_rollout_ref.actor.ppo_mini_batch_size=1 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
@@ -38,6 +39,7 @@ python3 -m vagen.trainer.main_ppo \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.max_num_batched_tokens=20000 \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=1 \
    actor_rollout_ref.rollout.n=1 \
@@ -45,6 +47,7 @@ python3 -m vagen.trainer.main_ppo \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.rollout.top_p=0.95 \
    actor_rollout_ref.rollout.top_p=0.95 \
    actor_rollout_ref.rollout.temperature=0.7 \
    actor_rollout_ref.rollout.temperature=0.7 \
    actor_rollout_ref.rollout.limit_mm_per_prompt=200 \
    critic.optim.lr=1e-5 \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.use_remove_padding=True \
    critic.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
    critic.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
@@ -57,7 +60,7 @@ python3 -m vagen.trainer.main_ppo \
    trainer.logger=['console'] \
    trainer.logger=['console'] \
    trainer.project_name='vlnce' \
    trainer.project_name='vlnce' \
    trainer.experiment_name='vlnce_service' \
    trainer.experiment_name='vlnce_service' \
    trainer.n_gpus_per_node=4 \
    trainer.n_gpus_per_node=2 \
    trainer.nnodes=1 \
    trainer.nnodes=1 \
    trainer.save_freq=70 \
    trainer.save_freq=70 \
    trainer.test_freq=9999999 \
    trainer.test_freq=9999999 \
@@ -69,8 +72,10 @@ python3 -m vagen.trainer.main_ppo \
    rollout_manager.use_gae_mask=True \
    rollout_manager.use_gae_mask=True \
    trainer.val_before_train=False \
    trainer.val_before_train=False \
    trainer.val_generations_to_log_to_wandb=8 \
    trainer.val_generations_to_log_to_wandb=8 \
    rollout_manager.n_trajectory=8 \
    rollout_manager.n_trajectory=2 \
    rollout_manager.use_service=True \
    rollout_manager.use_service=True \
    rollout_manager.timeout=9999999 \
    rollout_manager.timeout=9999999 \
    rollout_manager.base_url="http://localhost:5001" \
    rollout_manager.base_url="http://172.18.35.200:5000" \
    rollout_manager.max_pixels=76800 \
    rollout_manager.min_pixels=1024 \
    2>&1 | tee vlnce.log
    2>&1 | tee vlnce.log
 No newline at end of file
+8 −8
Original line number Original line Diff line number Diff line
@@ -16,6 +16,7 @@ from verl.utils.dataset.rl_dataset import process_image, collate_fn
import vagen.env
import vagen.env
from vagen.env import REGISTERED_ENV
from vagen.env import REGISTERED_ENV
from vagen.server.client import BatchEnvClient
from vagen.server.client import BatchEnvClient
from qwen_vl_utils import smart_resize
    
    
class QwenVLRolloutManagerService():
class QwenVLRolloutManagerService():
    def __init__(self,
    def __init__(self,
@@ -264,9 +265,10 @@ class QwenVLRolloutManagerService():
        
        
        
        
        for env_id, rst in reset_results.items():
        for env_id, rst in reset_results.items():
            if isinstance(rst, list):
            obss, info = rst
            if isinstance(obss, (list, tuple)):
                # initial obs
                # initial obs
                obs, info = rst
                obs = obss[0]
                initial_obs[env_id] = obs
                initial_obs[env_id] = obs
                initial_info[env_id] = info
                initial_info[env_id] = info
                self.record(
                self.record(
@@ -277,8 +279,7 @@ class QwenVLRolloutManagerService():
                    info=info
                    info=info
                )
                )
                # execute provided gt actions
                # execute provided gt actions
                for i, item in enumerate(rst):
                for i, obs in enumerate(obss[1:]):
                    obs, info = item
                    initial_obs[env_id] = obs
                    initial_obs[env_id] = obs
                    initial_info[env_id] = info
                    initial_info[env_id] = info
                    self.record(
                    self.record(
@@ -289,7 +290,7 @@ class QwenVLRolloutManagerService():
                        info={
                        info={
                            **info,
                            **info,
                            "execlude_from_loss": True,
                            "execlude_from_loss": True,
                            "llm_raw_response": env_configs[env_id]["gt_actions"][i],
                            "llm_raw_response": self.envs[env_id].history_actions[i],
                        }
                        }
                    )
                    )
            else:
            else:
@@ -331,7 +332,7 @@ class QwenVLRolloutManagerService():
        image_placeholder = self.envs[env_id].get('image_placeholder', "<image>")
        image_placeholder = self.envs[env_id].get('image_placeholder', "<image>")
        if 'multi_modal_data' in obs:
        if 'multi_modal_data' in obs:
            if image_placeholder in obs['multi_modal_data']:
            if image_placeholder in obs['multi_modal_data']:
                record_entry['image_data'] = [process_image(image) for image in obs['multi_modal_data'][image_placeholder]]
                record_entry['image_data'] = [process_image(image, min_pixels=self.config.min_pixels, max_pixels=self.config.max_pixels) for image in obs['multi_modal_data'][image_placeholder]]
        self.recorder[env_id].append(record_entry)
        self.recorder[env_id].append(record_entry)


    @torch.no_grad()
    @torch.no_grad()
@@ -364,8 +365,7 @@ class QwenVLRolloutManagerService():
        assert len(recording) >= end_step + 1, 'History length is not enough'
        assert len(recording) >= end_step + 1, 'History length is not enough'


        len_gt_actions = len([record for record in recording if record['info'].get('execlude_from_loss', False)])
        len_gt_actions = len([record for record in recording if record['info'].get('execlude_from_loss', False)])
        assert window_size is None, "Currently we only do not support history trunication"
        assert start_step == 0, "Currently we only do not support history trunication"
        assert start_step == 0
        end_step += len_gt_actions
        end_step += len_gt_actions
        
        
        history = recording[start_step: end_step + 1]
        history = recording[start_step: end_step + 1]
+3 −3
Original line number Original line Diff line number Diff line
@@ -14,8 +14,6 @@ data:
  filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
  filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
  truncation: error
  truncation: error
  image_key: images
  image_key: images
  max_pixels: 76800
  min_pixels: 1024


actor_rollout_ref:
actor_rollout_ref:
  hybrid_engine: True
  hybrid_engine: True
@@ -208,3 +206,5 @@ rollout_manager:
  use_service: False
  use_service: False
  timeout: 1200
  timeout: 1200
  max_workers: 8
  max_workers: 8
  max_pixels: 76800
  min_pixels: 1024
 No newline at end of file
+1 −3
Original line number Original line Diff line number Diff line
@@ -599,9 +599,7 @@ class RayPPOTrainer(object):
                                         max_prompt_length=self.config.data.max_prompt_length,
                                         max_prompt_length=self.config.data.max_prompt_length,
                                         filter_prompts=True,
                                         filter_prompts=True,
                                         return_raw_chat=self.config.data.get('return_raw_chat', False),
                                         return_raw_chat=self.config.data.get('return_raw_chat', False),
                                         truncation='error',
                                         truncation='error')
                                         max_pixels=self.config.data.max_pixels,
                                         min_pixels=self.config.data.min_pixels,)
        # use sampler for better ckpt resume
        # use sampler for better ckpt resume
        if self.config.data.shuffle:
        if self.config.data.shuffle:
            train_dataloader_generator = torch.Generator()
            train_dataloader_generator = torch.Generator()