#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Database layer
#  Abstraction layer between the database and the rest of the project
#
# Software is free software released under the "Modified BSD license"
#
# Copyright (c) 2014-2017  Pieter-Jan Moreels - pieterjan.moreels@gmail.com

# Imports
import math
import re
import uuid

from collections   import defaultdict
from passlib.hash  import pbkdf2_sha256

from lib.Config    import Configuration as conf
from lib.Database  import Database
from lib.Objects   import CVE, CPE, CWE, CAPEC, VIA4
from lib.Singleton import Singleton
from lib.Toolkit   import toStringFormattedCPE, exploitabilityScore, impactScore, hashableDict
from lib.Toolkit   import compile as _compile

# Code
class DatabaseLayer(metaclass=Singleton):
    def __init__(self, hash_rounds = 8000, salt_size = 10, _db=None):
      self.hash_rounds = hash_rounds
      self.salt_size   = salt_size

      self.db        = _db and Database(db=_db) or Database()
      self.CVE       = CVEs()
      self.CPE       = CPEs()
      self.CWE       = CWEs()
      self.CAPEC     = CAPECs()
      self.VIA4      = VIA4s()
      self.Whitelist = MarkList('whitelist')
      self.Blacklist = MarkList('blacklist')
      self.Users     = Users()
      self.Ranking   = Ranking()
      self.Plugins   = Plugins()
      self.Redis     = Redis()

    def db_info(self, include_admin=False):
        return self.db.db_getStats(include_admin)

    def ensure_index(self, collection, field):
        self.db.db_ensureIndex(collection, field)

    def drop_metadata(self):
        self.db.db_metadata_drop()

####################
# Black-/Whitelist #
####################
class MarkList:
    def __init__(self, marktype):
        if marktype.lower() not in ['blacklist', 'whitelist']:
            raise ValueError()
        self.marktype = marktype.lower()
        self.db = Database()

    def get(self):
        return list(getattr(self.db, self.marktype+"_get")())

    def rules(self, compiled=True):
        rules = []
        for cpe in list(getattr(self.db, self.marktype+"_get")()):
            if   cpe['type'] == "cpe":
                rules.append(cpe['id'])
            elif cpe['type'] == "targethardware":
                rules.append("cpe:2.3:([^:]*:){9}"+re.escape(cpe['id']))
            elif cpe['type'] == "targetsoftware":
                rules.append("cpe:2.3:([^:]*:){8}"+re.escape(cpe['id']))
        if not compiled: return rules
        return _compile(rules)

    def size(self):
        return getattr(self.db, self.marktype+"_size")()

    def contains(self, cpe):
        return getattr(self.db, self.marktype+"_contains")(cpe)

    def insert(self, cpe, cpeType, comments=None):
        # Strip all comments
        try:
            if '#' in cpe:
                if comments is None: comments = []
                comments.extend(cpe.split('#')[1:])
                cpe = cpe.split('#')[0]
            # Format properly
            if cpeType.lower() == 'cpe': cpe = toStringFormattedCPE(cpe)
            # Checks format
            if cpe:
                if not self.contains(cpe):
                    getattr(self.db, self.marktype+"_insert")(cpe, cpeType, comments)
                    return True
            return False
        except Exception as ex:
            print("Error inserting item in database: %s"%ex)
            raise(ex)

    def remove(self, cpe):
        try:
            cpe = toStringFormattedCPE(cpe.strip())
            if cpe and self.contains(cpe):
                getattr(self.db, self.marktype+"_remove")(cpe)
                return True
            return False
        except Exception as ex:
            print("Error removing item from database: %s"%ex)
            raise(ex)

    def update(self, cpeOld, cpeNew, cpeType):
        try:
            comments = cpeNew.split("#")[1:]
            if cpeType == "cpe":
                cpeOld = toStringFormattedCPE(cpeOld.split('#')[0].strip())
                cpeNew = toStringFormattedCPE(cpeNew.split('#')[0].strip())
            if cpeOld and cpeNew:
                if self.contains(cpeOld):
                    if self.contains(cpeNew) and cpeNew != cpeOld:
                        raise(Exception("Value already exists in database"))
                    getattr(self.db, self.marktype+"_update")(cpeOld, cpeNew, cpeType, comments)
                    return True
            return False
        except Exception as ex:
            print("Error updating item in database: %s"%(ex))
            raise(ex)

    def clear(self):
        try:
            size = self.size()
            getattr(self.db, self.marktype+"_drop")
            return size
        except Exception as ex:
            print("Error cleaning out collection in database: %s"%(ex))
            raise(ex)

