Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tracer based on experience with testing #437

Merged
merged 3 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions monitoring/monitorlib/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
url = '{}?grant_type=client_credentials&scope={}&intended_audience={}&issuer=dummy&sub={}'.format(
self._oauth_token_endpoint, urllib.parse.quote(' '.join(scopes)),
urllib.parse.quote(intended_audience), self._sub)
response = self._oauth_session.post(url).json()
return response['access_token']
response = self._oauth_session.post(url)
if response.status_code != 200:
raise AccessTokenError('Request to get access token returned {} {}'.format(response.status_code, response.content.decode('utf-8')))
return response.json()['access_token']


class ServiceAccount(AuthAdapter):
Expand All @@ -55,8 +57,10 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
url = '{}?grant_type=client_credentials&scope={}&intended_audience={}'.format(
self._oauth_token_endpoint, urllib.parse.quote(' '.join(scopes)),
urllib.parse.quote(intended_audience))
response = self._oauth_session.post(url).json()
return response['access_token']
response = self._oauth_session.post(url)
if response.status_code != 200:
raise AccessTokenError('Request to get access token returned {} {}'.format(response.status_code, response.content.decode('utf-8')))
return response.json()['access_token']


class UsernamePassword(AuthAdapter):
Expand All @@ -79,8 +83,10 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
'password': self._password,
'client_id': self._client_id,
'scope': ' '.join(scopes),
}).json()
return response['access_token']
})
if response.status_code != 200:
raise AccessTokenError('Request to get access token returned {} {}'.format(response.status_code, response.content.decode('utf-8')))
return response.json()['access_token']


class SignedRequest(AuthAdapter):
Expand Down Expand Up @@ -116,7 +122,7 @@ def __init__(self, token_endpoint: str, client_id: str, key_path: str, cert_url:
elif cert_url[-4:].lower() == '.crt':
cert = cryptography.x509.load_pem_x509_certificate(response.content, self._backend)
else:
raise ValueError('cert_url must end with .der or .crt')
raise AccessTokenError('cert_url must end with .der or .crt')
cert_public_key = cert.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo)
Expand All @@ -129,13 +135,13 @@ def __init__(self, token_endpoint: str, client_id: str, key_path: str, cert_url:
key_content, password=None, backend=self._backend)
private_key_bytes = key_content
else:
raise ValueError('key_path must end with .key or .pem')
raise AccessTokenError('key_path must end with .key or .pem')
public_key = private_key.public_key().public_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo)

if cert_public_key != public_key:
raise ValueError('Public key in certificate does not match private key provided')
raise AccessTokenError('Public key in certificate does not match private key provided')

self._private_jwk = jwcrypto.jwk.JWK.from_pem(private_key_bytes)
self._public_jwk = jwcrypto.jwk.JWK.from_pem(public_key)
Expand Down Expand Up @@ -175,7 +181,7 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
try:
jws_check.verify(self._public_jwk, 'RS256')
except jwcrypto.jws.InvalidJWSSignature:
raise ValueError('Could not construct a valid cryptographic signature for JWS')
raise AccessTokenError('Could not construct a valid cryptographic signature for JWS')

# Construct signature
signature = re.sub(r'\.[^.]*\.', '..', signed)
Expand All @@ -187,7 +193,7 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
}
response = requests.post(self._token_endpoint, data=payload, headers=request_headers)
if response.status_code != 200:
raise ValueError('Unable to retrieve access token:\n' + response.content.decode('utf-8'))
raise AccessTokenError('Unable to retrieve access token:\n' + response.content.decode('utf-8'))
return response.json()['access_token']


Expand All @@ -211,10 +217,15 @@ def issue_token(self, intended_audience: str, scopes: List[str]) -> str:
'scope': ' '.join(scopes),
})
if response.status_code != 200:
raise ValueError('Unable to retrieve access token:\n' + response.content.decode('utf-8'))
raise AccessTokenError('Unable to retrieve access token:\n' + response.content.decode('utf-8'))
return response.json()['access_token']


class AccessTokenError(RuntimeError):
def __init__(self, msg):
super(AccessTokenError, self).__init__(msg)


def make_auth_adapter(spec: str) -> AuthAdapter:
"""Make an AuthAdapter according to a string specification.

Expand Down
19 changes: 18 additions & 1 deletion monitoring/monitorlib/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_headers(self, url: str, scopes: List[str] = None) -> Dict[str, str]:
else:
token = self._tokens[intended_audience][scope_string]
payload = jwt.decode(token, verify=False)
expires = EPOCH + datetime.timedelta(milliseconds=payload['exp'])
expires = EPOCH + datetime.timedelta(seconds=payload['exp'])
if expires < datetime.datetime.utcnow() - TOKEN_REFRESH:
token = self.issue_token(intended_audience, scopes)
self._tokens[intended_audience][scope_string] = token
Expand Down Expand Up @@ -149,3 +149,20 @@ def default_scope(scope: str):
decorated test.
"""
return default_scopes([scope])


def get_token_claims(headers: Dict) -> Dict:
auth_key = [key for key in headers if key.lower() == 'authorization']
if len(auth_key) == 0:
return {'error': 'Missing Authorization header'}
if len(auth_key) > 1:
return {'error': 'Multiple Authorization headers: ' + ', '.join(auth_key)}
token: str = headers[auth_key[0]]
if token.lower().startswith('bearer '):
token = token[len('bearer '):]
try:
return jwt.decode(token, verify=False)
except ValueError as e:
return {'error': 'ValueError: ' + str(e)}
except jwt.exceptions.DecodeError as e:
return {'error': 'DecodeError: ' + str(e)}
4 changes: 3 additions & 1 deletion monitoring/tracer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ a single polling period, this tool would not create an record of that ISA.

