datastructures.py
定义了大量数据类型,方便对http报文的操作。
URL类
class URL:
def __init__(
self, url: str = "", scope: Scope = None, **components: typing.Any
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
# scope和url不能同时设置
assert not components, 'Cannot set both "scope" and "**components".'
# scope和components不能同时设置
scheme = scope.get("scheme", "http")
# http https ...
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
query_string = scope.get("query_string", b"")
host_header = None
for key, value in scope["headers"]:
if key == b"host":
host_header = value.decode("latin-1")
break
# 这里写的没问题,headers的值是[[k,v], [k,v], [k,v], ...]这种形式,不是dict
if host_header is not None:
url = f"{scheme}://{host_header}{path}"
# 拼接全路径
elif server is None:
url = path
else:
host, port = server
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
if port == default_port:
url = f"{scheme}://{host}{path}"
else:
url = f"{scheme}://{host}:{port}{path}"
# 这段是根据实际情况拼接url
# 如果存在headers中的host,直接拼接scheme和host以及path
# 如果headers中不存在host,且server也不存在,直接使用path
# else,从server中取出host和port,根据scheme找到其应该使用的default_port
# 对比两者,一至则不写port,不一致则加上port
if query_string:
url += "?" + query_string.decode()
elif components:
assert not url, 'Cannot set both "url" and "**components".'
url = URL("").replace(**components).components.geturl()
# 类型:SplitResult.geturl()
# 属于urllib库范畴
self._url = url
@property
def components(self) -> SplitResult:
if not hasattr(self, "_components"):
self._components = urlsplit(self._url)
return self._components
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def username(self) -> typing.Union[None, str]:
return self.components.username
@property
def password(self) -> typing.Union[None, str]:
return self.components.password
@property
def hostname(self) -> typing.Union[None, str]:
return self.components.hostname
@property
def port(self) -> typing.Optional[int]:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> "URL":
if (
"username" in kwargs
or "password" in kwargs
or "hostname" in kwargs
or "port" in kwargs
):
hostname = kwargs.pop("hostname", self.hostname)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
password = kwargs.pop("password", self.password)
netloc = hostname
if port is not None:
netloc += f":{port}"
if username is not None:
userpass = username
if password is not None:
userpass += f":{password}"
netloc = f"{userpass}@{netloc}"
kwargs["netloc"] = netloc
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
def include_query_params(self, **kwargs: typing.Any) -> "URL":
# 插入查询参数
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
def replace_query_params(self, **kwargs: typing.Any) -> "URL":
# 替换查询参数
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(
self, keys: typing.Union[str, typing.Sequence[str]]
) -> "URL":
# 移除查询参数
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
for key in keys:
params.pop(key, None)
query = urlencode(params.multi_items())
return self.replace(query=query)
def __eq__(self, other: typing.Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return self._url
def __repr__(self) -> str:
url = str(self)
if self.password:
url = str(self.replace(password="********"))
return f"{self.__class__.__name__}({repr(url)})"
URLPath类
class URLPath(str):
"""
一个URL路径字符串,它也可能包含一个关联的协议 与/或 host.
用于路由返回 `url_path_for` 匹配(反向查询).
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path) # type: ignore
# 可以作为元类,筛选协议类型。
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
self.host = host
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
# 拼接绝对路径
if isinstance(base_url, str):
base_url = URL(base_url)
if self.protocol:
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
# 这个设计很好,如果是我可能直接写if—else了
else:
scheme = base_url.scheme
if self.host:
netloc = self.host
else:
netloc = base_url.netloc
path = base_url.path.rstrip("/") + str(self)
return str(URL(scheme=scheme, netloc=netloc, path=path))
Secret类
class Secret:
"""
保存不应该在tracebacks等中显示的字符串值.
您应该在需要时将该值转换为str.
"""
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}('**********')"
def __str__(self) -> str:
return self._value
CommaSeparatedStrings类
用到了标准库shlex,其功能与unix的shell有关。这里用到了分割字符串方法
class CommaSeparatedStrings(Sequence):
"""
逗号分割字符串
"""
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
splitter.whitespace_split = True
# 标准库shlex,目前来看该函数和split()用法相近.其优点有待探索
self._items = [item.strip() for item in splitter]
else:
self._items = list(value)
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
return self._items[index]
def __iter__(self) -> typing.Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = [item for item in self]
return f"{class_name}({items!r})"
def __str__(self) -> str:
return ", ".join([repr(item) for item in self])
ImmutableMultiDict
class ImmutableMultiDict(typing.Mapping):
"""
不可变多值字典
这里其实就是个列表和二元元祖组成的字典
"""
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
# 只能是上述中一种类型
if args:
value = args[0]
else:
value = []
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
)
# [(key, value), ...] + [(key, value), ...]
if not value:
_items = [] # type: typing.List[typing.Tuple[typing.Any, typing.Any]]
elif hasattr(value, "multi_items"):
value = typing.cast(ImmutableMultiDict, value)
# 强制类型转换,不作任何操作。只为了提醒类型检查器
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping, value)
_items = list(value.items())
else:
value = typing.cast(
typing.List[typing.Tuple[typing.Any, typing.Any]], value
)
_items = list(value)
# 该list为[(key, value), (key, value), ...]
self._dict = {k: v for k, v in _items}
self._list = _items
def getlist(self, key: typing.Any) -> typing.List[str]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> typing.KeysView:
return self._dict.keys()
def values(self) -> typing.ValuesView:
return self._dict.values()
def items(self) -> typing.ItemsView:
return self._dict.items()
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
return list(self._list)
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
return default
def __getitem__(self, key: typing.Any) -> str:
return self._dict[key]
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = self.multi_items()
return f"{class_name}({items!r})"
MultiDict类
class MultiDict(ImmutableMultiDict):
"""
可变,加了些方法
"""
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
self.setlist(key, [value])
def __delitem__(self, key: typing.Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
def popitem(self) -> typing.Tuple:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
def poplist(self, key: typing.Any) -> typing.List:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
def clear(self) -> None:
self._dict.clear()
self._list.clear()
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
def setlist(self, key: typing.Any, values: typing.List) -> None:
if not values:
self.pop(key, None)
else:
existing_items = [(k, v) for (k, v) in self._list if k != key]
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
def append(self, key: typing.Any, value: typing.Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
*args: typing.Union[
"MultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
self._list = existing_items + value.multi_items()
self._dict.update(value)
QueryParams类
class QueryParams(ImmutableMultiDict):
"""
一个不可变多值字典.
"""
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
str,
bytes,
],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value = args[0] if args else []
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
)
else:
super().__init__(*args, **kwargs) # type: ignore
self._list = [(str(k), str(v)) for k, v in self._list]
self._dict = {str(k): str(v) for k, v in self._dict.items()}
def __str__(self) -> str:
return urlencode(self._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
query_string = str(self)
return f"{class_name}({query_string!r})"
UploadFile类
这里用到了关于tempfile库中缓存文件的概念,可参照下列文章
https://blog.csdn.net/zhtysw/article/details/82778354
class UploadFile:
"""
作为请求数据的一部分的上传的文件。
"""
spool_max_size = 1024 * 1024
def __init__(
self, filename: str, file: typing.IO = None, content_type: str = ""
) -> None:
self.filename = filename
self.content_type = content_type
if file is None:
file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
# 创建文件对象用于临时数据保存
# 只有达到一定条件,才会真正写入到硬盘的临时文件中
self.file = file
@property
def _in_memory(self) -> bool:
rolled_to_disk = getattr(self.file, "_rolled", True)
# 判断处于内存中,还是已经录入磁盘中的临时文件
# 对为什么用rolled这个词,我还是比较疑惑
return not rolled_to_disk
async def write(self, data: typing.Union[bytes, str]) -> None:
if self._in_memory:
self.file.write(data) # type: ignore
else:
await run_in_threadpool(self.file.write, data)
# 根据文件是否已经roll进磁盘,选择同步还是异步
async def read(self, size: int = -1) -> typing.Union[bytes, str]:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
FormData类
class FormData(ImmutableMultiDict):
"""
不可变多值字典, 包含文件上传和文本输入.
"""
def __init__(
self,
*args: typing.Union[
"FormData",
typing.Mapping[str, typing.Union[str, UploadFile]],
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
],
**kwargs: typing.Union[str, UploadFile],
) -> None:
super().__init__(*args, **kwargs)
async def close(self) -> None:
for key, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()
# 关闭所有文件
Headers类
class Headers(typing.Mapping[str, str]):
"""
一个不可变忽视大小写的多值字典.
"""
def __init__(
self,
headers: typing.Mapping[str, str] = None,
raw: typing.List[typing.Tuple[bytes, bytes]] = None,
scope: Scope = None,
) -> None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in headers.items()
]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
elif scope is not None:
self._list = scope["headers"]
# 三者同时只能存在一个
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return list(self._list)
def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
except KeyError:
return default
def getlist(self, key: str) -> typing.List[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
def mutablecopy(self) -> "MutableHeaders":
return MutableHeaders(raw=self._list[:])
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
MutableHeaders类
class MutableHeaders(Headers):
"""可变Headers"""
def __setitem__(self, key: str, value: str) -> None:
"""
设置 header `key` 所对应的 `value`, 删除重复条目.
保留插入顺序.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
移除header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._list
def setdefault(self, key: str, value: str) -> str:
"""
如果header `key` 不存在, 那么给他赋值.
返回 header value.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value
def update(self, other: dict) -> None:
for key, val in other.items():
self[key] = val
def append(self, key: str, value: str) -> None:
"""
添加一个header, 保留任何重复的条目.
"""
append_key = key.lower().encode("latin-1")
append_value = value.encode("latin-1")
self._list.append((append_key, append_value))
def add_vary_header(self, vary: str) -> None:
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])
self["vary"] = vary
State类
class State(object):
"""
可用于存储任意状态的对象.
用于 `request.state` 和 `app.state`.
"""
def __init__(self, state: typing.Dict = None):
if state is None:
state = {}
super(State, self).__setattr__("_state", state)
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
self._state[key] = value
def __getattr__(self, key: typing.Any) -> typing.Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
def __delattr__(self, key: typing.Any) -> None:
del self._state[key]