#########
# Users #
#########
class Users:
    def __init__(self):
        self.db = Database()

    def size(self):
        return self.db.user_size()

    def insert(self, user, passwd, admin=False, localOnly=False):
        hashed = pbkdf2_sha256.encrypt(passwd, **conf.getPBKDFSettings())
        data = {'username': user, 'password': hashed}
        if admin:     data['master']     = True
        if localOnly: data['local_only'] = True
        self.db.user_add(data)

    def remove(self, username):
        self.db.user_remove(username)

    def changePassword(self, username, pwd):
        user   = self.get(username)
        if user:
            hashed = pbkdf2_sha256.encrypt(pwd, **conf.getPBKDFSettings())
            user['password'] = hashed
            self.db.user_update(user)
            return True
        return False

    def setAdmin(self, username, admin=True):
        self.db.user_setAdmin(username, admin)

    def isAdmin(self, username):
        user = self.get(username)
        return (user and user.get('master'))

    def isOnlyAdmin(self, username):
        return self.db.user_isOnlyAdmin(username)

    def setLocalOnly(self, username, localOnly=True):
        self.db.user_setLocalOnly(username, localOnly)

    def isLocalOnly(self, username):
        user = self.get(username)
        return (user and user.get('local_only'))

    def isOnlyMaster(self, username):
        return self.db.user_isOnlyAdmin(username)

    def exists(self, username):
        return True if self.get(username) else False

    def get(self, username):
        return self.db.user_get(username)

    def getAll(self):
        return self.db.user_getAll()

    def verifyPassword(self, username, pwd):
        user = self.get(username)
        return (user and pbkdf2_sha256.verify(pwd, user['password']))

    def getToken(self, username):
        user = self.get(username)
        if not user:               return None
        if 'token' in user.keys(): return user['token']
        else:                      return self.generateToken(user)

    def generateToken(self, username):
        user = self.get(username)
        if user:
            user['token'] = uuid.uuid4().hex
            self.db.user_update(user)
            return user['token']
        return None

########
# CVEs #
########
class CVEs:
    def __init__(self):
        self.db = Database()

    def upsert(self, cve):
        if not isinstance(cve, list):
            cve = [cve]
        if not all(isinstance(x, CVE) for x in cve):
           raise ValueError()
        self.db.cve_upsert(cve)

    # Data retrieval
    def get(self, cveID, **kwargs):
        cve = self.db.cve_get(cveID.upper())
        # Replace the dud at cve reconstruction time with the pointer
        if cve:
            if cve.cwe: cve.cwe = DatabaseLayer().CWE.get(cve.cwe.id)
            cve.vulnerable_configuration = [DatabaseLayer().CPE.get(x.id)
                                            for x in cve.vulnerable_configuration]
            self._enhance(cve, **kwargs)
        return cve

    def query(self, limit=False, skip=0, sort=None,  query={}, **kwargs):
        cves = self.db.cve_query(limit=limit, skip=skip, sort=sort, query=query)
        self._enhance(cves, **kwargs)
        return cves

    def last(self, limit=-1, skip=0, query={}, **kwargs):
        cves = self.query(limit=limit, skip=skip, query=query)
        self._enhance(cves, **kwargs)
        return cves

    def _enhance(self, cve, via4=False, subscore=False, ranking=False, **kwargs):
        if isinstance(cve, CVE): cve = [cve]
        for c in cve:
            # update CPE's for titles
            vulns = []
            for vuln in c.vulnerable_configuration:
                vulns.append(DatabaseLayer().CPE.get(vuln.id))
            c.vulnerable_configuration = vulns
            # Extra updates
            if via4:
                c.via4 = DatabaseLayer().VIA4.get(c.id)
            if ranking:
                ranks = set()
                for config in c.vulnerable_configuration:
                    rank = DatabaseLayer().CPE.ranking(config.id)
                    if rank:
                        rank = [hashableDict(x) for x in rank] # making the dict hashable
                        ranks.add(tuple(rank)) # tuple cuz lists are not hahsable
                c.ranking = ranks
            if subscore:
                exploitCVSS=exploitabilityScore(cve)
                impactCVSS =impactScore(cve)
                cve.access.cvss =(math.ceil(exploitCVSS*10)/10) if type(exploitCVSS) is not str else exploitCVSS
                cve.impact.cvss =(math.ceil(impactCVSS *10)/10) if type(impactCVSS)  is not str else impactCVSS

    def forCPE(self, cpe):
        return self.db.cve_forCPE(cpe)

    def textSearch(self, search):
        return self.db.cve_textSearch(search)

    def via4links(self, text, **kwargs):
        cves = [x['id'] for x in self.db.via4_search(re.compile(re.escape(text), re.I))]
        return self.query(query={'id': {'$in': cves}}, **kwargs)

    # Info
    def info(self):
        return self.db.cve_info()

    def size(self):
        return self.db.cve_size()

    def updated(self, date=None):
        if date:
            self.db.cve_setUpdate(date)
        else:
            return self.db.cve_getUpdate()

    def drop(self):
        self.db.cve_drop()