### Invocation
```shell script
docker run --rm -v `pwd`/logs:/logs interuss/dss/tracer \
docker run --name tracer_run --rm -v `pwd`/logs:/logs interuss/dss/tracer \
python python tracer_poll.py \
--auth=<SPEC> \
--dss=https://example.com \
Expand Down Expand Up @@ -58,6 +58,8 @@ notifications upon DSS prompting.
### Invocation
Make a copy of [`run_subscribe.sh`](run_subscribe.sh) and edit the arguments as
appropriate. Then simply run your copy of that script (`./run_subscribe.sh`).
To stop this container gracefully (so that Subscriptions are removed):
`docker container kill --signal=INT tracer_subscribe`

### External route
One important argument in subscribe mode is `--base-url`. This should be the
Expand Down
49 changes: 18 additions & 31 deletions monitoring/tracer/check_rid_flights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!env/bin/python3

import argparse
import datetime
import logging
from typing import Dict

Expand All @@ -9,7 +10,7 @@
import yaml

from monitoring.monitorlib import rid
from monitoring.tracer import polling
from monitoring.tracer import formatting, polling
from monitoring.tracer.resources import ResourceSet


Expand All @@ -18,32 +19,8 @@
_logger.setLevel(logging.DEBUG)


def _json_or_error(resp: requests.Response) -> Dict:
try:
json = resp.json()
except ValueError:
json = None
if resp == 200 and json:
return json
else:
info = {
'request': {
'url': resp.request.url,
'Authorization': resp.request.headers.get('Authorization', '<None>'),
},
'response': {
'code': resp.status_code,
'elapsed': resp.elapsed.total_seconds()
}
}
if json is None:
info['response']['body'] = resp.content
else:
info['response']['json'] = json
return info


def get_flights(resources: ResourceSet, flights_url: str, area: s2sphere.LatLngRect, include_recent_positions: bool) -> Dict:
t0 = datetime.datetime.utcnow()
resp = resources.dss_client.get(flights_url, params={
'view': '{},{},{},{}'.format(
area.lat_lo().degrees,
Expand All @@ -53,12 +30,19 @@ def get_flights(resources: ResourceSet, flights_url: str, area: s2sphere.LatLngR
),
'include_recent_positions': 'true' if include_recent_positions else 'false',
}, scope=rid.SCOPE_READ)
return _json_or_error(resp)
return {
'request': formatting.describe_request(resp.request, t0),
'response': formatting.describe_response(resp),
}


def get_flight_details(resources: ResourceSet, flights_url: str, id: str) -> Dict:
t0 = datetime.datetime.utcnow()
resp = resources.dss_client.get(flights_url + '/{}/details'.format(id), scope=rid.SCOPE_READ)
return _json_or_error(resp)
return {
'request': formatting.describe_request(resp.request, t0),
'response': formatting.describe_response(resp),
}


def get_all_flights(resources: ResourceSet, area: s2sphere.LatLngRect, include_recent_positions: bool) -> Dict:
Expand All @@ -85,15 +69,18 @@ def get_all_flights(resources: ResourceSet, area: s2sphere.LatLngRect, include_r
continue
isa_flights = get_flights(resources, flights_url, area, include_recent_positions)
if 'flights' not in isa_flights['response'].get('json', {}):
isa_flights['description'] = 'Missing flights field'
if isa_flights['response']['code'] != 200:
isa_flights['description'] = 'USS returned {}'.format(isa_flights['response']['code'])
else:
isa_flights['description'] = 'Missing flights field'
result[isa_id] = {'error': isa_flights}
continue
for flight in isa_flights['response']['json']['flights']:
flight_id = flight.get('id', None)
if flight_id is None:
flight['details'] = {'error': {'description': 'Missing id field'}}
flight['details (separate query)'] = {'error': {'description': 'Missing id field'}}
continue
flight['details'] = get_flight_details(resources, flights_url, flight['id'])
flight['details (separate query)'] = get_flight_details(resources, flights_url, flight['id'])
result[isa_id] = isa_flights

return result
Expand Down
38 changes: 38 additions & 0 deletions monitoring/tracer/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import json
from typing import Dict, List, Optional, Tuple

import requests
from termcolor import colored

from monitoring.monitorlib import infrastructure


class Change(enum.Enum):
NOCHANGE = 0
Expand Down Expand Up @@ -205,3 +208,38 @@ def format_timedelta(td: datetime.timedelta) -> str:
return sign + '{:s}d{:s}:{:s}:{:s}'.format(*segments)
else:
return sign + '{:s}:{:s}:{:s}'.format(*segments[1:])


def describe_request(req: requests.PreparedRequest, initiated_at: datetime.datetime) -> Dict:
headers = {k: v for k, v in req.headers.items()}
info = {
'method': req.method,
'url': req.url,
'initiated_at': initiated_at.isoformat(),
'token': infrastructure.get_token_claims(headers),
'headers': headers,
}
body = req.body.decode('utf-8') if req.body else None
try:
if body:
info['json'] = json.loads(body)
else:
info['body'] = body
except ValueError:
info['body'] = body
return info


def describe_response(resp: requests.Response):
headers = {k: v for k, v in resp.headers.items()}
info = {
'code': resp.status_code,
'headers': headers,
'elapsed_s': resp.elapsed.total_seconds(),
'reported': datetime.datetime.utcnow().isoformat(),
}
try:
info['json'] = resp.json()
except ValueError:
info['body'] = resp.content.decode('utf-8')
return info
Loading