Python实战-编写Web App-Day5-编写Web框架

在正式开始Web开发前,我们需要编写一个Web框架。

由于aiohttp相对比较底层,所以我们需要基于aiohttp自己封装一个处理url的Web框架。

## 定义add_route函数,来注册一个URL处理函数

def add_route(app, fn):

    method = getattr(fn, '__method__', None)

    path = getattr(fn, '__route__', None)

    if path is None or method is None:

        raise ValueError('@get or @post not defined in %s.' % str(fn))

    if not asyncio.iscoroutinefunction(fn) and not inspect.isgeneratorfunction(fn):

        fn = asyncio.coroutine(fn)

    logging.info('add route %s %s => %s(%s)' % (method, path, fn.__name__, ', '.join(inspect.signature(fn).parameters.keys())))

    app.router.add_route(method, path, RequestHandler(app, fn))

get和post

## 编写装饰函数 @get()

def get(path):

    ## Define decorator @get('/path')

    def decorator(func):

        @functools.wraps(func)

        def wrapper(*args, **kw):

            return func(*args, **kw)

        wrapper.__method__ = 'GET'

        wrapper.__route__ = path

        return wrapper

    return decorator


## 编写装饰函数 @post()

def post(path):

    ## Define decorator @post('/path')

    def decorator(func):

        @functools.wraps(func)

        def wrapper(*args, **kw):

            return func(*args, **kw)

        wrapper.__method__ = 'POST'

        wrapper.__route__ = path

        return wrapper

    return decorator

在www目录新建coroweb.py

#!/usr/bin/env python3

# -*- coding: utf-8 -*-




import asyncio, os, inspect, logging, functools


from urllib import parse


from aiohttp import web


## apis是处理分页的模块,后面会编写

## APIError 是指API调用时发生逻辑错误

from apis import APIError


## 编写装饰函数 @get()

def get(path):

    ## Define decorator @get('/path')

    def decorator(func):

        @functools.wraps(func)

        def wrapper(*args, **kw):

            return func(*args, **kw)

        wrapper.__method__ = 'GET'

        wrapper.__route__ = path

        return wrapper

    return decorator


## 编写装饰函数 @post()

def post(path):

    ## Define decorator @post('/path')

    def decorator(func):

        @functools.wraps(func)

        def wrapper(*args, **kw):

            return func(*args, **kw)

        wrapper.__method__ = 'POST'

        wrapper.__route__ = path

        return wrapper

    return decorator


## 以下是RequestHandler需要定义的一些函数

def get_required_kw_args(fn):

    args = []

    params = inspect.signature(fn).parameters

    for name, param in params.items():

        if param.kind == inspect.Parameter.KEYWORD_ONLY and param.default == inspect.Parameter.empty:

            args.append(name)

    return tuple(args)


def get_named_kw_args(fn):

    args = []

    params = inspect.signature(fn).parameters

    for name, param in params.items():

        if param.kind == inspect.Parameter.KEYWORD_ONLY:

            args.append(name)

    return tuple(args)


def has_named_kw_args(fn):

    params = inspect.signature(fn).parameters

    for name, param in params.items():

        if param.kind == inspect.Parameter.KEYWORD_ONLY:

            return True


def has_var_kw_arg(fn):

    params = inspect.signature(fn).parameters

    for name, param in params.items():

        if param.kind == inspect.Parameter.VAR_KEYWORD:

            return True


def has_request_arg(fn):

    sig = inspect.signature(fn)

    params = sig.parameters

    found = False

    for name, param in params.items():

        if name == 'request':

            found = True

            continue

        if found and (param.kind != inspect.Parameter.VAR_POSITIONAL and param.kind != inspect.Parameter.KEYWORD_ONLY and param.kind != inspect.Parameter.VAR_KEYWORD):

            raise ValueError('request parameter must be the last named parameter in function: %s%s' % (fn.__name__, str(sig)))

    return found


## 定义RequestHandler从URL函数中分析其需要接受的参数