########
# CPEs #
########
class CPEs:
    def __init__(self):
        self.db = Database()

    def upsert(self, cpe):
        if not isinstance(cpe, list):
            cpe = [cpe]
        if not all(isinstance(x, CPE) for x in cpe):
            raise ValueError()
        self.db.cpe_upsert(cpe)

    def alternative_upsert(self, cpe):
        if not isinstance(cpe, list):
            cpe = [cpe]
        if not all(isinstance(x, CPE) for x in cpe):
            raise ValueError()
        self.db.cpe_alternative_upsert(cpe)

    def get(self, id):
        cpe = self.db.cpe_get(toStringFormattedCPE(id))
        return cpe if cpe else CPE(id)

    def get_regex(self, regex, alternative=False):
        return self.db.cpe_regex(re.compile(regex, re.IGNORECASE), alternative)

    def getAll(self):
        return self.db.cpe_getAll()

    def getAllAlternative(self):
        return self.db.cpe_getAllAlternative()

    def alternative_updated(self, date=None, indexed=None):
        if date:
            self.db.cpeOther_setUpdate(date)
            if indexed:
                self.db.cpeOther_setMetadata('indexed', indexed)
        else:
            data = self.db.cpeOther_info()
            return (data.get('last-modified'), data.get('indexed'))

    def ranking(self, cpeid, loosy=True):
        result = False
        if loosy:
            for x in cpeid.split(':'):
                if x is not '':
                    i = self.db.ranking_find(x, regex=True)
                if i is None:
                    continue
                if 'rank' in i:
                    result = i['rank']
        else:
            i = self.db.ranking_find(cpeid, regex=True)
            if i is None:
                result =  False
            if 'rank' in i:
                return i['rank']
        return result

    def updated(self, date=None):
        if date:
            self.db.cpe_setUpdate(date)
        else:
            return self.db.cpe_getUpdate()

    def size(self):
        return self.db.cpe_size()

    def alternative_size(self):
        return self.db.cpeOther_size()

    def drop(self):
        self.db.cpe_drop()

    def alternative_drop(self):
        self.db.cpeOther_drop()

########
# CWEs #
########
class CWEs:
    def __init__(self):
        self.db  = Database()
        self.cwe = None

    def upsert(self, cwe):
        if not isinstance(cwe, list):
            cwe = [cwe]
        if not all(isinstance(x, CWE) for x in cwe):
            raise ValueError()
        self.db.cwe_upsert(cwe)

    def updated(self, date=None):
        if date:
            self.db.cwe_setUpdate(date)
        else:
            return self.db.cwe_getUpdate()

    def get(self, id):
        if type(id) is int: id = str(id)
        if not self.cwe: self._populate_memory_db()
        return self.cwe.get(id)

    def getAll(self): #Safe way of accessing all CWE
        if not self.cwe: self._populate_memory_db()
        return list(self.cwe.values())

    def _populate_memory_db(self):
        self.cwe = {x.id: x for x in self.db.cwe_getAll()}
        for c in self.cwe.values():
            c.capec = DatabaseLayer().CAPEC.relatedTo(c.id)

    def size(self):
        return self.db.cwe_size()

    def drop(self):
        self.db.cwe_drop()

