#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Query tools
#
# Software is free software released under the "Modified BSD license"
#
# Copyright (c) 2014-2015   Alexandre Dulaunoy - a@foo.be
# Copyright (c) 2014-2017 	Pieter-Jan Moreels - pieterjan.moreels@gmail.com

import os
import json
import requests
import sys
import urllib.parse
runPath = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(runPath, ".."))

from lib.Objects import CVE
from web.api     import API

class Query():
    def __init__(self, api=None):
        self.apiurl = api
        if not api:
          self.api = API()


    def _query(self, url):
        r = requests.get(urllib.parse.urljoin(self.apiurl, url),
                         headers={'Version': '2.0', 'Accept': '*/json'})
        data = json.loads(r.text)
        if   r.status_code is 200 and data['status'] == 'success':
            return data['data']
        elif (r.status_code is 404 and data['status'] == 'success' 
                                   and data['reason'] == 'cve not found'):
            return None
        else:
            return False


    def cve(self, cveid):
        if self.apiurl:
            cve = self._query("api/cve/%s"%str(cveid))
            return cve and CVE.fromDict(cve) or None
        try:
            return self.api.api_cve(cveid)[0]
        except:
            return None


    def cveforcpe(self, cpe):
        if self.apiurl:
            cpes = self._query("api/cvefor/%s"%str(cpe))
            return cpes and [CVE.fromDict(x) for x in cpes] or cpes
        return self.api.api_cvesFor(cpe)[0]


    def last(self, entries=None):
        if self.apiurl:
            url = "api/last"
            if entries:
                url = "%s/%s"%(url, str(entries))
            cves = self._query(url)
            return cves and [CVE.fromDict(x) for x in cves] or cves
        return self.api.api_last(entries)[0]


    def browse(self, vendor=None):
        if self.apiurl:
            url = "api/browse"
            if vendor:
                url = "%s/%s"%(url, vendor)
            return self._query(url)
        return self.api.api_browse(vendor)


    def search(self, query):
        if self.apiurl:
            return self._query("api/search/%s"%str(query))
        return self.api.api_text_search(query)