class RequestHandler(object):


    def __init__(self, app, fn):

        self._app = app

        self._func = fn

        self._has_request_arg = has_request_arg(fn)

        self._has_var_kw_arg = has_var_kw_arg(fn)

        self._has_named_kw_args = has_named_kw_args(fn)

        self._named_kw_args = get_named_kw_args(fn)

        self._required_kw_args = get_required_kw_args(fn)


    async def __call__(self, request):

        kw = None

        if self._has_var_kw_arg or self._has_named_kw_args or self._required_kw_args:

            if request.method == 'POST':

                if not request.content_type:

                    return web.HTTPBadRequest(text='Missing Content-Type.')

                ct = request.content_type.lower()

                if ct.startswith('application/json'):

                    params = await request.json()

                    if not isinstance(params, dict):

                        return web.HTTPBadRequest(text='JSON body must be object.')

                    kw = params

                elif ct.startswith('application/x-www-form-urlencoded') or ct.startswith('multipart/form-data'):

                    params = await request.post()

                    kw = dict(**params)

                else:

                    return web.HTTPBadRequest(text='Unsupported Content-Type: %s' % request.content_type)

            if request.method == 'GET':

                qs = request.query_string

                if qs:

                    kw = dict()

                    for k, v in parse.parse_qs(qs, True).items():

                        kw[k] = v[0]

        if kw is None:

            kw = dict(**request.match_info)

        else:

            if not self._has_var_kw_arg and self._named_kw_args:

                # remove all unamed kw:

                copy = dict()

                for name in self._named_kw_args:

                    if name in kw:

                        copy[name] = kw[name]

                kw = copy

            # check named arg:

            for k, v in request.match_info.items():

                if k in kw:

                    logging.warning('Duplicate arg name in named arg and kw args: %s' % k)

                kw[k] = v

        if self._has_request_arg:

            kw['request'] = request

        # check required kw:

        if self._required_kw_args:

            for name in self._required_kw_args:

                if not name in kw:

                    return web.HTTPBadRequest(text='Missing argument: %s' % name)

        logging.info('call with args: %s' % str(kw))

        try:

            r = await self._func(**kw)

            return r

        except APIError as e:

            return dict(error=e.error, data=e.data, message=e.message)

## 定义add_static函数,来注册static文件夹下的文件

def add_static(app):

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static')

    app.router.add_static('/static/', path)

    logging.info('add static %s => %s' % ('/static/', path))


## 定义add_route函数,来注册一个URL处理函数

def add_route(app, fn):

    method = getattr(fn, '__method__', None)

    path = getattr(fn, '__route__', None)

    if path is None or method is None:

        raise ValueError('@get or @post not defined in %s.' % str(fn))

    if not asyncio.iscoroutinefunction(fn) and not inspect.isgeneratorfunction(fn):

        fn = asyncio.coroutine(fn)

    logging.info('add route %s %s => %s(%s)' % (method, path, fn.__name__, ', '.join(inspect.signature(fn).parameters.keys())))

    app.router.add_route(method, path, RequestHandler(app, fn))


## 定义add_routes函数,自动把handler模块的所有符合条件的URL函数注册了

def add_routes(app, module_name):

    n = module_name.rfind('.')

    if n == (-1):

        mod = __import__(module_name, globals(), locals())

    else:

        name = module_name[n+1:]

        mod = getattr(__import__(module_name[:n], globals(), locals(), [name]), name)

    for attr in dir(mod):

        if attr.startswith('_'):

            continue

        fn = getattr(mod, attr)

        if callable(fn):

            method = getattr(fn, '__method__', None)

            path = getattr(fn, '__route__', None)

            if method and path:

                add_route(app, fn)

最后,在app.py中加入middleware、jinja2模板和自注册的支持。

#!/usr/bin/env python3

# -*- coding: utf-8 -*-




import logging; logging.basicConfig(level=logging.INFO)

import asyncio, os, json, time

from datetime import datetime

from aiohttp import web

from jinja2 import Environment, FileSystemLoader


## config 配置代码在后面会创建添加, 可先从github下载到www下,以防报错

from config import configs


import orm

from coroweb import add_routes, add_static


## handlers 是url处理模块在后面会创建编写, 可先从github下载到www下,以防报错

from handlers import cookie2user, COOKIE_NAME


## 初始化jinja2的函数

def init_jinja2(app, **kw):

    logging.info('init jinja2...')

    options = dict(

        autoescape = kw.get('autoescape', True),

        block_start_string = kw.get('block_start_string', '{%'),

        block_end_string = kw.get('block_end_string', '%}'),

        variable_start_string = kw.get('variable_start_string', '{{'),

        variable_end_string = kw.get('variable_end_string', '}}'),

        auto_reload = kw.get('auto_reload', True)

    )

    path = kw.get('path', None)

    if path is None:

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')

    logging.info('set jinja2 template path: %s' % path)

    env = Environment(loader=FileSystemLoader(path), **options)

    filters = kw.get('filters', None)

    if filters is not None:

        for name, f in filters.items():

            env.filters[name] = f

    app['__templating__'] = env


