Commit 9cf69d98 authored by PidgeyL's avatar PidgeyL
Browse files

more database layer abstraction

parent c0a828a8
Loading
Loading
Loading
Loading
+25 −2
Original line number Diff line number Diff line
@@ -55,7 +55,22 @@ def addSeenCVEs(user, CVEs):
    if seen:
      colSEEN.update({"user": user},{"$addToSet": {"seen_cves": { "$each": seen}}})

def removeSeenCVEs(user, CVEs):
  if type(CVEs) == str: CVEs=[CVEs]
  if type(CVEs) == list:
    colSEEN.update({"user": user}, {"$pullAll": {"seen_cves": CVEs}})

def isMasterAccount(user):
  return False if colUSERS.find({"username": user, "master": True}).count() == 0 else True

def userExists(user):
  return True if colUSERS.find({"username": user}).count() > 0 else False

def isSingleMaster(user):
  return True if len(list(colUSERS.find({"username": {"$ne": user}, "master": True}))) == 0 else False

# Query Functions
# Generic data
def getCVEs(limit=-1, query=[], skip=0):
  if type(query) == dict: query=[query]
  if len(query) == 0:
@@ -72,11 +87,19 @@ def getCPE(id):
def getAlternativeCPE(id):
  return sanitize(colCPEOTHER.find_one({"id": id}))

def getUsers():
  return sanitize(colUSERS.find())
def getFreeText(text):
  return [x["obj"] for x in db.command("text", "cves", search=text)["results"]]

# Dynamic data
def getWhitelist():
  return sanitize(colWHITELIST.find())

def getBlacklist():
  return sanitize(colBLACKLIST.find())

# Users
def getUsers():
  return sanitize(colUSERS.find())

def getUser(user):
  return sanitize(colUSERS.find_one({"username": user}))
+8 −12
Original line number Diff line number Diff line
@@ -52,9 +52,9 @@ exits = {'userInDb': 'User already exists in database',


def verifyPass(password, user):
    if not existsInDB(user):
    if not dbLayer.userExists(user):
        sys.exit(exits['userNotInDb'])
    dbPass = (list(collection.find({'username': user}))[0])['password']
    dbPass = dbLayer.getUser(user)['password']
    if not pbkdf2_sha256.verify(password, dbPass):
        sys.exit(exits['userpasscombo'])
    return True
@@ -71,19 +71,15 @@ def promptNewPass():
def masterLogin():
    master = input("Master account username: ")
    if verifyPass(getpass.getpass("Master password:"), master):
        if collection.find({'username': master, 'master': True}).count() == 0:
        if not dbLayer.isMasterAccount(master):
            sys.exit(exits['noMaster'])
    else:
        sys.exit('Master user/password combination does not exist')
    return True


def existsInDB(user):
    return True if collection.find({'username': user}).count() > 0 else False


def isLastAdmin(user):
    if len(list(collection.find({'username': {'$ne': user}, 'master': True}))) == 0:
    if dbLayer.isSingleMaster(user):
        sys.exit(exits['lastMaster'])

# script run
@@ -92,7 +88,7 @@ try:
        username = args.a
        if username.strip() == "_dummy_":
            sys.exit(exits['dummy'])
        if existsInDB(username):
        if dbLayer.userExists(username):
            sys.exit(exits['userInDb'])
        # set master if db is empty
        if(collection.count() > 0):
@@ -111,7 +107,7 @@ try:
        sys.exit("Password updated")
    elif args.r:
        username = args.r
        if not existsInDB(username):
        if not dbLayer.userExists(username):
            sys.exit(exits['userNotInDb'])
        masterLogin()
        isLastAdmin(username)
@@ -119,7 +115,7 @@ try:
        sys.exit('User removed from database')
    elif args.p:
        username = args.p
        if not existsInDB(username):
        if not dbLayer.userExists(username):
            sys.exit(exits['userNotInDb'])
        masterLogin()
        # promote
@@ -127,7 +123,7 @@ try:
        sys.exit('User promoted')
    elif args.d:
        username = args.d
        if not existsInDB(username):
        if not dbLayer.userExists(username):
            sys.exit(exits['userNotInDb'])
        masterLogin()
        isLastAdmin(username)
+9 −23
Original line number Diff line number Diff line
@@ -128,7 +128,8 @@ def blacklist_mark(cve):


def seen_mark(cve):
    seen=getSeenCVEs()
    if current_user.is_authenticated():
        seen=dbLayer.seenCVEs(current_user.get_id())
        for c in cve:
            if c["id"] in seen: cve[cve.index(c)]['seen'] = 'yes'

@@ -150,17 +151,6 @@ def getBlacklistRegexes():
    return regexes


def getSeenCVEs():
  cu=current_user.get_id()
  collection = db.mgmt_seen
  userdata = collection.find({"user":cu})
  if userdata.count()==0:
    collection.insert({"user":cu, "seen_cves":[]})
    return [] 
  else:
    return userdata[0]["seen_cves"]


def addCPEToList(cpe, listType, cpeType=None):
    if not cpeType:
        cpeType='cpe'
@@ -204,7 +194,6 @@ def adminStats():

def filter_logic(blacklist, whitelist, unlisted, timeSelect, startDate, endDate,
                 timeTypeSelect, cvssSelect, cvss, rejectedSelect, hideSeen, limit, skip):
    collection = db.cves
    query = []
    # retrieving lists
    if blacklist == "on":
@@ -234,7 +223,7 @@ def filter_logic(blacklist, whitelist, unlisted, timeSelect, startDate, endDate,

    if current_user.is_authenticated():
      if hideSeen == "hide":
        query.append({'id': {"$nin":getSeenCVEs()}})
        query.append({'id': {"$nin":dbLayer.seenCVEs(current_user.get_id)}})

    # cvss logic
    if cvssSelect != "all":
@@ -363,8 +352,7 @@ def unseen(r):
    seenlist=request.form.get('list').split(",")
    # retrieving data
    if current_user.is_authenticated():
        col = db.mgmt_seen
        col.update({"user":current_user.get_id()},{"$pullAll":{"seen_cves":seenlist}})
        dbLayer.removeSeenCVEs(current_user.get_id(), seenlist)
    settings, cve = getFilterSettingsFromPost(r)
    return render_template('index.html', settings=settings, cve=cve, r=r, pageLength=pageLength)

@@ -428,7 +416,7 @@ def cve(cveid):
        return render_template('error.html',status={'except':'cve-not-found','info':{'cve':cveid}}) 
    cve = markCPEs(cve)
    if current_user.is_authenticated():
        dbLayer.addSeenCVEs(cveid)
        dbLayer.addSeenCVEs(current_user.get_id(), cveid)
    return render_template('cve.html', cve=cve)

@app.route('/browse/<vendor>')
@@ -450,12 +438,10 @@ def browse(vendor=None):
@app.route('/search', methods=['POST'])
def searchText():
    search = request.form.get('search')
    collection = db.cves
    try:
        cvelist = db.command("text", "cves", search=search)["results"]
        cve=dbLayer.getFreeText(search)
    except:
        return render_template('error.html', status={'except':'textsearch-not-enabled'})
    cve=[x["obj"] for x in cvelist]
    return render_template('search.html', cve=cve)