Source code for broker.broker_utils.broker_utils.gcp_utils

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""The ``gcp_utils`` module contains common functions used to interact with
GCP resources.
"""

from concurrent.futures import TimeoutError
from google.cloud import bigquery, pubsub_v1, storage
from google.cloud.logging_v2.logger import Logger
from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture
from google.cloud.pubsub_v1.types import PubsubMessage, ReceivedMessage
import json
import pandas as pd
from typing import Callable, List, Optional, Union

pgb_project_id = 'ardent-cycling-243415'


# --- Pub/Sub --- #
[docs]def publish_pubsub( topic_name: str, message: Union[bytes, dict], project_id: Optional[str] = None, attrs: Optional[dict] = None, publisher: Optional[pubsub_v1.PublisherClient] = None ) -> str: """Publish messages to a Pub/Sub topic. Wrapper for `google.cloud.pubsub_v1.PublisherClient().publish()`. See also: https://cloud.google.com/pubsub/docs/publisher#publishing_messages. Args: topic_name: The Pub/Sub topic name for publishing alerts. message: The message to be published. project_id: GCP project ID for the project containing the topic. If None, the environment variable GOOGLE_CLOUD_PROJECT will be used. attrs: Message attributes to be published. publisher: An instantiated PublisherClient. Use this kwarg if you are calling this function repeatedly. The publisher will automatically batch the messages over a small time window (currently 0.05 seconds) to avoid making too many separate requests to the service. This helps increase throughput. See https://googleapis.dev/python/pubsub/1.7.0/publisher/index.html#batching Returns: published message ID """ if project_id is None: project_id = pgb_project_id if publisher is None: publisher = pubsub_v1.PublisherClient() if attrs is None: attrs = {} # enforce bytes type for message if isinstance(message, dict): message = json.dumps(message).encode('utf-8') if not isinstance(message, bytes): raise TypeError('`message` must be bytes or a dict.') topic_path = publisher.topic_path(project_id, topic_name) future = publisher.publish(topic_path, data=message, **attrs) return future.result()
[docs]def pull_pubsub( subscription_name: str, max_messages: int = 1, project_id: Optional[str] = None, msg_only: bool = True, callback: Optional[Callable[[Union[ReceivedMessage, bytes]], bool]] = None, return_count: bool = False, ) -> Union[List[bytes], List[ReceivedMessage]]: """Pull and acknowledge a fixed number of messages from a Pub/Sub topic. Wrapper for the synchronous `google.cloud.pubsub_v1.SubscriberClient().pull()`. See also: https://cloud.google.com/pubsub/docs/pull#synchronous_pull Args: subscription_name: Name of the Pub/Sub subcription to pull from. max_messages: The maximum number of messages to pull. project_id: GCP project ID for the project containing the subscription. If None, the module's `pgb_project_id` will be used. msg_only: Whether to work with and return the message contents only or the full packet. If `return_count` is True, it supersedes the returned object. callback: Function used to process each message. Its input type is determined by the value of `msg_only`. It should return True if the message should be acknowledged, else False. return_count: Whether to return the messages or just the total number of acknowledged messages. Returns: A list of messages """ if project_id is None: project_id = pgb_project_id # setup for pull subscriber = pubsub_v1.SubscriberClient() subscription_path = subscriber.subscription_path(project_id, subscription_name) request = { "subscription": subscription_path, "max_messages": max_messages, } # wrap in 'with' block to automatically call close() when done with subscriber: # pull response = subscriber.pull(**request) # unpack the messages message_list, ack_ids = [], [] for received_message in response.received_messages: if msg_only: # extract the message bytes and append msg_bytes = received_message.message.data message_list.append(msg_bytes) # perform callback, if requested if callback is not None: success = callback(msg_bytes) else: # append the full message message_list.append(received_message) # perform callback, if requested if callback is not None: success = callback(received_message) # collect ack_id, if appropriate if (callback is None) or (success): ack_ids.append(received_message.ack_id) # acknowledge the messages so they will not be sent again if len(ack_ids) > 0: ack_request = { "subscription": subscription_path, "ack_ids": ack_ids, } subscriber.acknowledge(**ack_request) if not return_count: return message_list else: return len(message_list)
[docs]def streamingPull_pubsub( subscription_name: str, callback: Callable[[PubsubMessage], None], project_id: str = None, timeout: int = 10, block: bool = True, flow_control: Optional[dict] = None, ) -> Union[None, StreamingPullFuture]: """Pull and process Pub/Sub messages continuously in a background thread. Wrapper for the asynchronous `google.cloud.pubsub_v1.SubscriberClient().subscribe()`. See also: https://cloud.google.com/pubsub/docs/pull#asynchronous-pull Args: subscription_name: The Pub/Sub subcription to pull from. callback: The callback function containing the message processing and acknowledgement logic. project_id: GCP project ID for the project containing the subscription. If None, the environment variable GOOGLE_CLOUD_PROJECT will be used. timeout: The number of seconds before the `subscribe` call times out and closes the connection. block: Whether to block while streaming messages or return the StreamingPullFuture object for the user to manage separately. Returns: If `block` is False, immediately returns the StreamingPullFuture object that manages the background thread. Call its `cancel()` method to stop streaming messages. If `block` is True, returns None once the streaming encounters an error or timeout. """ if project_id is None: project_id = pgb_project_id if flow_control is None: flow_control = {} subscriber = pubsub_v1.SubscriberClient() subscription_path = subscriber.subscription_path(project_id, subscription_name) # start receiving and processing messages in a background thread streaming_pull_future = subscriber.subscribe( subscription_path, callback, flow_control=flow_control ) if block: # block until timeout duration is reached or an error is encountered with subscriber: try: streaming_pull_future.result(timeout=timeout) except TimeoutError: streaming_pull_future.cancel() # Trigger the shutdown. streaming_pull_future.result() # Block until the shutdown is complete. else: return streaming_pull_future
# --- BigQuery --- #
[docs]def insert_rows_bigquery(table_id: str, rows: List[dict]): """Insert rows into a table using the streaming API. Args: table_id: Identifier for the BigQuery table in the form {dataset}.{table}. For example, 'ztf_alerts.alerts'. rows: Data to load in to the table. Keys must include all required fields in the schema. Keys which do not correspond to a field in the schema are ignored. """ bq_client = bigquery.Client(project=pgb_project_id) table = bq_client.get_table(table_id) errors = bq_client.insert_rows(table, rows) return errors
[docs]def load_dataframe_bigquery( table_id: str, df: pd.DataFrame, use_table_schema: bool = True, logger: Optional[Logger] = None, ): """Load a dataframe to a table. Args: table_id: Identifier for the BigQuery table in the form {dataset}.{table}. For example, 'ztf_alerts.alerts'. df: Data to load in to the table. If the dataframe schema does not match the BigQuery table schema, must pass a valid `schema`. use_table_schema: Conform the dataframe to the table schema by converting dtypes and dropping extra columns. logger: If not None, messages will be sent to the logger. Else, print them. """ # setup bq_client = bigquery.Client(project=pgb_project_id) table = bq_client.get_table(table_id) if use_table_schema: my_df = df.reset_index() # set a job_config; bigquery will try to convert df.dtypes to match table schema job_config = bigquery.LoadJobConfig(schema=table.schema) # make sure the df has the correct columns bq_col_names = [s.name for s in table.schema] # pad missing columns missing = [c for c in bq_col_names if c not in my_df.columns] for col in missing: my_df[col] = None # drop extra columns dropped = list(set(my_df.columns) - set(bq_col_names)) # grab so we can report my_df = my_df[bq_col_names] # tell the user what happened if len(dropped) > 0: msg = f'Dropping columns not in the table schema: {dropped}' if logger is not None: logger.log_text(msg, severity='INFO') else: print(msg) else: my_df = df job_config = None # load the data job = bq_client.load_table_from_dataframe(my_df, table_id, job_config=job_config) job.result() # Wait for the job to complete. # report the results msg = ( f"Loaded {job.output_rows} rows to BigQuery table {table_id}.\n" f"The following errors were generated: {job.errors}" ) if logger is not None: severity = 'DEBUG' if job.errors is not None else 'INFO' logger.log_text(msg, severity=severity) else: print(msg)
[docs]def query_bigquery( query: str, project_id: Optional[str] = None, job_config: Optional[bigquery.job.QueryJobConfig] = None, ) -> bigquery.job.QueryJob: """Query BigQuery. Args: query: SQL query statement. project_id: The GCP project id that will be used to make the API call. If not provided, the Pitt-Google production project id will be used. job_config: Optional job config to send with the query. Example query: `` query = ( f'SELECT * ' f'FROM `{dataset_project_id}.{dataset}.{table}` ' f'WHERE objectId={objectId} ' ) `` Examples of working with the query_job: `` # Cast it to a DataFrame: query_job.to_dataframe() # Iterate row-by-row for r, row in enumerate(query_job): # row values can be accessed by field name or index print(f"objectId={row[0]}, candid={row['candid']}") `` """ if project_id is None: project_id = pgb_project_id bq_client = bigquery.Client(project=project_id) query_job = bq_client.query(query, job_config=job_config) return query_job
# --- Cloud Storage --- #
[docs]def cs_download_file(localdir: str, bucket_id: str, filename: Optional[str] = None): """ Args: localdir: Path to local directory where file(s) will be downloaded to. bucket_id: Name of the GCS bucket, not including the project ID. For example, pass 'ztf-alert_avros' for the bucket 'ardent-cycling-243415-ztf-alert_avros'. filename: Name or prefix of the file(s) in the bucket to download. """ # connect to the bucket and get an iterator that finds blobs in the bucket storage_client = storage.Client(pgb_project_id) bucket_name = f'{pgb_project_id}-{bucket_id}' print(f'Connecting to bucket {bucket_name}') bucket = storage_client.get_bucket(bucket_name) blobs = storage_client.list_blobs(bucket, prefix=filename) # iterator # download the files for blob in blobs: local_path = f'{localdir}/{blob.name}' blob.download_to_filename(local_path) print(f'Downloaded {local_path}')
[docs]def cs_upload_file(local_file: str, bucket_id: str, bucket_filename: Optional[str] = None): """ Args: local_file: Path of the file to upload. bucket_id: Name of the GCS bucket, not including the project ID. For example, pass 'ztf-alert_avros' for the bucket 'ardent-cycling-243415-ztf-alert_avros'. bucket_filename: String to name the file in the bucket. If None, bucket_filename = local_filename. """ if bucket_filename is None: bucket_filename = local_file.split('/')[-1] # connect to the bucket storage_client = storage.Client(pgb_project_id) bucket_name = f'{pgb_project_id}-{bucket_id}' print(f'Connecting to bucket {bucket_name}') bucket = storage_client.get_bucket(bucket_name) # upload blob = bucket.blob(bucket_filename) blob.upload_from_filename(local_file) print(f'Uploaded {local_file} as {bucket_filename}')