Source code for broker.broker_utils.broker_utils.testing

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""Classes and functions used for Pitt-Google broker testing."""

import base64
import logging
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import NamedTuple, Union

import numpy as np

from .data_utils import decode_alert, load_alert
from .gcp_utils import pull_pubsub
from .types import AlertIds, _AlertIds

logger = logging.getLogger(__name__)

[docs]@dataclass class AlertPaths: """Paths to alerts stored locally.""" def __init__(self, alert_dir: Union[str, Path] = "ztf"): """Load path to alerts directory and initialize variables. Args: alert_dir: Either the path to a directory containing alerts or one of the following strings. "ALERT_DIR" or "alert_dir": The path will be obtained from the environment variable ALERT_DIR. "ztf" or "elasticc": The path will be obtained from the environment variable given by ``f"ALERT_DIR_{alert_dir.upper()}"`` """ if alert_dir in ["ALERT_DIR", "alert_dir"]: self.alert_dir = Path(os.getenv("ALERT_DIR")) elif alert_dir in ["ztf", "elasticc"]: self.alert_dir = Path(os.getenv(f"ALERT_DIR_{alert_dir.upper()}")) else: self.alert_dir = Path(alert_dir) self.__mygen = None
[docs] def gen(self): """Return a new generator of paths to Avro files in self.alert_dir.""" return (p for p in self.alert_dir.glob("**/*.avro"))
@property def path(self): """Return a path to an alert Avro file. This can be used to obtain a random path to an alert, or within a loop to continuously iterate over paths in the alert directory. The returned path is the next one in self._mygen (a path generator maintained internally). If you want more control (e.g., to ensure that you iterate over all files in the directory and/or do not recieve duplicate paths), use ``AlertPaths().gen()`` instead. """ return next(self._mygen) @property def _mygen(self): """Path generator used by self.""" if self.__mygen is None: self.__mygen = self.gen() return self.__mygen
[docs]@dataclass class Mock: """Mock data and objects useful for testing the broker.""" def __init__(self, **kwargs): """Initialize attributes. kwargs may contain keys: modules: List of module names to mock. test_alert: An instance of TestAlert. This will be used to generate appropriate mock data (e.g., Pub/Sub message attributes containing the TestAlert's alert IDs). args and kwargs for TestAlert: These will be used to create a new instance of TestAlert. """ self.kwargs = kwargs self._modules = kwargs.get("modules") self._module_results = None self._my_test_alert = kwargs.get("test_alert") self._id_tuples = None self._attrs = None self._cfinput = None # modules to mock @property def modules(self): """List of modules for which mock results should be created. If this is set manually, self._module_results will be recreated to accommodate. """ return self._modules @modules.setter def modules(self, value): self._modules = value # force recreation of other properties self._module_results = None @modules.deleter def modules(self): self._modules = None # force recreation of other properties self._module_results = None # mocked module results @property def module_results(self): """Mock results for the pipeline module(s) given by self.modules.""" if self._module_results is None: if self.modules is not None: self._module_results = self._generate_module_results() return self._module_results @module_results.setter def module_results(self, value): self._module_results = value @module_results.deleter def module_results(self): self._module_results = None def _generate_module_results(self): results = {} if "SuperNNova" in self.modules: results["SuperNNova"] = self._supernnova_results self._module_results = results return self._module_results @property def _supernnova_results(self): """Return mocked results for SuperNNova module.""" prob_class1 = np.random.uniform() sid, oid = self.id_tuples return dict( [ (oid[0], oid[1]), (sid[0], sid[1]), ("prob_class0", 1 - prob_class1), ("prob_class1", prob_class1), ("predicted_class", round(prob_class1)), ] ) @property def my_test_alert(self): """Instance of TestAlert used by self to generate appropriate data.""" if self._my_test_alert is None: kwargs = dict(self.kwargs) path = kwargs.pop("path", AlertPaths().path) schema_map = kwargs.pop("schema_map") self._my_test_alert = TestAlert(path, schema_map, **kwargs) return self._my_test_alert @my_test_alert.setter def my_test_alert(self, value): self._my_test_alert = value # force recreation of other properties self._id_tuples = None self._cfinput = None @my_test_alert.deleter def my_test_alert(self): self._my_test_alert = None # force recreation of other properties self._id_tuples = None self._cfinput = None # tuples of mocked alert ID keys and IDs @property def id_tuples(self): """Tuples of alert ID keys and IDs.""" if self._id_tuples is None: if self._my_test_alert is None: sid = ("sourceId", f"mocksrc{int(1e9 * np.random.uniform())}") oid = ("objectId", f"mockobj{int(1e9 * np.random.uniform())}") else: sid = ( self._my_test_alert.id_keys.sourceId, str(self._my_test_alert.ids.sourceId), ) oid = ( self._my_test_alert.id_keys.objectId, str(self._my_test_alert.ids.objectId), ) self._id_tuples = (sid, oid) return self._id_tuples @id_tuples.setter def id_tuples(self, value): self._id_tuples = value @id_tuples.deleter def id_tuples(self): self._id_tuples = None # mocked attributes @property def attrs(self): """Mock attributes (dict) for a Pub/Sub message.""" if self._attrs is None: sid, oid = self.id_tuples # create int POSIX timestamps in microseconds now = tkafka = int(now.timestamp() * 1e6) # to convert back: datetime.fromtimestamp(tkafka / 1e6) tingest = int((now + timedelta(seconds=0.345678)).timestamp() * 1e6) self._attrs = dict( [ sid, oid, ("kafka.topic", "mock_topic_pittgoogle_test"), ("kafka.timestamp", str(tkafka)), ("ingest.timestamp", str(tingest)), ] ) return self._attrs @attrs.setter def attrs(self, value): self._attrs = value @attrs.deleter def attrs(self): self._attrs = None # mocked Cloud Functions input class _CFInput(NamedTuple): """Alert data for input to a Cloud Function.""" msg: dict context: NamedTuple class _CFContext(NamedTuple): """Mock context input for a Cloud Function.""" event_id = int(1e12 * np.random.uniform()) timestamp = ( .isoformat(timespec="milliseconds") .replace("+00:00", "Z") ) # event_type = "google.pubsub.topic.publish" event_type = "broker_utils.testing.Mock" resource = "mock_resource" @property def cfinput(self): """Mock input to a Cloud Function.""" if self._cfinput is None: # return immediately if no schema map was provided if self.kwargs.get("schema_map") is None: logger.error("A schema map is required to mock Cloud Function input.") return None self._cfinput = self._CFInput( msg=dict( data=base64.b64encode(self.my_test_alert.msg_payload), attributes=self.my_test_alert.mock.attrs, ), context=self._CFContext(), ) return self._cfinput @cfinput.setter def cfinput(self, value): self._cfinput = value @cfinput.deleter def cfinput(self): self._cfinput = None
[docs]@dataclass class TestAlert: """An alert packet and related functions useful for testing the broker.""" def __init__(self, path, schema_map, **kwargs): """Load the alert from `path` and initialize attributes. kwargs may contain keys: drop_cutouts (bool): Wether to drop the image cutouts from the alert. serialize (str): "avro" or "json". Serialization format of the Pub/Sub message. mock (Mock): An instance of Mock containing mocked data which should be attached to the message. mock_modules (List[str]): List of modules for which data should be mocked. This will be passed to Mock() as keyword argument "modules". This is ignored if an instance of Mock is passed using the "mock" keyword argument. """ self.path = path abytes = load_alert(path, "bytes") adict = decode_alert( abytes, drop_cutouts=kwargs.get("drop_cutouts", True), schema_map=schema_map ) = dict(bytes=abytes, dict=adict) aids = AlertIds(schema_map,["dict"]) self.id_keys = aids.id_keys self.ids = aids.ids self._serialize = kwargs.get("serialize", "json") self._msg_payload = None self._mock_modules = kwargs.get("mock_modules", None) self._mock = kwargs.get("mock", None) # message payload @property def msg_payload(self): """Pub/Sub message payload containing the alert and any mocked results.""" if self._msg_payload is None: # create the message payload # alert data type appropriate for the format given by self.serialize dtype = {"json": "dict", "avro": "bytes"} alert =[dtype[self.serialize]] if self.serialize == "avro": self._msg_payload = alert # currently can't publish an Avro serialized message with mock results if self.mock_modules is not None: logger.warning( ( "The published message will be an Avro serialized alert. " "Mocked results will NOT be attached." ) ) else: if self.mock_modules is None: self._msg_payload = alert else: self._msg_payload = dict(alert=alert, **self.mock.module_results) return self._msg_payload # message serialization format @property def serialize(self): """One of "json" or "avro". Determines the format of the published message. If this is set manually, self._msg_payload will be recreated to accommodate. If this is manually deleted, it will be set to the default ("json") and self._msg_payload will be recreated to accommodate. """ return self._serialize @serialize.setter def serialize(self, value): self._serialize = value # force recreation of other properties self._msg_payload = None @serialize.deleter def serialize(self): self._serialize = "json" # force recreation of other properties self._msg_payload = None
[docs] @staticmethod def guess_serializer(topic): """Use the topic to kguess the format that the message is expected to be in.""" avro_topics = ["alerts"] # all others are json try: topic_name_stub = topic.split("-")[1] except IndexError: topic_name_stub = topic serialize = "avro" if topic_name_stub in avro_topics else "json" return serialize
@property def mock_modules(self): """List of modules for which results should be mocked. If this is set manually, self._mock will be recreated to accommodate. """ return self._mock_modules @mock_modules.setter def mock_modules(self, value): self._mock_modules = value # force recreation of other properties self._mock = None @mock_modules.deleter def mock_modules(self): self._mock_modules = None # force recreation of other properties self._mock = None @property def mock(self): """Mock data. If this is set manually, this will also set self.mock_modules = mock.modules """ if self._mock is None: self._mock = Mock(modules=self.mock_modules, test_alert=self) return self._mock @mock.setter def mock(self, value): self._mock = value # set dependent attributes self.mock_modules = value.modules @mock.deleter def mock(self): self._mock = None
[docs]@dataclass class IntegrationTestValidator: """Functions to validate an integration test.""" def __init__(self, subscrip, published_alert_ids, schema_map, **kwargs): """Initialize attributes. kwargs can include keys: max_pulls (int): maximum number of times to pull the Pub/Sub subscription before quitting. """ self.subscrip = subscrip self.max_pulls = kwargs.get("max_pulls", 4) self.pulled_msg_ids = None self.published_alert_ids = published_alert_ids self.unmatched_ids = None self.alert_ids = AlertIds(schema_map) # load once, use to extract all msg IDs
[docs] def run(self): """Pull the subscription and validate that ids match the published alerts.""" self.pulled_msg_ids = self._pull() self.success, self.unmatched_ids = self._compare_ids() return self.success
def _pull(self): pulled_msg_ids, i = [], 0 while len(pulled_msg_ids) < len(self.published_alert_ids): max_messages = len(self.published_alert_ids) - len(pulled_msg_ids) msgs = pull_pubsub(self.subscrip, max_messages=max_messages, msg_only=False) if len(msgs) > 0: pulled_msg_ids += self._extract_ids(msgs) i += 1 if i >= self.max_pulls: break return pulled_msg_ids def _extract_ids(self, msgs): """Extract the alert IDs from msgs. This method first guesses whether the IDs should be extracted from message attributes or alert filenames by checking the objectId attribute of the first message in msgs. Background: Messages generated by a GCS bucket contain an attribute called objectId which contains the filename of the object (i.e., file) in the bucket that triggered the message. This is a name collision with the alert's DIA objectId, which broker pipeline modules place in an attribute of the same name. """ oid0 = msgs[0].message.attributes["objectId"] oid_is_filename = oid0.split(".")[-1] == "avro" if oid_is_filename: pulled_msg_ids = [ self.alert_ids.extract_ids(filename=msg.message.attributes["objectId"]) for msg in msgs ] else: pulled_msg_ids = [ self.alert_ids.extract_ids(attrs=msg.message.attributes) for msg in msgs ] return pulled_msg_ids def _compare_ids(self): idsout = set(self.pulled_msg_ids) # self._pull() gets the ids from the message attributes or filenames, # which are always strings. # convert the published_alert_ids to the same type. idsin = {_AlertIds(*[str(id) for id in ids]) for ids in self.published_alert_ids} # compare ID sets unmatched_ids = idsout.symmetric_difference(idsin) success = len(unmatched_ids) == 0 # log results tmp = f"The message IDs pulled from subscription {self.subscrip}" if success:"Success! %s match the input.", tmp) else: logger.warning("Something went wrong. %s do not match the input.", tmp) # warn if idsout contains IDs that were not published if len(idsout.difference(idsin)) > 0: logger.warning( ( "Some IDs were pulled that were not reported as published. " "You may want to purge the subcription using\n" "\tbroker_utils.gcp_utils.purge_subscription(%s)\n" "and run the test again." ), self.subscrip ) return success, unmatched_ids