#########
# CAPEC #
#########
class CAPECs:
    def __init__(self):
        self.db      = Database()
        self.capec   = None
        self.related = None

    def upsert(self, capec):
        if not isinstance(capec, list):
            capec = [capec]
        if not all(isinstance(x, CAPEC) for x in capec):
            raise ValueError()
        self.db.capec_upsert(capec)

    def updated(self, date=None):
        if date:
            self.db.capec_setUpdate(date)
        else:
            return self.db.capec_getUpdate()

    def get(self, id):
        if isinstance(id, int): id = str(id)
        if not self.capec: self._populate_memory_db()
        return self.capec.get(id)

    def relatedTo(self, cweID):
        if isinstance(cweID, int): cweID = str(cweID)
        if not self.related: self._populate_memory_db()
        return self.related.get(cweID, [])

    def _populate_memory_db(self):
        DatabaseLayer().CWE.get("0") # Force a db populate if not done yet
        self.capec   = {x.id: x for x in self.db.capec_getAll()}
        self.related = defaultdict(list)
        for c in self.capec.values():
            related_weaknesses = []
            for w in c.weaknesses:
                rw = DatabaseLayer().CWE.get(w) or CWE(w, "Unknown", "No CWE", "Unknown", "Unknown")
                related_weaknesses.append(rw)
                self.related[w].append(c)
            c.weaknesses = related_weaknesses

    def size(self):
        return self.db.capec_size()

    def drop(self):
        self.db.capec_drop()

########
# VIA4 #
########
class VIA4s:
    def __init__(self):
        self.db = Database()

    def upsert(self, via4):
        if not isinstance(via4, list):
          via4 = [via4]
        if not all(isinstance(x, VIA4) for x in via4):
          raise ValueError()
        self.db.via4_upsert(via4)

    def get(self, cveid):
        via4 = self.db.via4_get(cveid)
        return via4 if via4 else None

    def searchables(self, data=None):
        if data:
            self.db.via4_setMetadata('searchables', data)
        else:
            return self.db.via4_info().get('searchables', [])

    def sources(self, data=None):
        if data:
            self.db.via4_setMetadata('sources', data)
        else:
            return self.db.via4_info().get('sources', [])

    def updated(self, date=None):
        if date:
            self.db.via4_setUpdate(date)
        else:
            return self.db.via4_getUpdate()

    def link(self, key, value):
        return self.db.via4_link(key, value)

    def size(self):
        return self.db.via4_size()

    def drop(self):
        self.db.via4_drop()

###########
# Ranking #
###########
class Ranking:
    def __init__(self):
        self.db = Database()

    def get(self, cpe=None, regex=False):
        return self.db.ranking_find(cpe=cpe, regex=regex)

    def remove(self, cpe):
        if isinstance(cpe, str) and len(cpe) > 0: 
          self.db.ranking_remove(cpe)

    def add(self, cpe, key, rank):
        self.db.ranking_add(cpe, key, rank)

###########
# Plugins #
###########
class Plugins:
    def __init__(self):
        self.db = Database()

    # Settings
    def setting_write(self, plugin, setting, value):
        self.db.plugin_setting_write(plugin, setting, value)

    def setting_read(self, plugin, setting):
        return self.db.plugin_setting_read(plugin, setting)

    def settings_delete(self, plugin):
        self.db.plugin_settings_delete(plugin)

    # User settings
    def userSetting_write(self, plugin, user, setting, value):
        self.db.plugin_userSetting_write(plugin, user, setting, value)

    def userSetting_read(self, plugin, user, setting):
        return self.db.plugin_userSetting_read(plugin, user, setting)

    def userSettings_delete(self, plugin):
        self.db.plugin_userSettings_delete(plugin)

    # Direct manipulations
    def query(self, plugin, query):
        return self.db.plugin_query(plugin, query)

    def query_one(self, plugin, query):
        return self.db.plugin_query_one(plugin, query)

    # Data Manipulation
    def drop(self, plugin):
        self.db.plugin_drop(plugin)

    def insert(self, plugin, data):
        self.db.plugin_insert(plugin, data)

    def remove(self, plugin, query):
        self.db.plugin_remove(plugin, query)

    def addToList(self, plugin, query, listname, data):
        if type(data) != list: data=[data]
        self.db.plugin_addToList(plugin, query, listname, data)

    def removeFromList(self, plugin, query, listname, data):
        self.db.plugin_removeFromList(plugin, query, listname, data)

    def bulkUpdate(self, plugin, keyword, data):
        if type(data) is not list: data = [data]
        self.db.plugin_bulkUpdate(plugin, keyword, data)

#########
# Redis #
#########
class Redis():
    def __init__(self):
        self.vendor = conf.getRedisVendorConnection()

    def vendors(self):
        return sorted(list(set().union(self.vendor.smembers("t:/o"),
                                       self.vendor.smembers("t:/a"),
                                       self.vendor.smembers("t:/h"))))

    def products(self, vendor):
        return sorted(list(self.vendor.smembers("v:" + vendor))) or None

