Commit 095645e6 authored by jameskrw's avatar jameskrw
Browse files

updated cross view

parent e97b6a58
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ class BaseEnvConfig(ABC):
        return getattr(self, key, default)
    
    def generate_seeds(self,size,seed=0,n_candidate: int = 20000,) -> list:
        # you can define it in your own env_config to support customized seed geenration
        random.seed(seed)
        seeds=random.sample(range(0, n_candidate+size), size)
        return seeds
 No newline at end of file
+55 −7
Original line number Diff line number Diff line
@@ -15,19 +15,62 @@ class CrossViewEnv(BaseEnv):
    def __init__(self, config: CrossViewEnvConfig):
        self.config = config
        self.script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),"CrossViewQA")
        self.data_path = os.path.join(self.script_dir, config.data_path)
        self.split = config.split
        self.image_dir = os.path.join(self.script_dir, config.image_dir)
        
        # Load dataset
        with open(self.data_path, 'r', encoding='utf-8') as f:
        if self.split == "train":
            self.data_path=os.path.join(self.script_dir, config.train_data_path)
            with open(self.data_path, "r") as f:
                self.dataset = json.load(f)   
        print(f"Loaded {len(self.dataset)} examples from {self.data_path}")
        
        elif self.split == "test":
            self.data_path=os.path.join(self.script_dir, config.test_data_path)
           # this is jsonl file
            with open(self.data_path, "r") as f:
                self.dataset = [json.loads(line) for line in f]
            self.dataset=self._convert_data_format(self.dataset)
        self.current_data = None
        self.current_seed = None
        self.done = False
        self.total_reward = 0
    
    def _convert_data_format(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Convert the JSONL format to the required JSON structure.
        
        Args:
            data: List of dictionaries from the JSONL file
            
        Returns:
            List of dictionaries in the required format with id, conversation, and images
        """
        converted_data = []
        
        for item in data:
            # Create the conversation structure
            conversation = [
                {
                    "role": "user",
                    "content": item["question"]
                },
                {
                    "role": "assistant",
                    "content": item["gt_answer"]
                }
            ]
            
            # Create the converted item
            converted_item = {
                "id": item["id"],
                "conversation": conversation,
                "images": item["images"]
            }
            
            converted_data.append(converted_item)
        
        return converted_data
        
        
        
    def reset(self, seed=None) -> Tuple[Dict, Dict]:
        """Reset environment with new seed"""
        if seed is not None:
@@ -69,7 +112,12 @@ class CrossViewEnv(BaseEnv):
        
        # Create observation string with image placeholders
        image_placeholders = " ".join([self.config.image_placeholder] * len(images))
        obs_str = f"Question: {question}\n{image_placeholders}\nPlease look at the images and answer the question."
        obs_str = f"""Question: {question}
{image_placeholders}
Please look at the images and answer the question. 
Your answer should be in the format of <think>...</think><answer>...</answer>. 
Please give your thought first then answer.
e.g. <think>I can see there're multiple images with different view. I can see from the second view the object is on the target's left.I think the correct answer is A</think><answer>A</answer>"""
        
        return {
            'obs_str': obs_str,
@@ -94,7 +142,7 @@ class CrossViewEnv(BaseEnv):
        
        # Simple exact match (case-insensitive)
        action_is_valid = action_content != ""
        success = action_is_valid and action_content.lower() == ground_truth.lower()
        success = action_is_valid and action_content.strip().lower()[0] == ground_truth.strip().lower()[0]
        action_is_effective = action_is_valid
        
        # Compute reward - base reward + format reward if applicable
+8 −2
Original line number Diff line number Diff line
@@ -4,14 +4,20 @@ from typing import Optional, List, Union

@dataclass
class CrossViewEnvConfig(BaseEnvConfig):
    data_path: str = "crossviewQA_train_qwenformat_singleletter.json"
    split="train"
    image_dir: str = "extracted_images"
    image_size: tuple = (300, 300)
    render_mode: str = "vision"
    
    train_data_path: str = "crossviewQA_train_qwenformat_singleletter.json"
    test_data_path: str = "crossviewQA_tinybench.jsonl"
    def config_id(self) -> str:
        return f"CrossViewQAEnv"
    
    def generate_seeds(self, size, seed=0, n_candidate = 20000):
        return [i for i in range(size)]
            
            
            
if __name__ == "__main__":
    config = CrossViewEnvConfig()
    print(config.config_id())
 No newline at end of file