#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Objects for CVE-Search internal workings
#
# Software is free software released under the "Modified BSD license"
#

# Copyright (c) 2017  Pieter-Jan Moreels - pieterjan.moreels@gmail.com

# imports
from datetime import datetime

import lib.Toolkit as tk

#######
# CVE #
#######
class Impact:
    def __init__(self, confidentiality, integrity, availability):
        tk.assertType(str, confidentiality=confidentiality,
                           integrity=integrity, availability=availability)
        confidentiality = confidentiality.upper()
        integrity       = integrity.upper()
        availability    = availability.upper()

        ACCEPTED = ["COMPLETE", "PARTIAL", "NONE"]
        if (set(ACCEPTED+[confidentiality, integrity, availability])
            != set(ACCEPTED)):
            raise ValueError("incorrect values given")

        self.confidentiality = confidentiality
        self.integrity       = integrity
        self.availability    = availability

    def dict(self):
        return {'confidentiality': self.confidentiality,
                'integrity':       self.integrity,
                'availability':    self.availability}

    @classmethod
    def fromDict(cls, data):
        return cls(data['confidentiality'], data['integrity'],
                   data['availability'])


class Access:
    def __init__(self, complexity, authentication, vector):
        tk.assertType(str, complexity=complexity,
                      authentication=authentication, vector=vector)

        if not (complexity.upper() in ["HIGH", "MEDIUM", "LOW"] and
                authentication.upper() in ["NONE", "SINGLE_INSTANCE",
                                           "MULTIPLE_INSTANCES"] and
                vector.upper() in ["NETWORK", "LOCAL",
                                   "ADJACENT_NETWORK"]):
            raise ValueError("incorrect values given")

        self.complexity     = complexity.upper()
        self.authentication = authentication.upper()
        self.vector         = vector.upper()

    def dict(self):
        return {'complexity':     self.complexity,
                'authentication': self.authentication,
                'vector':         self.vector}

    @classmethod
    def fromDict(cls, data):
        return cls(data['complexity'], data['authentication'],
                   data['vector'])


class CVE:
    __slots__ = ('id', 'cvss', 'summary', 'vulnerable_configuration',
                 'published', 'modified', 'impact', 'access', 'cwe',
                 'references', 'cvss_time', 'ranking', 'via4', 'reason')
    def __init__(self, id, summary, vulnerable_configuration, published,
                  modified=None, impact=None, access=None, cvss=None,
                 cwe=None, references=None, cvss_time=None):
        if not references: references = []
        tk.assertType(str, id=id, summary=summary)
        tk.assertType((float, str, None), cvss=cvss)
        tk.assertType(datetime, published=published)
        tk.assertType((datetime, None), modified=modified, cvss_time=cvss_time)
        tk.assertType(list, vulnerable_configuration=vulnerable_configuration,
                            references=references)
        tk.assertType((Impact, None), impact=impact)
        tk.assertType((Access, None), access=access)
        tk.assertType((CWE, None), cwe=cwe)
        tk.assertTypeForAllIn(CPE, vulnerable_configuration)
        tk.assertTypeForAllIn(str, references)


        self.id                       = id.upper()
        self.cvss                     = cvss and float(cvss) or None
        self.summary                  = summary
        self.vulnerable_configuration = vulnerable_configuration
        self.published                = published
        self.modified                 = modified
        self.impact                   = impact
        self.access                   = access
        self.cwe                      = cwe
        self.references               = references
        self.cvss_time                = cvss_time

    def dict(self, capec=False, human_dates=False, backwards_compatible=True,
             database=False):
        vuln_conf = [x.dict() for x in self.vulnerable_configuration]
        data = {'id':                       self.id,
                'cvss':                     self.cvss,
                'summary':                  self.summary,
                'vulnerable_configuration': vuln_conf,
                'Published':                self.published,
                'Modified':                 self.modified,
                'impact':                   self.impact and self.impact.dict() or None,
                'access':                   self.access and self.access.dict() or None,
                'cwe':                      self.cwe and self.cwe.id or "Unknown",
                'references':               self.references,
                'cvss-time':                self.cvss_time}
        if data['cwe'] != "Unknown": data['cwe'] = "CWE-"+ data['cwe']

        if capec:
            data['capec'] = []
            if self.cwe and self.cwe.id.lower() != "unknown":
                data['capec'] = [c.dict() for c in self.cwe.capec]
        if human_dates:
            for field in ['Published', 'Modified', 'cvss-time']:
                data[field] = str(data[field])

        # To be removed in the newest release
        if backwards_compatible or database:
            vuln_conf = [x.id for x in self.vulnerable_configuration]
            if backwards_compatible and not database:
                b = [x.id_2_2 for x in self.vulnerable_configuration]
                data['vulnerable_configuration_cpe_2_2'] = b
            data['vulnerable_configuration'] = vuln_conf
        return data

    @classmethod
    def fromDict(cls, data):
        def toDate(date):
            if isinstance(date, str):
                try:
                    return datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f")
                except:
                    return datetime.strptime(date, "%Y-%m-%dT%H:%M:%S")
            return date

        # Creating dud CPEs with the correct ID, to be replaced
        vc  = [CPE(x) if isinstance(x, str) else CPE.fromDict(x)
                      for x in data['vulnerable_configuration']]
        i   = data.get('impact') and Impact.fromDict(data['impact']) or None
        a   = data.get('access') and Access.fromDict(data['access']) or None
        if data.get('cwe'): # Get ensures backwards compatibility
            # Create a dud cwe with just the ID, so we can assign the
            #  pointer later
            cwe = CWE(data['cwe'].strip("CWE-"), 'dud', 'dud', 'dud', 'dud')
        else: cwe = None
        data['Modified']  = toDate(data['Modified'])
        data['Published'] = toDate(data['Published'])
        if data.get('cvss-time'):
            data['cvss-time'] = toDate(data['cvss-time'])
        return cls(data['id'], data['summary'], vc, data['Published'],
                   data['Modified'], i, a, data.get('cvss'), cwe,
                   data['references'], data.get('cvss-time'))


