generate_api_doc.py 3.47 KB
import collections
import re

from django.core.management import BaseCommand
from django.urls.resolvers import get_resolver

import yaml
from yaml.scanner import ScannerError

from common.api_doc import (base_part, security_definitions, responses, definitions)

_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG


def dict_representer(dumper, data):
    return dumper.represent_dict(data.items())


def dict_constructor(loader, node):
    return collections.OrderedDict(loader.construct_pairs(node))


yaml.add_representer(collections.OrderedDict, dict_representer)
yaml.add_constructor(_mapping_tag, dict_constructor)


def unify_url_path_format(string):
    return '/%s' % re.sub(r'%\(([^/]+)\)s', lambda m: '{%s}' % m.group(1),
                          string)


DEFAULT_API_DOC = '''
summary: 未填写
responses:
  200:
    description: ok
'''


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument('-o', '--output_file', help='文件名,用于存储文档')

    def handle(self, *args, **kwargs):
        urls = get_resolver()
        api_doc_dct = {}
        for view, pattern in urls.reverse_dict.items():
            view_class = view.view_class
            url_path, path_parameters = pattern[0][0]
            url_path = unify_url_path_format(url_path)
            if url_path not in ['/api/compare/v1']:
                continue
            url_path_paramters = getattr(view, 'parameters_doc', None)
            if url_path_paramters:
                url_path_paramters = yaml.load(url_path_paramters)
            else:
                url_path_paramters = [{
                    'in': 'path',
                    'name': parameter,
                    'required': True,
                    'schema': {
                        'type': 'string'
                    }
                } for parameter in path_parameters]
            api_doc_dct[url_path] = {}
            if url_path_paramters:
                api_doc_dct[url_path]['parameters'] = url_path_paramters
            for method in view_class.http_method_names:
                method_handler = getattr(view_class, method, None)
                doc = getattr(method_handler, 'openapi_doc', None)
                if not method_handler or (method == 'options' and not doc):
                    continue
                try:
                    doc = yaml.load(doc or DEFAULT_API_DOC)
                except ScannerError as err:
                    raise Exception(
                        'failed to load doc: """%s"""\nerr: %s' % (doc, err))
                # if doc.get('parameters'):
                #     for parameter in doc['parameters']:
                #         if parameter['name'] in path_parameters:
                #             doc['parameters'].pop(parameter)
                api_doc_dct[url_path][method] = doc
        doc_dct = yaml.load(base_part)
        doc_dct['paths'] = api_doc_dct
        doc_dct['securityDefinitions'] = yaml.load(security_definitions)
        doc_dct['responses'] = yaml.load(responses)
        doc_dct['definitions'] = yaml.load(definitions)

        doc_str = yaml.dump(
            doc_dct, default_flow_style=False, allow_unicode=True)
        if kwargs.get('output_file'):
            with open(kwargs['output_file'], 'w') as f:
                f.write(doc_str)
            self.stdout.write(
                self.style.SUCCESS('api doc generated succssfully: %s' %
                                   kwargs['output_file']))
        else:
            self.stdout.write(doc_str)