#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Database
#  The database specific code.
#
# Software is free software released under the "Modified BSD license"
#
# Copyright (c) 2017       Pieter-Jan Moreels - pieterjan.moreels@gmail.com

# Imports
import pymongo
import urllib.parse
import sys

from lib.Config    import Configuration         as conf
from lib.Objects   import CVE, CPE, CWE, CAPEC, VIA4
from lib.Singleton import Singleton

# Code
class Database(metaclass=Singleton):
    def __init__(self, host=None, port=None, db=None, user=None,
                 password=None):
        kwargs = {}
        if host:     kwargs['host']=    host
        if port:     kwargs['port']=    port
        if db:       kwargs['db']=      db
        if user:     kwargs['user']=    user
        if password: kwargs['password']=password
        self.db = self.getConnection(**kwargs)

        self.colCVE=             self.db['cves']
        self.colCPE=             self.db['cpe']
        self.colCWE=             self.db['cwe']
        self.colCPEOTHER=        self.db['cpeother']
        self.colWHITELIST=       self.db['mgmt_whitelist']
        self.colBLACKLIST=       self.db['mgmt_blacklist']
        self.colUSERS=           self.db['mgmt_users']
        self.colINFO=            self.db['info']
        self.colRANKING=         self.db['ranking']
        self.colVIA4=            self.db['via4']
        self.colCAPEC=           self.db['capec']
        self.colPlugSettings=    self.db['plugin_settings']
        self.colPlugUserSettings=self.db['plugin_user_settings']

    ###############################
    # Database Specific Functions #
    ###############################
    def getConnection(self, host=None, port=None, db=None, user=None,
                      password=None):
        # Default connection
        _host, _port     = conf.getDatabaseServer()
        _user, _password = conf.getDatabaseAuth()
        _db              = conf.getDatabaseName()
        # assign needed
        host     = host     or _host
        port     = port     or _port
        user     = user     or _user
        password = password or _password
        db       = db       or _db

        try:
            if user and password:
                mongoURI = "mongodb://{user}:{pwd}@{host}:{port}/{db}".format(
                              user     = urllib.parse.quote( user ),
                              password = urllib.parse.quote( password ),
                              host = host, port = port, db = db)
                connect = pymongo.MongoClient(mongoURI)
            else:
                connect = pymongo.MongoClient(host, port)[db]
        except Exception as e:
            print(e)
            sys.exit("Unable to connect to Mongo. Is it running on %s:%s?"%(host, port))
        return connect

    def sanitize(self, x):
        if isinstance(x, pymongo.cursor.Cursor):
            x=list(x)
        if isinstance(x, list):
            for y in x: self.sanitize(y)
        if isinstance(x, dict):
            if x and  "_id" in x: x.pop("_id")
        return x


    ############
    # Database #
    ############
    def db_getStats(self, include_admin=False):
        data={'cves': {}, 'cpe': {}, 'cpeOther': {}, 'capec': {}, 'cwe': {}, 'via4': {}}
        for key in data.keys():
            data[key] = {'size':        self._size(self.db[key.lower()]),
                         'last_update': self._getUpdate(self.db[key.lower()])}
        if include_admin:
            data['whitelist']={'size': self.colWHITELIST.count()}
            data['blacklist']={'size': self.colBLACKLIST.count()}
            data = {'stats': {'size_on_disk': self.db.command("dbstats")['storageSize'],
                              'db_size':      self.db.command('dbstats')['dataSize'],
                              'name':         conf.getDatabaseName()},
                    'data':  data}
        return data

    def db_ensureIndex(self, collection, field):
        self.db[collection].ensure_index(field)

    def db_metadata_drop(self):
        self.colINFO.drop()


    #######################
    # Blacklist/Whitelist #
    #######################
    # Info
    def blacklist_size(self):
        return self.colBLACKLIST.count()

    def witelist_size(self):
        return self.colWHITELIST.count()

    def blacklist_contains(self, cpe):
        return True if self.colBLACKLIST.find({'id': cpe}).count()>0 else False

    def whitelist_contains(self, cpe):
        return True if self.colWHITELIST.find({'id': cpe}).count()>0 else False

    # Data manipulation
    def blacklist_insert(self, cpe, type, comments=None):
        data = {'id': cpe, 'type': type}
        if comments: data['comments'] = comments
        self.colBLACKLIST.insert(data)

    def whitelist_insert(self, cpe, type, comments=None):
        data = {'id': cpe, 'type': type}
        if comments: data['comments'] = comments
        self.colWHITELIST.insert(data)

    def blacklist_remove(self, cpe):
        self.colBLACKLIST.remove({'id': cpe})

    def whitelist_remove(self, cpe):
        self.colWHITELIST.remove({'id': cpe})

    def blacklist_update(self, cpeOld, cpeNew, cpeType, comments=None):
        data = {'id': cpeNew, 'type': cpeType}
        if comments: data['comments'] = comments
        self.colBLACKLIST.update({'id': cpeOld}, data)

    def whitelist_update(self, cpeOld, cpeNew, cpeType, comments=None):
        data = {'id': cpeNew, 'type': cpeType}
        if comments: data['comments'] = comments
        self.colWHITELIST.update({'id': cpeOld}, data)

    def blacklist_drop(self):
        self.colBLACKLIST.drop()

    def whitelist_drop(self):
        self.colWHITELIST.drop()

    # Data retrieval
    def blacklist_get(self):
        return self.sanitize(self.colBLACKLIST.find())

    def whitelist_get(self):
        return self.sanitize(self.colWHITELIST.find())


    #########
    # Users #
    #########
    # Info
    def user_size(self):
        return self.colUSERS.count()

    # Data manipulation
    def user_add(self, entry):
        self.colUSERS.insert(entry)

    def user_update(self, entry):
        self.colUSERS.update({'username': entry['username']}, entry)

    def user_remove(self, user):
        self.colUSERS.remove({'username': user})

    def user_setAdmin(self, user, admin=True):
        if admin:
            self.colUSERS.update({'username': user}, {'$set': {'master': True}})
        else:
            self.colUSERS.update({'username': user}, {'$unset': {'master': ""}})

    def user_setLocalOnly(self, user, localOnly=True):
        if localOnly:
            self.colUSERS.update({'username': user}, {'$set': {'local_only': True}})
        else:
            self.colUSERS.update({'username': user}, {'$unset': {'local_only': ""}})

    # Data retrieval
    def user_get(self, user):
        return self.sanitize(self.colUSERS.find_one({"username": user}))

    def user_getAll(self):
        return self.sanitize(self.colUSERS.find())

    def user_isOnlyAdmin(self, user):
        return True if len(list(self.colUSERS.find({"username": {"$ne": user}, "master": True}))) == 0 else False


    ########
    # CVEs #
    ########
    # Data manipulation
    def cve_upsert(self, data):
        if type(data) is CVE: data = [data]
        if len(data)>0:
            bulk=self.colCVE.initialize_unordered_bulk_op()
            for x in data:
                bulk.find({'id': x.id}).upsert().update({'$set': x.dict(database=True)})
            bulk.execute()

    # Data retrieval
    def cve_get(self, id):
        cve = self.sanitize(self.colCVE.find_one({"id": id}))
        return CVE.fromDict(cve) if cve else None

    def cve_forCPE(self, cpe):
        if not cpe: return []
        return [CVE.fromDict(cve) for cve in
                self.sanitize(self.colCVE.find({"vulnerable_configuration": {"$regex": cpe}}).sort("Modified", -1))]

    def cve_query(self, limit=False, skip=0, sort=None, query={}):
        if not sort: sort = ("Modified", "desc")
        if isinstance(query, list): query = {"$and": query}
        if isinstance(sort, (list, tuple)) and len(sort) == 2:
            if sort[1].lower() == "asc": sort = (sort[0], 1)
            else:                        sort = (sort[0], -1) # Default Descending
        cves = list(self.colCVE.find(query).sort(sort[0], sort[1]).limit(limit).skip(skip))
        return [CVE.fromDict(x) for x in self.sanitize(cves)]

    def cve_textSearch(self, text):
        try: # Before Mongo 3
            data = [x["obj"] for x in self.db.command("text", "cves", search=text)["results"]]
        except: # As of Mongo 3
            data = self.sanitize(self.colCVE.find({"$text":{"$search":text}}))
        return [CVE.fromDict(x) for x in data]

    def cve_drop(self):
        self.colCVE.drop()

    ########
    # CPEs #
    ########
    # Data modification
    def cpe_upsert(self, data):
        if type(data) is CPE: data = [data]
        if len(data)>0:
            bulk=self.colCPE.initialize_unordered_bulk_op()
            for x in data:
                if type(x) is CPE:
                    bulk.find({'id': x.id}).upsert().update({'$set': x.dict()})
            bulk.execute()

    def cpe_alternative_upsert(self, data):
        if type(data) is CPE: data = [data]
        if len(data)>0:
            bulk=self.colCPEOTHER.initialize_unordered_bulk_op()
            for x in data:
                if type(x) is CPE:
                    bulk.find({'id': x.id}).upsert().update({'$set': {'id': x.id}})
            bulk.execute()

    # Data retrieval
    def cpe_get(self, id):
        cpe = self.sanitize(self.colCPE.find_one({"id": id}))
        return CPE.fromDict(cpe) if cpe else None

    def cpe_getAll(self):
        return [CPE.fromDict(x) for x in self.sanitize(self.colCPE.find())] or []

    def cpe_getAllAlternative(self):
        return [CPE(x) for x in self.sanitize(self.colCPEOTHER.find())] or []

    def cpe_regex(self, regex, alternative):
        data = list(self.colCPE.find({"id": {"$regex": regex}}))
        if alternative:
            data.extend(list(self.colCPEOTHER.find({"id": {"$regex": regex}})))
        return [CPE(x) for x in self.sanitize(data)] or []

    def cpe_drop(self):
        self.colCPE.drop()

    def cpe_dropAlternative(self):
        self.colCPEOTHER.drop()

    ########
    # CWEs #
    ########
    def cwe_upsert(self, data):
        if len(data)>0:
            bulk=self.colCWE.initialize_unordered_bulk_op()
            for x in data:
                if type(x) is CWE:
                    bulk.find({'id': x.id}).upsert().update({'$set': x.dict()})
            bulk.execute()

    def cwe_getAll(self):
        return [CWE.fromDict(x) for x in self.sanitize(self.colCWE.find())] or []

    def cwe_drop(self):
        self.colCWE.drop()

    #########
    # CAPEC #
    #########
    def capec_upsert(self, data):
        if len(data)>0:
            bulk=self.colCAPEC.initialize_unordered_bulk_op()
            for x in data:
                if type(x) is CAPEC:
                    bulk.find({'id': x.id}).upsert().update({'$set': x.dict()})
            bulk.execute()

    def capec_getAll(self):
        return [CAPEC.fromDict(x) for x in
                self.sanitize(self.colCAPEC.find())] or []

    def capec_drop(self):
        self.colCAPEC.drop()

    ########
    # VIA4 #
    ########
    def via4_upsert(self, data):
        if len(data)>0:
            bulk=self.colVIA4.initialize_unordered_bulk_op()
            for x in data:
                if type(x) is VIA4:
                    bulk.find({'id': x.id}).upsert().update({'$set': x.dict()})
            bulk.execute()

    def via4_get(self, cveid):
        via4 = self.sanitize(self.colVIA4.find_one({'id': cveid}))
        if via4: via4.pop("id")
        return VIA4.fromDict(via4) if via4 else None

    def via4_link(self, key, val):
        cveList=[x['id'] for x in self.colVIA4.find({key: val})]
        return self.cve_query(query={'id':{'$in':cveList}})

    def via4_drop(self):
        self.colVIA4.drop()

    def via4_search(self, text):
        data = []
        for vLink in self.via4_info().get('searchables', []):
            data.extend(self.sanitize(self.colVIA4.find({vLink: {'$in': [text]}})))
        return data

    ###########
    # Ranking #
    ###########
    def ranking_add(self, cpe, key, rank):
        item = self.ranking_find(cpe)
        if item is None:
            self.colRANKING.update({'cpe': cpe}, {"$push": {'rank': {key: rank}}}, upsert=True)
        else:
            l = []
            for i in item['rank']:
                i[key] = rank
                l.append(i)
            self.colRANKING.update({'cpe': cpe}, {"$set": {'rank': l}})
        return True

    def ranking_remove(self, cpe):
        self.colRANKING.remove({'cpe': {'$regex': cpe}})

    def ranking_find(self, cpe=None, regex=False):
      if not cpe:
          return self.sanitize(self.colRANKING.find())
      if regex and cpe:
          return self.sanitize(self.colRANKING.find_one({'cpe': {'$regex': cpe}}))
      else:
          return self.sanitize(self.colRANKING.find_one({'cpe': cpe}))

    ########
    # Info #
    ########
    def _getInfo(self, collection):
        return self.sanitize(self.colINFO.find_one({'db': collection.name})) or {}

    def _setUpdate(self, collection, date):
        self.colINFO.update({"db": collection.name},
                            {"$set": {"last-modified": date}}, upsert=True)

    def _getUpdate(self, collection):
        info = self._getInfo(collection)
        return info.get('last-modified') if info else None

    def _size(self, collection):
        return collection.count()

    # CVE
    def cve_info(self):
        return self._getInfo(self.colCVE)

    def cve_setUpdate(self, date):
        self._setUpdate(self.colCVE, date)

    def cve_getUpdate(self):
        return self._getUpdate(self.colCVE)

    def cve_size(self):
        return self._size(self.colCVE)

    # CPE
    def cpe_info(self):
        return self._getInfo(self.colCPE)

    def cpe_setUpdate(self, date):
        self._setUpdate(self.colCPE, date)

    def cpe_getUpdate(self):
        return self._getUpdate(self.colCPE)

    def cpe_size(self):
        return self._size(self.colCPE)

    # CPE-Other
    def cpeOther_info(self):
        return self._getInfo(self.colCPEOTHER)

    def cpeOther_setUpdate(self, date):
        self._setUpdate(self.colCPEOTHER, date)

    def cpeOther_getUpdate(self):
        return self._getUpdate(self.colCPEOTHER)

    def cpeOther_size(self):
        return self._size(self.colCPEOTHER)

    def cpeOther_setMetadata(self, field, data):
        self.colINFO.update({"db": self.colCPEOTHER.name},
                            {"$set": {field: data}}, upsert=True)

    # CWE
    def cwe_info(self):
        return self._getInfo(self.colCWE)

    def cwe_setUpdate(self, date):
        self._setUpdate(self.colCWE, date)

    def cwe_getUpdate(self):
        return self._getUpdate(self.colCWE)

    def cwe_size(self):
        return self._size(self.colCWE)

    # CAPEC
    def capec_info(self):
        return self._getInfo(self.colCAPEC)

    def capec_setUpdate(self, date):
        self._setUpdate(self.colCAPEC, date)

    def capec_getUpdate(self):
        return self._getUpdate(self.colCAPEC)

    def capec_size(self):
        return self._size(self.colCAPEC)

    # VIA4
    def via4_info(self):
        return self._getInfo(self.colVIA4)

    def via4_setMetadata(self, field, data):
        self.colINFO.update({"db": self.colVIA4.name},
                            {"$set": {field: data}}, upsert=True)

    def via4_setUpdate(self, date):
        self._setUpdate(self.colVIA4, date)

    def via4_getUpdate(self):
        return self._getUpdate(self.colVIA4)

    def via4_size(self):
        return self._size(self.colVIA4)


    ###########
    # Plugins #
    ###########
    # Settings
    def plugin_setting_write(self, plugin, setting, value):
        self.colPlugSettings.update({"plugin": plugin},
                                    {"$set": {setting: value}}, upsert=True)

    def plugin_setting_read(self, plugin, setting):
        return self.colPlugSettings.find_one({'plugin': plugin}).get(setting)

    def plugin_settings_delete(self, plugin):
        self.colPlugSettings.remove({'plugin': plugin})

    # User settings
    def plugin_userSetting_write(self, plugin, user, setting, value):
        self.colPlugUserSettings.update({"plugin": plugin, "user":user},
                                        {"$set": {setting: value}}, upsert=True)

    def plugin_userSetting_read(self, plugin, user, setting):
        return self.colPlugUserSettings.find_one({'plugin': plugin, 'user': user}).get(setting)

    def plugin_userSettings_delete(self, plugin):
        self.colPlugUserSettings.remove({"plugin": plugin})

    # Query data
    def plugin_query(self, plugin, query):
        return self.sanitize(self.db['plug_%s'%plugin].find(query))

    def plugin_query_one(self, plugin, query):
        return self.sanitize(self.db['plug_%s'%plugin].find_one(query))

    # Data Manipulation
    def plugin_drop(self, plugin):
        self.db['plug_%s'%plugin].drop()

    def plugin_insert(self, plugin, data):
        self.db['plug_%s'%plugin].insert(data)

    def plugin_remove(self, plugin, query):
        self.db['plug_%s'%plugin].remove(query)

    def plugin_addToList(self, plugin, query, listname, data):
      current = self.plugin_query(plugin, query)
      if len(current)==0:
          self.plugin_insert(plugin, query)
      for entry in current:
          if listname in entry:
              data=list(set([repr(x) for x in data])-set([repr(x) for x in entry[listname]]))
              data=[ast.literal_eval(x) for x in data]
          if data:
              self.db['plug_%s'%plugin].update(query, {"$addToSet": {listname: {"$each": data}}})

    def plugin_removeFromList(self, plugin, query, listname, data):
        if   type(data) == dict:
            self.db['plug_%s'%plugin].update(query, {"$pull": {listname: data}})
        elif type(data) != list: data=[data]
        self.db['plug_%s'%plugin].update(query, {"$pullAll": {listname: data}})

    def plugin_bulkUpdate(self, plugin, keyword, data):
        if len(data)>0:
            bulk=db['plug_%s'%plugin].initialize_ordered_bulk_op()
            for x in data:
                bulk.find({keyword: x[keyword]}).upsert().update({'$set': x})
            bulk.execute()
