"""User sessions for aiohttp.web."""
__version__ = "2.11.0"
import abc
import json
import sys
import time
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterator,
MutableMapping,
Optional,
Union,
cast,
)
from aiohttp import web
Handler = Callable[[web.Request], Awaitable[web.StreamResponse]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.StreamResponse]]
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class _CookieParams(TypedDict, total=False):
domain: Optional[str]
max_age: Optional[int]
path: str
secure: Optional[bool]
httponly: bool
samesite: Optional[str]
expires: str
class SessionData(TypedDict, total=False):
created: int
session: Dict[str, Any]
[docs]
class Session(MutableMapping[str, Any]):
"""Session dict-like object."""
def __init__(
self,
identity: Optional[Any],
*,
data: Optional[SessionData],
new: bool,
max_age: Optional[int] = None,
) -> None:
self._changed: bool = False
self._mapping: Dict[str, Any] = {}
self._identity = identity if data != {} else None
self._new = new if data != {} else True
self._max_age = max_age
created = data.get("created", None) if data else None
session_data = data.get("session", None) if data else None
now = int(time.time())
age = now - created if created else now
if max_age is not None and age > max_age:
session_data = None
if self._new or created is None:
self._created = now
else:
self._created = created
if session_data is not None:
self._mapping.update(session_data)
def __repr__(self) -> str:
return "<{} [new:{}, changed:{}, created:{}] {!r}>".format(
self.__class__.__name__,
self.new,
self._changed,
self.created,
self._mapping,
)
@property
def new(self) -> bool:
return self._new
@property
def identity(self) -> Optional[Any]: # type: ignore[misc]
return self._identity
@property
def created(self) -> int:
return self._created
@property
def empty(self) -> bool:
return not bool(self._mapping)
@property
def max_age(self) -> Optional[int]:
return self._max_age
@max_age.setter
def max_age(self, value: Optional[int]) -> None:
self._max_age = value
[docs]
def changed(self) -> None:
self._changed = True
[docs]
def invalidate(self) -> None:
self._changed = True
self._mapping = {}
[docs]
def set_new_identity(self, identity: Optional[Any]) -> None:
if not self._new:
raise RuntimeError("Can't change identity for a session which is not new")
self._identity = identity
def __len__(self) -> int:
return len(self._mapping)
def __iter__(self) -> Iterator[str]:
return iter(self._mapping)
def __contains__(self, key: object) -> bool:
return key in self._mapping
def __getitem__(self, key: str) -> Any:
return self._mapping[key]
def __setitem__(self, key: str, value: Any) -> None:
self._mapping[key] = value
self._changed = True
self._created = int(time.time())
def __delitem__(self, key: str) -> None:
del self._mapping[key]
self._changed = True
self._created = int(time.time())
SESSION_KEY = "aiohttp_session"
STORAGE_KEY = "aiohttp_session_storage"
[docs]
async def get_session(request: web.Request) -> Session:
session = request.get(SESSION_KEY)
if session is None:
storage = request.get(STORAGE_KEY)
if storage is None:
raise RuntimeError(
"Install aiohttp_session middleware " "in your aiohttp.web.Application"
)
session = await storage.load_session(request)
if not isinstance(session, Session):
raise RuntimeError(
"Installed {!r} storage should return session instance "
"on .load_session() call, got {!r}.".format(storage, session)
)
request[SESSION_KEY] = session
return session
[docs]
async def new_session(request: web.Request) -> Session:
storage = request.get(STORAGE_KEY)
if storage is None:
raise RuntimeError(
"Install aiohttp_session middleware " "in your aiohttp.web.Application"
)
session = await storage.new_session()
if not isinstance(session, Session):
raise RuntimeError(
"Installed {!r} storage should return session instance "
"on .load_session() call, got {!r}.".format(storage, session)
)
request[SESSION_KEY] = session
return session
[docs]
def session_middleware(storage: "AbstractStorage") -> Middleware:
if not isinstance(storage, AbstractStorage):
raise RuntimeError(f"Expected AbstractStorage got {storage}")
@web.middleware
async def factory(request: web.Request, handler: Handler) -> web.StreamResponse:
request[STORAGE_KEY] = storage
raise_response = False
# TODO aiohttp 4:
# Remove Union from response, and drop the raise_response variable
response: Union[web.StreamResponse, web.HTTPException]
try:
response = await handler(request)
except web.HTTPException as exc:
response = exc
raise_response = True
if not isinstance(response, (web.StreamResponse, web.HTTPException)):
raise RuntimeError(f"Expect response, not {type(response)!r}")
if not isinstance(response, (web.Response, web.HTTPException)):
# likely got websocket or streaming
return response
if response.prepared:
raise RuntimeError("Cannot save session data into prepared response")
session = request.get(SESSION_KEY)
if session is not None:
if session._changed:
await storage.save_session(request, response, session)
if raise_response:
raise cast(web.HTTPException, response)
return response
return factory
[docs]
def setup(app: web.Application, storage: "AbstractStorage") -> None:
"""Setup the library in aiohttp fashion."""
app.middlewares.append(session_middleware(storage))
[docs]
class AbstractStorage(metaclass=abc.ABCMeta):
def __init__(
self,
*,
cookie_name: str = "AIOHTTP_SESSION",
domain: Optional[str] = None,
max_age: Optional[int] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: bool = True,
samesite: Optional[str] = None,
encoder: Callable[[object], str] = json.dumps,
decoder: Callable[[str], Any] = json.loads,
) -> None:
self._cookie_name = cookie_name
self._cookie_params = _CookieParams(
domain=domain,
max_age=max_age,
path=path,
secure=secure,
httponly=httponly,
samesite=samesite,
)
self._max_age = max_age
self._encoder = encoder
self._decoder = decoder
@property
def cookie_name(self) -> str:
return self._cookie_name
@property
def max_age(self) -> Optional[int]:
return self._max_age
@property
def cookie_params(self) -> _CookieParams:
return self._cookie_params
def _get_session_data(self, session: Session) -> SessionData:
if session.empty:
return {}
return {"created": session.created, "session": session._mapping}
[docs]
async def new_session(self) -> Session:
return Session(None, data=None, new=True, max_age=self.max_age)
[docs]
@abc.abstractmethod
async def load_session(self, request: web.Request) -> Session:
pass
[docs]
@abc.abstractmethod
async def save_session(
self, request: web.Request, response: web.StreamResponse, session: Session
) -> None:
pass
[docs]
def load_cookie(self, request: web.Request) -> Optional[str]:
# TODO: Remove explicit type anotation when aiohttp 3.8 is out
cookie: Optional[str] = request.cookies.get(self._cookie_name)
return cookie
[docs]
def save_cookie(
self,
response: web.StreamResponse,
cookie_data: str,
*,
max_age: Optional[int] = None,
) -> None:
params = self._cookie_params.copy()
if max_age is not None:
params["max_age"] = max_age
t = time.gmtime(time.time() + max_age)
params["expires"] = time.strftime("%a, %d-%b-%Y %T GMT", t)
if not cookie_data:
response.del_cookie(
self._cookie_name, domain=params["domain"], path=params["path"]
)
else:
# Ignoring type for params until aiohttp#4238 is released
response.set_cookie(self._cookie_name, cookie_data, **params)
[docs]
class SimpleCookieStorage(AbstractStorage):
"""Simple JSON storage.
Doesn't any encryption/validation, use it for tests only"""
def __init__(
self,
*,
cookie_name: str = "AIOHTTP_SESSION",
domain: Optional[str] = None,
max_age: Optional[int] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: bool = True,
samesite: Optional[str] = None,
encoder: Callable[[object], str] = json.dumps,
decoder: Callable[[str], Any] = json.loads,
) -> None:
super().__init__(
cookie_name=cookie_name,
domain=domain,
max_age=max_age,
path=path,
secure=secure,
httponly=httponly,
samesite=samesite,
encoder=encoder,
decoder=decoder,
)
async def load_session(self, request: web.Request) -> Session:
cookie = self.load_cookie(request)
if cookie is None:
return Session(None, data=None, new=True, max_age=self.max_age)
data = self._decoder(cookie)
return Session(None, data=data, new=False, max_age=self.max_age)
async def save_session(
self, request: web.Request, response: web.StreamResponse, session: Session
) -> None:
cookie_data = self._encoder(self._get_session_data(session))
self.save_cookie(response, cookie_data, max_age=session.max_age)