Commit ff25a625 authored by Pieter-Jan Moreels's avatar Pieter-Jan Moreels
Browse files

api only script

parent a5241681
Loading
Loading
Loading
Loading

web/api.py

0 → 100644
+202 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3.3
# -*- coding: utf-8 -*-
#
# Simple web interface to cve-search to display the last entries
# and view a specific CVE.
#
# Software is free software released under the "Modified BSD license"
#

# Copyright (c) 2013-2016 	Alexandre Dulaunoy - a@foo.be
# Copyright (c) 2014-2016 	Pieter-Jan Moreels - pieterjan.moreels@gmail.com

# imports
import os
import sys
_runPath = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(_runPath, ".."))

from tornado.wsgi import WSGIContainer
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from flask import Flask, render_template, jsonify

import json
import argparse
import time
import urllib
import random
import signal
import logging
from logging.handlers import RotatingFileHandler

from lib.Config import Configuration
from lib.Toolkit import toStringFormattedCPE, toOldCPE, convertDateToDBFormat
import lib.CVEs as cves
import lib.DatabaseLayer as dbLayer

# parse command line arguments
argparser = argparse.ArgumentParser(description='Start CVE-Search web component')
argparser.add_argument('-v', action='store_true', help='verbose output')
args = argparser.parse_args()

# variables
app = Flask(__name__, static_folder='static', static_url_path='/static')
app.config['MONGO_DBNAME'] = Configuration.getMongoDB()
app.config['SECRET_KEY'] = str(random.getrandbits(256))
pageLength = Configuration.getPageLength()

# db connectors
redisdb = Configuration.getRedisVendorConnection()

# functions
def getBrowseList(vendor):
    result = {}
    if (vendor is None) or type(vendor) == list:
        v1 = redisdb.smembers("t:/o")
        v2 = redisdb.smembers("t:/a")
        v3 = redisdb.smembers("t:/h")
        vendor = sorted(list(set(list(v1) + list(v2) + list(v3))))
        cpe = None
    else:
        cpenum = redisdb.scard("v:" + vendor)
        if cpenum < 1:
            return page_not_found(404)
        p = redisdb.smembers("v:" + vendor)
        cpe = sorted(list(p))
    result["vendor"] = vendor
    result["product"] = cpe
    return result

# Routes
@app.route('/api/cpe2.3/<path:cpe>', methods=['GET'])
def cpe23(cpe):
    cpe = toStringFormattedCPE(cpe)
    if not cpe: cpe='None'
    return cpe

@app.route('/api/cpe2.2/<path:cpe>', methods=['GET'])
def cpe22(cpe):
    cpe = toOldCPE(cpe)
    if not cpe: cpe='None'
    return cpe

@app.route('/api/cvefor/<path:cpe>', methods=['GET'])
def apiCVEFor(cpe):
    cpe=urllib.parse.unquote_plus(cpe)
    cpe=toStringFormattedCPE(cpe)
    if not cpe: cpe='None'
    r = []
    cvesp = cves.last(rankinglookup=False, namelookup=False, vfeedlookup=True, capeclookup=False)
    for x in dbLayer.cvesForCPE(cpe):
        r.append(cvesp.getcve(x['id']))
    return json.dumps(r)

@app.route('/api/cve/<cveid>', methods=['GET'])
def apiCVE(cveid):
    cvesp = cves.last(rankinglookup=True, namelookup=True, vfeedlookup=True, capeclookup=True)
    cve = cvesp.getcve(cveid=cveid)
    if cve is None:
        cve = {}
    return (jsonify(cve))

@app.route('/api/browse/<vendor>', methods=['GET'])
@app.route('/api/browse/', methods=['GET'])
@app.route('/api/browse', methods=['GET'])
def apibrowse(vendor=None):
    if vendor is not None:
        vendor = urllib.parse.quote_plus(vendor).lower()
    browseList = getBrowseList(vendor)
    if isinstance(browseList, dict):
        return (jsonify(browseList))
    else:
        return (jsonify({}))

@app.route('/api/last/', methods=['GET'])
@app.route('/api/last', methods=['GET'])
def apilast():
    limit = 30
    cvesp = cves.last(rankinglookup=True, namelookup=True, vfeedlookup=True, capeclookup=True)
    cve = cvesp.get(limit=limit)
    return (jsonify({"results": cve} ))

@app.route('/api/search/<vendor>/<path:product>', methods=['GET'])
def apisearch(vendor=None, product=None):
    if vendor is None or product is None:
        return (jsonify({}))
    search = vendor + ":" + product
    return (json.dumps(dbLayer.cvesForCPE(search)))

@app.route('/api/dbInfo', methods=['GET'])
def apidbInfo():
    return (json.dumps(dbLayer.getDBStats()))

# error handeling
@app.errorhandler(404)
def page_not_found(e):
    return render_template('404.html', minimal=True), 404

# signal handlers
def sig_handler(sig, frame):
    print('Caught signal: %s' % sig)
    IOLoop.instance().add_callback(shutdown)


def shutdown():
    MAX_WAIT_SECONDS_BEFORE_SHUTDOWN = 3
    print('Stopping http server')
    http_server.stop()

    print('Will shutdown in %s seconds ...' % MAX_WAIT_SECONDS_BEFORE_SHUTDOWN)
    io_loop = IOLoop.instance()
    deadline = time.time() + MAX_WAIT_SECONDS_BEFORE_SHUTDOWN

    def stop_loop():
        now = time.time()
        if now < deadline and (io_loop._callbacks or io_loop._timeouts):
            io_loop.add_timeout(now + 1, stop_loop)
        else:
            io_loop.stop()
            print('Shutdown')
    stop_loop()

if __name__ == '__main__':
    # get properties
    flaskHost = Configuration.getFlaskHost()
    flaskPort = Configuration.getFlaskPort()
    flaskDebug = Configuration.getFlaskDebug()
    # logging
    if Configuration.getLogging():
        logfile = Configuration.getLogfile()
        pathToLog = logfile.rsplit('/', 1)[0]
        if not os.path.exists(pathToLog):
            os.makedirs(pathToLog)
        maxLogSize = Configuration.getMaxLogSize()
        backlog = Configuration.getBacklog()
        file_handler = RotatingFileHandler(logfile, maxBytes=maxLogSize, backupCount=backlog)
        file_handler.setLevel(logging.ERROR)
        formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        file_handler.setFormatter(formatter)
        app.logger.addHandler(file_handler)

    if flaskDebug:
        # start debug flask server
        app.run(host=flaskHost, port=flaskPort, debug=flaskDebug)
    else:
        # start asynchronous server using tornado wrapper for flask
        # ssl connection
        print("Server starting...")
        if Configuration.useSSL():
            cert = os.path.join(_runPath, "../", Configuration.getSSLCert())
            key = os.path.join(_runPath, "../", Configuration.getSSLKey())
            ssl_options = {"certfile": cert,
                           "keyfile": key}
        else:
            ssl_options = None
        signal.signal(signal.SIGTERM, sig_handler)
        signal.signal(signal.SIGINT, sig_handler)
        global http_server
        http_server = HTTPServer(WSGIContainer(app), ssl_options=ssl_options)
        http_server.bind(flaskPort, address=flaskHost)
        http_server.start(0)  # Forks multiple sub-processes
        IOLoop.instance().start()