# Copyright (c) 2021 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from logging import getLogger
from json.decoder import JSONDecodeError
import re
from typing import Dict, Tuple, cast, Optional
import websocket # type: ignore
import requests
from spinn_utilities.log import FormatAdapter
from spinn_utilities.typing.json import JsonObject
from spinnman.exceptions import SpallocException
from .utils import clean_url
logger = FormatAdapter(getLogger(__name__))
#: The name of the session cookie issued by Spring Security
_SESSION_COOKIE = "JSESSIONID"
#: Enable detailed debugging by setting to True
_debug_pretty_print = False
def _may_renew(method):
def pp_req(request: requests.PreparedRequest):
"""
:param ~requests.PreparedRequest request:
"""
print(">>>>>>>>>>>START>>>>>>>>>>>\n")
print(f"{request.method} {request.url}")
# pylint: disable=consider-using-f-string
print('\r\n'.join('{}: {}'.format(*kv)
for kv in request.headers.items()))
if request.body:
print(request.body)
def pp_resp(response: requests.Response):
"""
:param ~requests.Response response:
"""
# pylint: disable=consider-using-f-string
print('{}\n{}\r\n{}\r\n\r\n{}'.format(
'<<<<<<<<<<<START<<<<<<<<<<<',
str(response.status_code) + " " + response.reason,
'\r\n'.join('{}: {}'.format(*kv)
for kv in response.headers.items()),
# Assume we only get textual responses
str(response.content, "UTF-8") if response.content else ""))
@wraps(method)
def call(self, *args, **kwargs):
renew_count = 0
while True:
r = method(self, *args, **kwargs)
if _debug_pretty_print:
pp_req(r.request)
pp_resp(r)
if _SESSION_COOKIE in r.cookies:
# pylint: disable=protected-access
self._session_id = r.cookies[_SESSION_COOKIE]
if r.status_code != 401 or not renew_count:
return r
self.renew()
renew_count += 1
return call
[docs]
class Session:
"""
Manages session credentials for the Spalloc client.
.. warning::
This class does not present a stable API for public consumption.
"""
__slots__ = (
"__login_form_url", "__login_submit_url", "__srv_base", "_service_url",
"__username", "__password", "__token",
"_session_id", "__csrf", "__csrf_header")
def __init__(
self, service_url: str,
username: Optional[str] = None, password: Optional[str] = None,
token: Optional[str] = None,
session_credentials: Optional[
Tuple[Dict[str, str], Dict[str, str]]] = None):
"""
:param str service_url: The reference to the service.
*Should not* include a username or password in it.
:param str username: The user name to use
:param str password: The password to use
:param str token: The bearer token to use
"""
url = clean_url(service_url)
self.__login_form_url = url + "system/login.html"
self.__login_submit_url = url + "system/perform_login"
self._service_url = url
self.__srv_base = url + "srv/spalloc/"
self.__username = username
self.__password = password
self.__token = token
if session_credentials:
cookies, headers = session_credentials
if _SESSION_COOKIE in cookies:
self._session_id = cookies[_SESSION_COOKIE]
for key, value in headers.items():
if key == "Authorization":
# TODO: extract this?
pass
else:
self.__csrf_header = key
self.__csrf = value
def __handle_error_or_return(self, response: requests.Response):
code = response.status_code
if code >= 200 and code < 400:
return response
result = response.content
raise ValueError(f"Unexpected response from server {code}\n"
f" {str(result)}")
[docs]
@_may_renew
def get(self, url: str, timeout: int = 10, **kwargs) -> requests.Response:
"""
Do an HTTP ``GET`` in the session.
:param str url:
:param int timeout:
:rtype: ~requests.Response
:raise ValueError: If the server rejects a request
"""
params = kwargs if kwargs else None
cookies = {_SESSION_COOKIE: self._session_id}
r = requests.get(url, params=params, cookies=cookies,
allow_redirects=False, timeout=timeout)
logger.debug("GET {} returned {}", url, r.status_code)
return self.__handle_error_or_return(r)
[docs]
@_may_renew
def post(self, url: str, json_dict: dict, timeout: int = 10,
**kwargs) -> requests.Response:
"""
Do an HTTP ``POST`` in the session.
:param str url:
:param int timeout:
:param dict json_dict:
:rtype: ~requests.Response
:raise ValueError: If the server rejects a request
"""
params = kwargs if kwargs else None
cookies, headers = self._credentials
r = requests.post(url, params=params, json=json_dict,
cookies=cookies, headers=headers,
allow_redirects=False, timeout=timeout)
logger.debug("POST {} returned {}", url, r.status_code)
return self.__handle_error_or_return(r)
[docs]
@_may_renew
def put(self, url: str, data: str, timeout: int = 10,
**kwargs) -> requests.Response:
"""
Do an HTTP ``PUT`` in the session. Puts plain text *OR* JSON!
:param str url:
:param str data:
:param int timeout:
:rtype: ~requests.Response
:raise ValueError: If the server rejects a request
"""
params = kwargs if kwargs else None
cookies, headers = self._credentials
if isinstance(data, str):
headers["Content-Type"] = "text/plain; charset=UTF-8"
r = requests.put(url, params=params, data=data,
cookies=cookies, headers=headers,
allow_redirects=False, timeout=timeout)
logger.debug("PUT {} returned {}", url, r.status_code)
return self.__handle_error_or_return(r)
[docs]
@_may_renew
def delete(self, url: str, timeout: int = 10,
**kwargs) -> requests.Response:
"""
Do an HTTP ``DELETE`` in the session.
:param str url:
:rtype: ~requests.Response
:raise ValueError: If the server rejects a request
"""
params = kwargs if kwargs else None
cookies, headers = self._credentials
r = requests.delete(url, params=params, cookies=cookies,
headers=headers, allow_redirects=False,
timeout=timeout)
logger.debug("DELETE {} returned {}", url, r.status_code)
return self.__handle_error_or_return(r)
[docs]
def renew(self) -> JsonObject:
"""
Renews the session, logging the user into it so that state modification
operations can be performed.
:returns: Description of the root of the service, without CSRF data
:rtype: dict
:raises SpallocException:
If the session cannot be renewed.
"""
if self.__token:
r = requests.get(
self.__login_form_url,
headers={"Authorization": f"Bearer {self.__token}"},
allow_redirects=False, timeout=10)
if not r.ok:
raise SpallocException(
f"Could not renew session: {cast(str, r.content)}")
self._session_id = r.cookies[_SESSION_COOKIE]
else:
# Step one: a temporary session so we can log in
csrf_matcher = re.compile(
r"""<input type="hidden" name="_csrf" value="(.*)" />""")
r = requests.get(self.__login_form_url, allow_redirects=False,
timeout=10)
logger.debug("GET {} returned {}",
self.__login_form_url, r.status_code)
m = csrf_matcher.search(r.text)
if not m:
raise SpallocException("could not establish temporary session")
csrf = m.group(1)
session = r.cookies[_SESSION_COOKIE]
# Step two: actually do the log in
form = {
"_csrf": csrf,
"username": self.__username,
"password": self.__password,
"submit": "submit"
}
# NB: returns redirect that sets a cookie
r = requests.post(self.__login_submit_url,
cookies={_SESSION_COOKIE: session},
allow_redirects=False,
data=form, timeout=10)
logger.debug("POST {} returned {}",
self.__login_submit_url, r.status_code)
try:
self._session_id = r.cookies[_SESSION_COOKIE]
except KeyError as e:
try:
json_error = r.json()
if 'message' in json_error:
error = json_error['message']
else:
error = str(json_error)
except JSONDecodeError:
error = r.raw
raise SpallocException(f"Unable to login: {error}") from e
# Step three: get the basic service data and new CSRF token
obj: JsonObject = self.get(self.__srv_base).json()
self.__csrf_header = cast(str, obj["csrf-header"])
self.__csrf = cast(str, obj["csrf-token"])
del obj["csrf-header"]
del obj["csrf-token"]
return obj
@property
def _credentials(self) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
The credentials for requests. *Serializable.*
"""
cookies = {_SESSION_COOKIE: self._session_id}
headers = {self.__csrf_header: self.__csrf}
if self.__token:
# This would be better off done once per session only
headers["Authorization"] = f"Bearer {self.__token}"
return cookies, headers
[docs]
def websocket(
self, url: str, header: Optional[dict] = None,
cookie: Optional[str] = None, **kwargs) -> websocket.WebSocket:
"""
Create a websocket that uses the session credentials to establish
itself.
:param str url: Actual location to open websocket at
:param dict(str,str) header: Optional HTTP headers
:param str cookie:
Optional cookies (composed as semicolon-separated string)
:param kwargs: Other options to :py:func:`~websocket.create_connection`
:rtype: ~websocket.WebSocket
"""
# Note: *NOT* a renewable action!
if header is None:
header = {}
header[self.__csrf_header] = self.__csrf
if cookie is not None:
cookie += ";" + _SESSION_COOKIE + "=" + self._session_id
else:
cookie = _SESSION_COOKIE + "=" + self._session_id
return websocket.create_connection(
url, header=header, cookie=cookie, **kwargs)
def _purge(self):
"""
Clears out all credentials from this session, rendering the session
completely inoperable henceforth.
"""
self.__username = None
self.__password = None
self._session_id = None
self.__csrf = None
[docs]
class SessionAware:
"""
Connects to the session.
.. warning::
This class does not present a stable API for public consumption.
"""
__slots__ = ("__session", "_url")
def __init__(self, session: Session, url: str):
self.__session = session
self._url = clean_url(url)
@property
def _session_credentials(self):
"""
The current session credentials.
Only supposed to be called by subclasses.
:rtype: tuple(dict(str,str),dict(str,str))
"""
# pylint: disable=protected-access
return self.__session._credentials
@property
def _service_url(self):
"""
The main service URL.
:rtype: str
"""
# pylint: disable=protected-access
return self.__session._service_url
def _get(self, url: str, **kwargs) -> requests.Response:
return self.__session.get(url, **kwargs)
def _post(self, url: str, json_dict: dict, **kwargs) -> requests.Response:
return self.__session.post(url, json_dict, **kwargs)
def _put(self, url: str, data: str, **kwargs) -> requests.Response:
return self.__session.put(url, data, **kwargs)
def _delete(self, url: str, **kwargs) -> requests.Response:
return self.__session.delete(url, **kwargs)
def _websocket(self, url: str, **kwargs) -> websocket.WebSocket:
"""
Create a websocket that uses the session credentials to establish
itself.
:param str url: Actual location to open websocket at
:rtype: ~websocket.WebSocket
"""
return self.__session.websocket(url, **kwargs)