## 以下是middleware,可以把通用的功能从每个URL处理函数中拿出来集中放到一个地方

## URL处理日志工厂

async def logger_factory(app, handler):

    async def logger(request):

        logging.info('Request: %s %s' % (request.method, request.path))

        return (await handler(request))

    return logger


## 认证处理工厂--把当前用户绑定到request上,并对URL/manage/进行拦截,检查当前用户是否是管理员身份

async def auth_factory(app, handler):

    async def auth(request):

        logging.info('check user: %s %s' % (request.method, request.path))

        request.__user__ = None

        cookie_str = request.cookies.get(COOKIE_NAME)

        if cookie_str:

            user = await cookie2user(cookie_str)

            if user:

                logging.info('set current user: %s' % user.email)

                request.__user__ = user

        if request.path.startswith('/manage/') and (request.__user__ is None or not request.__user__.admin):

            return web.HTTPFound('/signin')

        return (await handler(request))

    return auth


## 数据处理工厂

async def data_factory(app, handler):

    async def parse_data(request):

        if request.method == 'POST':

            if request.content_type.startswith('application/json'):

                request.__data__ = await request.json()

                logging.info('request json: %s' % str(request.__data__))

            elif request.content_type.startswith('application/x-www-form-urlencoded'):

                request.__data__ = await request.post()

                logging.info('request form: %s' % str(request.__data__))

        return (await handler(request))

    return parse_data


## 响应返回处理工厂

async def response_factory(app, handler):

    async def response(request):

        logging.info('Response handler...')

        r = await handler(request)

        if isinstance(r, web.StreamResponse):

            return r

        if isinstance(r, bytes):

            resp = web.Response(body=r)

            resp.content_type = 'application/octet-stream'

            return resp

        if isinstance(r, str):

            if r.startswith('redirect:'):

                return web.HTTPFound(r[9:])

            resp = web.Response(body=r.encode('utf-8'))

            resp.content_type = 'text/html;charset=utf-8'

            return resp

        if isinstance(r, dict):

            template = r.get('__template__')

            if template is None:

                resp = web.Response(body=json.dumps(r, ensure_ascii=False, default=lambda o: o.__dict__).encode('utf-8'))

                resp.content_type = 'application/json;charset=utf-8'

                return resp

            else:

                r['__user__'] = request.__user__

                resp = web.Response(body=app['__templating__'].get_template(template).render(**r).encode('utf-8'))

                resp.content_type = 'text/html;charset=utf-8'

                return resp

        if isinstance(r, int) and r >= 100 and r < 600:

            return web.Response(r)

        if isinstance(r, tuple) and len(r) == 2:

            t, m = r

            if isinstance(t, int) and t >= 100 and t < 600:

                return web.Response(t, str(m))

        # default:

        resp = web.Response(body=str(r).encode('utf-8'))

        resp.content_type = 'text/plain;charset=utf-8'

        return resp

    return response


## 时间转换

def datetime_filter(t):

    delta = int(time.time() - t)

    if delta < 60:

        return u'1分钟前'

    if delta < 3600:

        return u'%s分钟前' % (delta // 60)

    if delta < 86400:

        return u'%s小时前' % (delta // 3600)

    if delta < 604800:

        return u'%s天前' % (delta // 86400)

    dt = datetime.fromtimestamp(t)

    return u'%s年%s月%s日' % (dt.year, dt.month, dt.day)


async def init(loop):

    await orm.create_pool(loop=loop, **configs.db)

    app = web.Application(loop=loop, middlewares=[

        logger_factory, auth_factory, response_factory

    ])

    init_jinja2(app, filters=dict(datetime=datetime_filter))

    add_routes(app, 'handlers')

    add_static(app)

    srv = await loop.create_server(app.make_handler(), '127.0.0.1', 9000)

    logging.info('server started at http://127.0.0.1:9000...')

    return srv


loop = asyncio.get_event_loop()

loop.run_until_complete(init(loop))

loop.run_forever()

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容