Source code for redup_servicekit.grpc.client

"""Async gRPC client with TLS, compression and configurable limits.

The :mod:`redup_servicekit.grpc.client` module contains:

- :class:`redup_servicekit.grpc.client.BasicAsyncClient`
"""
from typing import Tuple
from urllib.parse import urlparse

import grpc

DEFAULT_MAX_MESSAGE_LENGTH = 100 * 1024 * 1024

COMPRESSION = {
    "gzip": grpc.Compression.Gzip,
    "deflate": grpc.Compression.Deflate,
    "no_compression": grpc.Compression.NoCompression,
}

KEEPALIVE_MS = 10000
KEEPALIVE_TIMEOUT_MS = 5000
KEEPALIVE_PERMIT_WITHOUT_CALLS = 1
HTTP2_MAX_PINGS_WITHOUT_DATA = 0


[docs] class BasicAsyncClient: r"""Async gRPC client with on-demand channels, TLS and compression. :param host: Server address with protocol: ``grpc://host:port`` (non-TLS) or ``https://host:port`` (TLS). Do not omit the scheme. :type host: str :param ServiceStub: Generated gRPC stub class (e.g. ``MyServiceStub``). :param max_message_length: Max send/receive message length in bytes. Default 100 MiB. :type max_message_length: int :param request_compression_algo: ``gzip``, ``deflate``, or ``no_compression``. :type request_compression_algo: str :param response_compression_algo: Same options; sent as metadata. :type response_compression_algo: str Example: >>> from redup_servicekit.grpc import BasicAsyncClient >>> client = BasicAsyncClient("grpc://localhost:50051", MyServiceStub, request_compression_algo="gzip") >>> response = await client.send(MyRequest(field="value"), Method="MyRpc", timeout=30) """ @staticmethod def _channel_options(max_message_length: int): return [ ("grpc.max_send_message_length", max_message_length), ("grpc.max_message_length", max_message_length), ("grpc.max_receive_message_length", max_message_length), ("grpc.keepalive_time_ms", KEEPALIVE_MS), ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), ("grpc.http2.max_pings_without_data", HTTP2_MAX_PINGS_WITHOUT_DATA), ("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS), ("grpc.lb_policy_name", "round_robin"), ] @staticmethod def _parse_endpoint_and_tls(host: str): parsed = urlparse(host) location = parsed.netloc.strip() or parsed.path return location, parsed.scheme == "https"
[docs] def __init__( self, host: str, ServiceStub, max_message_length: int = DEFAULT_MAX_MESSAGE_LENGTH, request_compression_algo: str = "gzip", response_compression_algo: str = "gzip", ): self._server_address = host self._stub_cls = ServiceStub self._compression_in = request_compression_algo self._compression_out = response_compression_algo self._base_metadata = (("response-compression", response_compression_algo),) self._channel_opts = BasicAsyncClient._channel_options(max_message_length) self._endpoint, self._use_tls = BasicAsyncClient._parse_endpoint_and_tls(host) self._open_channels = {}
async def _invoke(self, channel, method_name: str, request, metadata, timeout, stream: bool): stub = self._stub_cls(channel) stub_method = getattr(stub, method_name) if stream: async for response_message in stub_method(request, metadata=metadata, timeout=timeout): yield response_message else: yield await stub_method(request, metadata=metadata, timeout=timeout)
[docs] async def send( self, request, Method: str, metadata: Tuple[Tuple[str, str], ...] = (), max_message_length: int = DEFAULT_MAX_MESSAGE_LENGTH, timeout=None, stream: bool = False, ): r"""Send a request and return a single response (unary) or async stream. :param request: The request message instance. :param Method: Method name string (e.g. ``"MyRpc"``). :type Method: str :param metadata: Optional tuple of (key, value) metadata. :param max_message_length: Override max message length for this call. :param timeout: Optional timeout in seconds. :param stream: If True, return an async generator of response messages. :type stream: bool :return: Single response message (unary) or async generator (stream). """ request_metadata = metadata + self._base_metadata channel_options = BasicAsyncClient._channel_options(max_message_length) if max_message_length != DEFAULT_MAX_MESSAGE_LENGTH else self._channel_opts if self._use_tls: ssl_credentials = grpc.ssl_channel_credentials() compression_algorithm = COMPRESSION.get(self._compression_in, grpc.Compression.Gzip) else: ssl_credentials = None compression_algorithm = COMPRESSION.get(self._compression_in, grpc.Compression.Gzip) if not stream: if self._use_tls: async with grpc.aio.secure_channel(self._endpoint, ssl_credentials, options=channel_options, compression=compression_algorithm) as channel: response_generator = self._invoke(channel, Method, request, request_metadata, timeout, False) return await response_generator.__anext__() async with grpc.aio.insecure_channel(self._endpoint, options=channel_options, compression=compression_algorithm) as channel: response_generator = self._invoke(channel, Method, request, request_metadata, timeout, False) return await response_generator.__anext__() if self._use_tls: channel = grpc.aio.secure_channel(self._endpoint, ssl_credentials, options=channel_options, compression=compression_algorithm) else: channel = grpc.aio.insecure_channel(self._endpoint, options=channel_options, compression=compression_algorithm) channel_identifier = id(channel) self._open_channels[channel_identifier] = channel async def stream_response_generator(): try: async for response_message in self._invoke(channel, Method, request, request_metadata, timeout, True): yield response_message finally: self._open_channels.pop(channel_identifier, None) return stream_response_generator()