# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
import itertools
import re
from inspect import isclass, getdoc
from collections import OrderedDict
try:
from collections.abc import Hashable
except ImportError:
from collections import Hashable
from six import string_types, itervalues, iteritems, iterkeys
from flask import current_app
from werkzeug.routing import parse_rule
from . import fields
from .model import Model, ModelBase
from .reqparse import RequestParser
from .utils import merge, not_none, not_none_sorted
from ._http import HTTPStatus
#: Maps Flask/Werkzeug rooting types to Swagger ones
PATH_TYPES = {
'int': 'integer',
'float': 'number',
'string': 'string',
'default': 'string',
}
#: Maps Python primitives types to Swagger ones
PY_TYPES = {
int: 'integer',
float: 'number',
str: 'string',
bool: 'boolean',
None: 'void'
}
RE_URL = re.compile(r'<(?:[^:<>]+:)?([^<>]+)>')
DEFAULT_RESPONSE_DESCRIPTION = 'Success'
DEFAULT_RESPONSE = {'description': DEFAULT_RESPONSE_DESCRIPTION}
RE_RAISES = re.compile(r'^:raises\s+(?P<name>[\w\d_]+)\s*:\s*(?P<description>.*)$', re.MULTILINE)
def ref(model):
'''Return a reference to model in definitions'''
name = model.name if isinstance(model, ModelBase) else model
return {'$ref': '#/definitions/{0}'.format(name)}
def _v(value):
'''Dereference values (callable)'''
return value() if callable(value) else value
def extract_path(path):
'''
Transform a Flask/Werkzeug URL pattern in a Swagger one.
'''
return RE_URL.sub(r'{\1}', path)
def extract_path_params(path):
'''
Extract Flask-style parameters from an URL pattern as Swagger ones.
'''
params = OrderedDict()
for converter, arguments, variable in parse_rule(path):
if not converter:
continue
param = {
'name': variable,
'in': 'path',
'required': True
}
if converter in PATH_TYPES:
param['type'] = PATH_TYPES[converter]
elif converter in current_app.url_map.converters:
param['type'] = 'string'
else:
raise ValueError('Unsupported type converter: %s' % converter)
params[variable] = param
return params
def _param_to_header(param):
param.pop('in', None)
param.pop('name', None)
return _clean_header(param)
def _clean_header(header):
if isinstance(header, string_types):
header = {'description': header}
typedef = header.get('type', 'string')
if isinstance(typedef, Hashable) and typedef in PY_TYPES:
header['type'] = PY_TYPES[typedef]
elif isinstance(typedef, (list, tuple)) and len(typedef) == 1 and typedef[0] in PY_TYPES:
header['type'] = 'array'
header['items'] = {'type': PY_TYPES[typedef[0]]}
elif hasattr(typedef, '__schema__'):
header.update(typedef.__schema__)
else:
header['type'] = typedef
return not_none(header)
def parse_docstring(obj):
raw = getdoc(obj)
summary = raw.strip(' \n').split('\n')[0].split('.')[0] if raw else None
raises = {}
details = raw.replace(summary, '').lstrip('. \n').strip(' \n') if raw else None
for match in RE_RAISES.finditer(raw or ''):
raises[match.group('name')] = match.group('description')
if details:
details = details.replace(match.group(0), '')
parsed = {
'raw': raw,
'summary': summary or None,
'details': details or None,
'returns': None,
'params': [],
'raises': raises,
}
return parsed
def is_hidden(resource, route_doc=None):
'''
Determine whether a Resource has been hidden from Swagger documentation
i.e. by using Api.doc(False) decorator
'''
if route_doc is False:
return True
else:
return hasattr(resource, "__apidoc__") and resource.__apidoc__ is False
[docs]class Swagger(object):
'''
A Swagger documentation wrapper for an API instance.
'''
def __init__(self, api):
self.api = api
self._registered_models = {}
def as_dict(self):
'''
Output the specification as a serializable ``dict``.
:returns: the full Swagger specification in a serializable format
:rtype: dict
'''
basepath = self.api.base_path
if len(basepath) > 1 and basepath.endswith('/'):
basepath = basepath[:-1]
infos = {
'title': _v(self.api.title),
'version': _v(self.api.version),
}
if self.api.description:
infos['description'] = _v(self.api.description)
if self.api.terms_url:
infos['termsOfService'] = _v(self.api.terms_url)
if self.api.contact and (self.api.contact_email or self.api.contact_url):
infos['contact'] = {
'name': _v(self.api.contact),
'email': _v(self.api.contact_email),
'url': _v(self.api.contact_url),
}
if self.api.license:
infos['license'] = {'name': _v(self.api.license)}
if self.api.license_url:
infos['license']['url'] = _v(self.api.license_url)
paths = {}
tags = self.extract_tags(self.api)
# register errors
responses = self.register_errors()
for ns in self.api.namespaces:
for resource, urls, route_doc, kwargs in ns.resources:
for url in self.api.ns_urls(ns, urls):
path = extract_path(url)
serialized = self.serialize_resource(
ns,
resource,
url,
route_doc=route_doc,
**kwargs
)
paths[path] = serialized
# merge in the top-level authorizations
for ns in self.api.namespaces:
if ns.authorizations:
if self.api.authorizations is None:
self.api.authorizations = {}
self.api.authorizations = merge(self.api.authorizations, ns.authorizations)
specs = {
'swagger': '2.0',
'basePath': basepath,
'paths': not_none_sorted(paths),
'info': infos,
'produces': list(iterkeys(self.api.representations)),
'consumes': ['application/json'],
'securityDefinitions': self.api.authorizations or None,
'security': self.security_requirements(self.api.security) or None,
'tags': tags,
'definitions': self.serialize_definitions() or None,
'responses': responses or None,
'host': self.get_host(),
}
return not_none(specs)
def get_host(self):
hostname = current_app.config.get('SERVER_NAME', None) or None
if hostname and self.api.blueprint and self.api.blueprint.subdomain:
hostname = '.'.join((self.api.blueprint.subdomain, hostname))
return hostname
def extract_tags(self, api):
tags = []
by_name = {}
for tag in api.tags:
if isinstance(tag, string_types):
tag = {'name': tag}
elif isinstance(tag, (list, tuple)):
tag = {'name': tag[0], 'description': tag[1]}
elif isinstance(tag, dict) and 'name' in tag:
pass
else:
raise ValueError('Unsupported tag format for {0}'.format(tag))
tags.append(tag)
by_name[tag['name']] = tag
for ns in api.namespaces:
# hide namespaces without any Resources
if not ns.resources:
continue
# hide namespaces with all Resources hidden from Swagger documentation
if all(
is_hidden(r.resource, route_doc=r.route_doc)
for r in ns.resources
):
continue
if ns.name not in by_name:
tags.append({
'name': ns.name,
'description': ns.description
} if ns.description else {'name': ns.name})
elif ns.description:
by_name[ns.name]['description'] = ns.description
return tags
def extract_resource_doc(self, resource, url, route_doc=None):
route_doc = {} if route_doc is None else route_doc
if route_doc is False:
return False
doc = merge(getattr(resource, '__apidoc__', {}), route_doc)
if doc is False:
return False
# ensure unique names for multiple routes to the same resource
# provides different Swagger operationId's
doc["name"] = (
"{}_{}".format(resource.__name__, url)
if route_doc
else resource.__name__
)
params = merge(self.expected_params(doc), doc.get('params', OrderedDict()))
params = merge(params, extract_path_params(url))
# Track parameters for late deduplication
up_params = {(n, p.get('in', 'query')): p for n, p in params.items()}
need_to_go_down = set()
methods = [m.lower() for m in resource.methods or []]
for method in methods:
method_doc = doc.get(method, OrderedDict())
method_impl = getattr(resource, method)
if hasattr(method_impl, 'im_func'):
method_impl = method_impl.im_func
elif hasattr(method_impl, '__func__'):
method_impl = method_impl.__func__
method_doc = merge(method_doc, getattr(method_impl, '__apidoc__', OrderedDict()))
if method_doc is not False:
method_doc['docstring'] = parse_docstring(method_impl)
method_params = self.expected_params(method_doc)
method_params = merge(method_params, method_doc.get('params', {}))
inherited_params = OrderedDict((k, v) for k, v in iteritems(params) if k in method_params)
method_doc['params'] = merge(inherited_params, method_params)
for name, param in method_doc['params'].items():
key = (name, param.get('in', 'query'))
if key in up_params:
need_to_go_down.add(key)
doc[method] = method_doc
# Deduplicate parameters
# For each couple (name, in), if a method overrides it,
# we need to move the paramter down to each method
if need_to_go_down:
for method in methods:
method_doc = doc.get(method)
if not method_doc:
continue
params = {
(n, p.get('in', 'query')): p
for n, p in (method_doc['params'] or {}).items()
}
for key in need_to_go_down:
if key not in params:
method_doc['params'][key[0]] = up_params[key]
doc['params'] = OrderedDict(
(k[0], p) for k, p in up_params.items() if k not in need_to_go_down
)
return doc
def expected_params(self, doc):
params = OrderedDict()
if 'expect' not in doc:
return params
for expect in doc.get('expect', []):
if isinstance(expect, RequestParser):
parser_params = OrderedDict((p['name'], p) for p in expect.__schema__)
params.update(parser_params)
elif isinstance(expect, ModelBase):
params['payload'] = not_none({
'name': 'payload',
'required': True,
'in': 'body',
'schema': self.serialize_schema(expect),
})
elif isinstance(expect, (list, tuple)):
if len(expect) == 2:
# this is (payload, description) shortcut
model, description = expect
params['payload'] = not_none({
'name': 'payload',
'required': True,
'in': 'body',
'schema': self.serialize_schema(model),
'description': description
})
else:
params['payload'] = not_none({
'name': 'payload',
'required': True,
'in': 'body',
'schema': self.serialize_schema(expect),
})
return params
def register_errors(self):
responses = {}
for exception, handler in iteritems(self.api.error_handlers):
doc = parse_docstring(handler)
response = {
'description': doc['summary']
}
apidoc = getattr(handler, '__apidoc__', {})
self.process_headers(response, apidoc)
if 'responses' in apidoc:
_, model = list(apidoc['responses'].values())[0]
response['schema'] = self.serialize_schema(model)
responses[exception.__name__] = not_none(response)
return responses
def serialize_resource(self, ns, resource, url, route_doc=None, **kwargs):
doc = self.extract_resource_doc(resource, url, route_doc=route_doc)
if doc is False:
return
path = {
'parameters': self.parameters_for(doc) or None
}
for method in [m.lower() for m in resource.methods or []]:
methods = [m.lower() for m in kwargs.get('methods', [])]
if doc[method] is False or methods and method not in methods:
continue
path[method] = self.serialize_operation(doc, method)
path[method]['tags'] = [ns.name]
return not_none(path)
def serialize_operation(self, doc, method):
operation = {
'responses': self.responses_for(doc, method) or None,
'summary': doc[method]['docstring']['summary'],
'description': self.description_for(doc, method) or None,
'operationId': self.operation_id_for(doc, method),
'parameters': self.parameters_for(doc[method]) or None,
'security': self.security_for(doc, method),
}
# Handle 'produces' mimetypes documentation
if 'produces' in doc[method]:
operation['produces'] = doc[method]['produces']
# Handle deprecated annotation
if doc.get('deprecated') or doc[method].get('deprecated'):
operation['deprecated'] = True
# Handle form exceptions:
doc_params = list(doc.get('params', {}).values())
all_params = doc_params + (operation['parameters'] or [])
if all_params and any(p['in'] == 'formData' for p in all_params):
if any(p['type'] == 'file' for p in all_params):
operation['consumes'] = ['multipart/form-data']
else:
operation['consumes'] = ['application/x-www-form-urlencoded', 'multipart/form-data']
operation.update(self.vendor_fields(doc, method))
return not_none(operation)
def vendor_fields(self, doc, method):
'''
Extract custom 3rd party Vendor fields prefixed with ``x-``
See: http://swagger.io/specification/#specification-extensions-128
'''
return dict(
(k if k.startswith('x-') else 'x-{0}'.format(k), v)
for k, v in iteritems(doc[method].get('vendor', {}))
)
def description_for(self, doc, method):
'''Extract the description metadata and fallback on the whole docstring'''
parts = []
if 'description' in doc:
parts.append(doc['description'] or "")
if method in doc and 'description' in doc[method]:
parts.append(doc[method]['description'])
if doc[method]['docstring']['details']:
parts.append(doc[method]['docstring']['details'])
return '\n'.join(parts).strip()
def operation_id_for(self, doc, method):
'''Extract the operation id'''
return doc[method]['id'] if 'id' in doc[method] else self.api.default_id(doc['name'], method)
def parameters_for(self, doc):
params = []
for name, param in iteritems(doc['params']):
param['name'] = name
if 'type' not in param and 'schema' not in param:
param['type'] = 'string'
if 'in' not in param:
param['in'] = 'query'
if 'type' in param and 'schema' not in param:
ptype = param.get('type', None)
if isinstance(ptype, (list, tuple)):
typ = ptype[0]
param['type'] = 'array'
param['items'] = {'type': PY_TYPES.get(typ, typ)}
elif isinstance(ptype, (type, type(None))) and ptype in PY_TYPES:
param['type'] = PY_TYPES[ptype]
params.append(param)
# Handle fields mask
mask = doc.get('__mask__')
if (mask and current_app.config['RESTPLUS_MASK_SWAGGER']):
param = {
'name': current_app.config['RESTPLUS_MASK_HEADER'],
'in': 'header',
'type': 'string',
'format': 'mask',
'description': 'An optional fields mask',
}
if isinstance(mask, string_types):
param['default'] = mask
params.append(param)
return params
def responses_for(self, doc, method):
# TODO: simplify/refactor responses/model handling
responses = {}
for d in doc, doc[method]:
if 'responses' in d:
for code, response in iteritems(d['responses']):
if isinstance(response, string_types):
description = response
model = None
kwargs = {}
elif len(response) == 3:
description, model, kwargs = response
elif len(response) == 2:
description, model = response
kwargs = {}
else:
raise ValueError('Unsupported response specification')
description = description or DEFAULT_RESPONSE_DESCRIPTION
if code in responses:
responses[code].update(description=description)
else:
responses[code] = {'description': description}
if model:
responses[code]['schema'] = self.serialize_schema(model)
self.process_headers(responses[code], doc, method, kwargs.get('headers'))
if 'model' in d:
code = str(d.get('default_code', HTTPStatus.OK))
if code not in responses:
responses[code] = self.process_headers(DEFAULT_RESPONSE.copy(), doc, method)
responses[code]['schema'] = self.serialize_schema(d['model'])
if 'docstring' in d:
for name, description in iteritems(d['docstring']['raises']):
for exception, handler in iteritems(self.api.error_handlers):
error_responses = getattr(handler, '__apidoc__', {}).get('responses', {})
code = list(error_responses.keys())[0] if error_responses else None
if code and exception.__name__ == name:
responses[code] = {'$ref': '#/responses/{0}'.format(name)}
break
if not responses:
responses[str(HTTPStatus.OK.value)] = self.process_headers(DEFAULT_RESPONSE.copy(), doc, method)
return responses
def process_headers(self, response, doc, method=None, headers=None):
method_doc = doc.get(method, {})
if 'headers' in doc or 'headers' in method_doc or headers:
response['headers'] = dict(
(k, _clean_header(v)) for k, v
in itertools.chain(
iteritems(doc.get('headers', {})),
iteritems(method_doc.get('headers', {})),
iteritems(headers or {})
)
)
return response
def serialize_definitions(self):
return dict(
(name, model.__schema__)
for name, model in iteritems(self._registered_models)
)
def serialize_schema(self, model):
if isinstance(model, (list, tuple)):
model = model[0]
return {
'type': 'array',
'items': self.serialize_schema(model),
}
elif isinstance(model, ModelBase):
self.register_model(model)
return ref(model)
elif isinstance(model, string_types):
self.register_model(model)
return ref(model)
elif isclass(model) and issubclass(model, fields.Raw):
return self.serialize_schema(model())
elif isinstance(model, fields.Raw):
return model.__schema__
elif isinstance(model, (type, type(None))) and model in PY_TYPES:
return {'type': PY_TYPES[model]}
raise ValueError('Model {0} not registered'.format(model))
def register_model(self, model):
name = model.name if isinstance(model, ModelBase) else model
if name not in self.api.models:
raise ValueError('Model {0} not registered'.format(name))
specs = self.api.models[name]
self._registered_models[name] = specs
if isinstance(specs, ModelBase):
for parent in specs.__parents__:
self.register_model(parent)
if isinstance(specs, Model):
for field in itervalues(specs):
self.register_field(field)
return ref(model)
def register_field(self, field):
if isinstance(field, fields.Polymorph):
for model in itervalues(field.mapping):
self.register_model(model)
elif isinstance(field, fields.Nested):
self.register_model(field.nested)
elif isinstance(field, fields.List):
self.register_field(field.container)
def security_for(self, doc, method):
security = None
if 'security' in doc:
auth = doc['security']
security = self.security_requirements(auth)
if 'security' in doc[method]:
auth = doc[method]['security']
security = self.security_requirements(auth)
return security
def security_requirements(self, value):
if isinstance(value, (list, tuple)):
return [self.security_requirement(v) for v in value]
elif value:
requirement = self.security_requirement(value)
return [requirement] if requirement else None
else:
return []
def security_requirement(self, value):
if isinstance(value, (string_types)):
return {value: []}
elif isinstance(value, dict):
return dict(
(k, v if isinstance(v, (list, tuple)) else [v])
for k, v in iteritems(value)
)
else:
return None