#######
# CPE #
#######
class CPE:
    def __init__(self, id, title=None, references=None):
        if not references: references = []
        tk.assertType(str, id=id)
        tk.assertType((str, None),  title=title)
        tk.assertType((list, tuple, None), references=references)
        if references:
            tk.assertTypeForAllIn(str, references)

        self.id         = tk.toStringFormattedCPE(id)
        self.id_2_2     = tk.toOldCPE(id)
        self.title      = title if title else tk.cpeTitle(self.id)
        self.references = references and list(references) or []

    def dict(self):
        return {'id':         self.id,
                'cpe_2_2':    self.id_2_2,
                'title':      self.title,
                'references': self.references}

    @classmethod
    def fromDict(cls, data):
        return cls(data['id'], data['title'], data.get('references'))

#######
# CWE #
#######
class CWE:
    def __init__(self, id, name, description, status, weakness):
        tk.assertType(str, id=id, name=name, description=description,
                      status=status, weakness=weakness)

        self.id          = id
        self.name        = name
        self.description = description
        self.status      = status
        self.weakness    = weakness
        self.capec       = None     # Populated with pointers at runtime

    def dict(self):
        return {'id':                  self.id,
                'name':                self.name,
                'description_summary': self.description,
                'status':              self.status,
                'weaknessabs':         self.weakness}

    @classmethod
    def fromDict(cls, data):
        return cls(data['id'], data['name'], data['description_summary'],
                   data['status'], data['weaknessabs'])

#########
# CAPEC #
#########
class CAPEC:
    def __init__(self, id, name, summary, prerequisites, solutions,
                 weaknesses):
        tk.assertType(str, id=id, name=name, summary=summary,
                      prerequisites=prerequisites, solutions=solutions)
        tk.assertType((list, tuple), weaknesses=weaknesses)
        tk.assertTypeForAllIn(str, weaknesses)

        self.id            = id
        self.name          = name
        self.summary       = summary
        self.prerequisites = prerequisites
        self.solutions     = solutions
        self.weaknesses    = list(weaknesses)

    def dict(self):
        weakness = [x.id if isinstance(x, CWE) else x
                    for x in self.weaknesses]

        return {'id':               self.id,
                'name':             self.name,
                'summary':          self.summary,
                'prerequisites':    self.prerequisites,
                'solutions':        self.solutions,
                'related_weakness': weakness}

    @classmethod
    def fromDict(cls, data):
        return cls(data['id'], data['name'], data['summary'],
                   data['prerequisites'], data['solutions'],
                   data['related_weakness'])

########
# VIA4 #
########
class VIA4:
    def __init__(self, **kwargs):
        for key, val in kwargs.items():
            if not key.startswith("__"):
                setattr(self, key, self._getChild(val))

    def _getChild(self, data):
        if isinstance(data, list):
            values = []
            for val in data:
                values.append(self._getChild(val))
            return values
        return VIA4(**data) if isinstance(data, dict) else data

    def _dig(self, data):
        if isinstance(data, dict):
            for key, val in data.items():
                if key.startswith("__"): continue
                data[key] = self._dig(val)
        if isinstance(data, VIA4):
          data = self._dig(data.__dict__)
        if isinstance(data, list):
          for i, item in enumerate(data):
            data[i] = self._dig(item)
        return data

    def dict(self):
        return self._dig(self.__dict__)

    @classmethod
    def fromDict(cls, data):
        return cls(**data)
