#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Config reader to read the configuration file
#
# Software is free software released under the "Modified BSD license"
#
# Copyright (c) 2013-2014 	Alexandre Dulaunoy - a@foo.be
# Copyright (c) 2014-2017 	Pieter-Jan Moreels - pieterjan.moreels@gmail.com

# imports
import bz2
import configparser
import datetime
import gzip
import os
import pymongo
import re
import redis
import sys
import urllib.parse
import urllib.request as req
import zipfile

runPath = os.path.dirname(os.path.realpath(__file__))

from io import BytesIO

class Configuration():
    ConfigParser = configparser.ConfigParser()
    ConfigParser.read(os.path.join(runPath, "../etc/configuration.ini"))
    default = {'salt_size': 10,          'hash_rounds': 8000,
               'redisHost': 'localhost', 'redisPort': 6379,
               'redisVendorDB': 10,      'redisNotificationsDB': 11,
               'redisRefDB': 12,
               'mongoHost': 'localhost', 'mongoPort': 27017,
               'mongoDB': "cvedb",
               'mongoUsername': '', 'mongoPassword': '',
               'dbHost': 'localhost', 'dbPort': 27017,
               'dbName': "cvedb",
               'dbUsername': '', 'dbPassword': '',
               'flaskHost': "127.0.0.1", 'flaskPort': 5000,
               'flaskDebug': True,       'pageLength': 50,
               'loginRequired': False,   'listLogin': True,
               'ssl': False,             'sslCertificate': "./ssl/cve-search.crt",
                                         'sslKey': "./ssl/cve-search.crt",
               'CVEStartYear': 2002,
               'logging': True,           'logfile': "./log/cve-search.log",
               'maxLogSize': '100MB',     'backlog': 5,
               'Indexdir': './indexdir',  'updatelogfile': './log/update.log',
               'Tmpdir': './tmp',
               'http_proxy': '',
               'plugin_load': './etc/plugins.txt',
               'plugin_config': './etc/plugins.ini',
               'auth_load': './etc/auth.txt'
               }
    sources={'cve':        "https://static.nvd.nist.gov/feeds/xml/cve/",
             'cpe':        "https://static.nvd.nist.gov/feeds/xml/cpe/dictionary/official-cpe-dictionary_v2.2.xml",
             'cwe':        "http://cwe.mitre.org/data/xml/cwec_v2.8.xml.zip",
             'capec':      "http://capec.mitre.org/data/xml/capec_v2.6.xml",
             'via4':       "http://www.cve-search.org/feeds/via4.json",
             'includecve':   True, 'includecpe':  True, 'includecwe': True,
             'includecapec': True, 'includevia4': True}

    @classmethod
    def readSetting(cls, section, item, default):
        result = default
        try:
            if type(default) == bool:
                result = cls.ConfigParser.getboolean(section, item)
            elif type(default) == int:
                result = cls.ConfigParser.getint(section, item)
            else:
                result = cls.ConfigParser.get(section, item)
        except:
            pass
        return result

    # DataBase
    @classmethod
    def getDatabaseName(cls):
        return cls.readSetting("Database", "Name", cls.default['dbName'])

    @classmethod
    def getDatabaseServer(cls):
        host = cls.readSetting("Database", "Host", cls.default['dbHost'])
        port = cls.readSetting("Database", "Port", cls.default['dbPort'])
        return (host, port)

    @classmethod
    def getDatabaseAuth(cls):
        user = cls.readSetting("Database", "Username", cls.default['dbUsername'])
        pwd  = cls.readSetting("Database", "Password", cls.default['dbPassword'])
        return (user, pwd)

    @classmethod
    def getMongoConnection(cls):
        mongoHost = cls.readSetting("Mongo", "Host", cls.default['mongoHost'])
        mongoPort = cls.readSetting("Mongo", "Port", cls.default['mongoPort'])
        mongoDB = cls.getDatabaseName()
        mongoUsername = cls.readSetting("Mongo", "Username", cls.default['mongoUsername'])
        mongoPassword = cls.readSetting("Mongo", "Password", cls.default['mongoPassword'])

        mongoUsername = urllib.parse.quote( mongoUsername )
        mongoPassword = urllib.parse.quote( mongoPassword )
        try:
            if mongoUsername and mongoPassword:
                mongoURI = "mongodb://{username}:{password}@{host}:{port}/{db}".format(
                    username = mongoUsername, password = mongoPassword,
                    host = mongoHost, port = mongoPort,
                    db = mongoDB
                )
                connect = pymongo.MongoClient(mongoURI, connect=False)
            else:
                connect = pymongo.MongoClient(mongoHost, mongoPort, connect=False)
        except:
            sys.exit("Unable to connect to Mongo. Is it running on %s:%s?"%(mongoHost,mongoPort))
        return connect[mongoDB]

    @classmethod
    def toPath(cls, path):
        return path if os.path.isabs(path) else os.path.join(runPath, "..", path)

    # Redis
    @classmethod
    def getRedisHost(cls):
        return cls.readSetting("Redis", "Host", cls.default['redisHost'])

    @classmethod
    def getRedisPort(cls):
        return cls.readSetting("Redis", "Port", cls.default['redisPort'])

    @classmethod
    def getRedisVendorConnection(cls):
        redisHost = cls.getRedisHost()
        redisPort = cls.getRedisPort()
        redisDB = cls.readSetting("Redis", "VendorsDB", cls.default['redisVendorDB'])
        return redis.StrictRedis(host=redisHost, port=redisPort, db=redisDB, charset='utf-8', decode_responses=True)

    @classmethod
    def getRedisNotificationsConnection(cls):
        redisHost = cls.getRedisHost()
        redisPort = cls.getRedisPort()
        redisDB = cls.readSetting("Redis", "NotificationsDB", cls.default['redisNotificationsDB'])
        return redis.StrictRedis(host=redisHost, port=redisPort, db=redisDB, charset="utf-8", decode_responses=True)

    @classmethod
    def getRedisRefConnection(cls):
        redisHost = cls.getRedisHost()
        redisPort = cls.getRedisPort()
        redisDB = cls.readSetting("Redis", "RefDB", cls.default['redisRefDB'])
        return redis.StrictRedis(host=redisHost, port=redisPort, db=redisDB, charset="utf-8", decode_responses=True)

    # Password Settings
    @classmethod
    def getPBKDFSettings(cls):
        rounds = cls.readSetting("Encryption", "HashRounds", cls.default['hash_rounds'])
        salt   = cls.readSetting("Encryption", "SaltSize",   cls.default['salt_size'])
        return {'rounds': rounds, 'salt_size': salt}

    # Flask
    @classmethod
    def getFlaskHost(cls):
        return cls.readSetting("Webserver", "Host", cls.default['flaskHost'])

    @classmethod
    def getFlaskPort(cls):
        return cls.readSetting("Webserver", "Port", cls.default['flaskPort'])

    @classmethod
    def getFlaskDebug(cls):
        return cls.readSetting("Webserver", "Debug", cls.default['flaskDebug'])

    # Webserver
    @classmethod
    def getPageLength(cls):
        return cls.readSetting("Webserver", "PageLength", cls.default['pageLength'])

    # Authentication
    @classmethod
    def loginRequired(cls):
        return cls.readSetting("Webserver", "LoginRequired", cls.default['loginRequired'])


    @classmethod
    def listLoginRequired(cls):
        return cls.readSetting("Webserver", "ListLoginRequired", cls.default['listLogin'])


    @classmethod
    def getAuthLoadSettings(cls):
        return cls.toPath(cls.readSetting("Webserver", "authSettings", cls.default['auth_load']))

    # SSL
    @classmethod
    def useSSL(cls):
        return cls.readSetting("Webserver", "SSL", cls.default['ssl'])

    @classmethod
    def getSSLCert(cls):
        return cls.toPath(cls.readSetting("Webserver", "Certificate", cls.default['sslCertificate']))

    @classmethod
    def getSSLKey(cls):
        return cls.toPath(cls.readSetting("Webserver", "Key", cls.default['sslKey']))

    # CVE
    @classmethod
    def getCVEStartYear(cls):
        date = datetime.datetime.now()
        year = date.year + 1
        score = cls.readSetting("CVE", "StartYear", cls.default['CVEStartYear'])
        if score < 2002 or score > year:
            print('The year %i is not a valid year.\ndefault year %i will be used.' % (score, cls.default['CVEStartYear']))
            score = cls.default['CVEStartYear']
        return cls.readSetting("CVE", "StartYear", cls.default['CVEStartYear'])


    # Logging
    @classmethod
    def getLogfile(cls):
        return cls.toPath(cls.readSetting("Logging", "Logfile", cls.default['logfile']))

    @classmethod
    def getUpdateLogFile(cls):
        return cls.toPath(cls.readSetting("Logging", "Updatelogfile", cls.default['updatelogfile']))

    @classmethod
    def getLogging(cls):
        return cls.readSetting("Logging", "Logging", cls.default['logging'])

    @classmethod
    def getMaxLogSize(cls):
        size = cls.readSetting("Logging", "MaxSize", cls.default['maxLogSize'])
        split = re.findall('\d+|\D+', size)
        try:
            if len(split) > 2 or len(split) == 0:
                raise Exception
            base = int(split[0])
            if len(split) == 1:
                multiplier = 1
            else:
                multiplier = (split[1]).strip().lower()
                if multiplier == "b":
                    multiplier = 1
                elif multiplier == "kb":
                    multiplier = 1024
                elif multiplier == "mb":
                    multiplier = 1024 * 1024
                elif multiplier == "gb":
                    multiplier = 1024 * 1024 * 1024
                else:
                    # If we cannot interpret the multiplier, we take MB as default
                    multiplier = 1024 * 1024
            return base * multiplier
        except Exception as e:
            print(e)
            return 100 * 1024

    @classmethod
    def getBacklog(cls):
        return cls.readSetting("Logging", "Backlog", cls.default['backlog'])

    # Indexing
    @classmethod
    def getTmpdir(cls):
        return cls.toPath(cls.readSetting("dbmgt", "Tmpdir", cls.default['Tmpdir']))

    # Indexing
    @classmethod
    def getIndexdir(cls):
        return cls.toPath(cls.readSetting("FulltextIndex", "Indexdir", cls.default['Indexdir']))

    # Http Proxy
    @classmethod
    def getProxy(cls):
        return cls.readSetting("Proxy", "http", cls.default['http_proxy'])

    @classmethod
    def getFile(cls, getfile, unpack=True):
        if cls.getProxy():
            proxy = req.ProxyHandler({'http': cls.getProxy(), 'https': cls.getProxy()})
            auth = req.HTTPBasicAuthHandler()
            opener = req.build_opener(proxy, auth, req.HTTPHandler)
            req.install_opener(opener)
        response = req.urlopen(getfile)
        data = response
        # TODO: if data == text/plain; charset=utf-8, read and decode
        if unpack:
            if   'gzip' in response.info().get('Content-Type'):
                buf = BytesIO(response.read())
                data = gzip.GzipFile(fileobj=buf)
            elif 'bzip2' in response.info().get('Content-Type'):
                data = BytesIO(bz2.decompress(response.read()))
            elif 'zip' in response.info().get('Content-Type'):
                fzip = zipfile.ZipFile(BytesIO(response.read()), 'r')
                if len(fzip.namelist())>0:
                    data=BytesIO(fzip.read(fzip.namelist()[0]))
        return (data, response)


    # Feeds
    @classmethod
    def getFeedData(cls, source, unpack=True):
        source = cls.getFeedURL(source)
        return cls.getFile(source, unpack) if source else None

    @classmethod
    def getFeedURL(cls, source):
        cls.ConfigParser.clear()
        cls.ConfigParser.read(os.path.join(runPath, "../etc/sources.ini"))
        return cls.readSetting("Sources", source, cls.sources.get(source, ""))

    @classmethod
    def includesFeed(cls, feed):
        return cls.readSetting("EnabledFeeds", feed, cls.sources.get('include'+feed, False))


    # Plugins
    @classmethod
    def getPluginLoadSettings(cls):
        return cls.toPath(cls.readSetting("Plugins", "loadSettings", cls.default['plugin_load']))

    @classmethod
    def getPluginsettings(cls):
        return cls.toPath(cls.readSetting("Plugins", "pluginSettings", cls.default['plugin_config']))

class ConfigReader():
    def __init__(self, file):
        self.ConfigParser = configparser.ConfigParser()
        self.ConfigParser.read(file)

    def read(self, section, item, default):
        result = default
        try:
            if type(default) == bool:
                result = self.ConfigParser.getboolean(section, item)
            elif type(default) == int:
                result = self.ConfigParser.getint(section, item)
            else:
                result = self.ConfigParser.get(section, item)
        except:
            pass
        return result
