import asyncio import logging import traceback import time from asyncio import StreamReader, StreamWriter from typing import Self from itertools import count from proxy import Proxy from byte_fifo import ByteFifo from async_ifc import AsyncIfc from infos import Infos import gc logger = logging.getLogger('conn') class AsyncIfcImpl(AsyncIfc): _ids = count(0) def __init__(self) -> None: logger.debug('AsyncIfcImpl.__init__') self.fwd_fifo = ByteFifo() self.tx_fifo = ByteFifo() self.rx_fifo = ByteFifo() self.conn_no = next(self._ids) self.node_id = '' self.timeout_cb = None self.init_new_client_conn_cb = None self.update_header_cb = None def close(self): self.timeout_cb = None self.fwd_fifo.reg_trigger(None) self.tx_fifo.reg_trigger(None) self.rx_fifo.reg_trigger(None) def set_node_id(self, value: str): self.node_id = value def get_conn_no(self): return self.conn_no def tx_add(self, data: bytearray): ''' add data to transmit queue''' self.tx_fifo += data def tx_flush(self): ''' send transmit queue and clears it''' self.tx_fifo() def tx_peek(self, size: int = None) -> bytearray: '''returns size numbers of byte without removing them''' return self.tx_fifo.peek(size) def tx_log(self, level, info): ''' log the transmit queue''' self.tx_fifo.logging(level, info) def tx_clear(self): ''' clear transmit queue''' self.tx_fifo.clear() def tx_len(self): ''' get numner of bytes in the transmit queue''' return len(self.tx_fifo) def fwd_add(self, data: bytearray): ''' add data to forward queue''' self.fwd_fifo += data def fwd_log(self, level, info): ''' log the forward queue''' self.fwd_fifo.logging(level, info) def rx_get(self, size: int = None) -> bytearray: '''removes size numbers of bytes and return them''' return self.rx_fifo.get(size) def rx_peek(self, size: int = None) -> bytearray: '''returns size numbers of byte without removing them''' return self.rx_fifo.peek(size) def rx_log(self, level, info): ''' logs the receive queue''' self.rx_fifo.logging(level, info) def rx_clear(self): ''' clear receive queue''' self.rx_fifo.clear() def rx_len(self): ''' get numner of bytes in the receive queue''' return len(self.rx_fifo) def rx_set_cb(self, callback): self.rx_fifo.reg_trigger(callback) def prot_set_timeout_cb(self, callback): self.timeout_cb = callback def prot_set_init_new_client_conn_cb(self, callback): self.init_new_client_conn_cb = callback def prot_set_update_header_cb(self, callback): self.update_header_cb = callback class StreamPtr(): '''Descr StreamPtr''' def __init__(self, _stream, _ifc=None): self.stream = _stream self.ifc = _ifc @property def ifc(self): return self._ifc @ifc.setter def ifc(self, value): self._ifc = value @property def stream(self): return self._stream @stream.setter def stream(self, value): self._stream = value class AsyncStream(AsyncIfcImpl): MAX_PROC_TIME = 2 '''maximum processing time for a received msg in sec''' MAX_START_TIME = 400 '''maximum time without a received msg in sec''' MAX_INV_IDLE_TIME = 120 '''maximum time without a received msg from the inverter in sec''' MAX_DEF_IDLE_TIME = 360 '''maximum default time without a received msg in sec''' def __init__(self, reader: StreamReader, writer: StreamWriter, rstream: "StreamPtr") -> None: AsyncIfcImpl.__init__(self) logger.debug('AsyncStream.__init__') self.remote = rstream self.tx_fifo.reg_trigger(self.__write_cb) self._reader = reader self._writer = writer self.r_addr = writer.get_extra_info('peername') self.l_addr = writer.get_extra_info('sockname') self.proc_start = None # start processing start timestamp self.proc_max = 0 self.async_publ_mqtt = None # will be set AsyncStreamServer only def __write_cb(self): self._writer.write(self.tx_fifo.get()) def __timeout(self) -> int: if self.timeout_cb: return self.timeout_cb() return 360 async def loop(self) -> Self: """Async loop handler for precessing all received messages""" self.proc_start = time.time() while True: try: self.__calc_proc_time() dead_conn_to = self.__timeout() await asyncio.wait_for(self.__async_read(), dead_conn_to) await self.__async_write() await self.__async_forward() if self.async_publ_mqtt: await self.async_publ_mqtt() except asyncio.TimeoutError: logger.warning(f'[{self.node_id}:{self.conn_no}] Dead ' f'connection timeout ({dead_conn_to}s) ' f'for {self.l_addr}') await self.disc() return self except OSError as error: logger.error(f'[{self.node_id}:{self.conn_no}] ' f'{error} for l{self.l_addr} | ' f'r{self.r_addr}') await self.disc() return self except RuntimeError as error: logger.info(f'[{self.node_id}:{self.conn_no}] ' f'{error} for {self.l_addr}') await self.disc() return self except Exception: Infos.inc_counter('SW_Exception') logger.error( f"Exception for {self.r_addr}:\n" f"{traceback.format_exc()}") await asyncio.sleep(0) # be cooperative to other task def __calc_proc_time(self): if self.proc_start: proc = time.time() - self.proc_start if proc > self.proc_max: self.proc_max = proc self.proc_start = None async def disc(self) -> None: """Async disc handler for graceful disconnect""" if self._writer.is_closing(): return logger.debug(f'AsyncStream.disc() l{self.l_addr} | r{self.r_addr}') self._writer.close() await self._writer.wait_closed() def close(self) -> None: logging.debug(f'AsyncStream.close() l{self.l_addr} | r{self.r_addr}') """close handler for a no waiting disconnect hint: must be called before releasing the connection instance """ super().close() self._reader.feed_eof() # abort awaited read if self._writer.is_closing(): return self._writer.close() def healthy(self) -> bool: elapsed = 0 if self.proc_start is not None: elapsed = time.time() - self.proc_start if elapsed > self.MAX_PROC_TIME: logging.debug(f'[{self.node_id}:{self.conn_no}:' f'{type(self).__name__}]' f' act:{round(1000*elapsed)}ms' f' max:{round(1000*self.proc_max)}ms') logging.debug(f'Healthy()) refs: {gc.get_referrers(self)}') return elapsed < 5 ''' Our private methods ''' async def __async_read(self) -> None: """Async read handler to read received data from TCP stream""" data = await self._reader.read(4096) if data: self.proc_start = time.time() self.rx_fifo += data wait = self.rx_fifo() # call read in parent class if wait and wait > 0: await asyncio.sleep(wait) else: raise RuntimeError("Peer closed.") async def __async_write(self, headline: str = 'Transmit to ') -> None: """Async write handler to transmit the send_buffer""" if len(self.tx_fifo) > 0: self.tx_fifo.logging(logging.INFO, f'{headline}{self.r_addr}:') self._writer.write(self.tx_fifo.get()) await self._writer.drain() async def __async_forward(self) -> None: """forward handler transmits data over the remote connection""" if len(self.fwd_fifo) == 0: return try: await self._async_forward() except OSError as error: if self.remote.stream: rmt = self.remote logger.error(f'[{rmt.stream.node_id}:{rmt.stream.conn_no}] ' f'Fwd: {error} for ' f'l{rmt.ifc.l_addr} | r{rmt.ifc.r_addr}') await rmt.ifc.disc() if rmt.ifc.close_cb: rmt.ifc.close_cb() except RuntimeError as error: if self.remote.stream: rmt = self.remote logger.info(f'[{rmt.stream.node_id}:{rmt.stream.conn_no}] ' f'Fwd: {error} for {rmt.ifc.l_addr}') await rmt.ifc.disc() if rmt.ifc.close_cb: rmt.ifc.close_cb() except Exception: Infos.inc_counter('SW_Exception') logger.error( f"Fwd Exception for {self.r_addr}:\n" f"{traceback.format_exc()}") async def publish_outstanding_mqtt(self): '''Publish all outstanding MQTT topics''' try: await self.async_publ_mqtt() await Proxy._async_publ_mqtt_proxy_stat('proxy') except Exception: pass class AsyncStreamServer(AsyncStream): def __init__(self, reader: StreamReader, writer: StreamWriter, async_publ_mqtt, create_remote, rstream: "StreamPtr") -> None: AsyncStream.__init__(self, reader, writer, rstream) self.create_remote = create_remote self.async_publ_mqtt = async_publ_mqtt def close(self) -> None: logging.debug('AsyncStreamServer.close()') self.create_remote = None self.async_publ_mqtt = None super().close() async def server_loop(self) -> None: '''Loop for receiving messages from the inverter (server-side)''' logger.info(f'[{self.node_id}:{self.conn_no}] ' f'Accept connection from {self.r_addr}') Infos.inc_counter('Inverter_Cnt') await self.publish_outstanding_mqtt() await self.loop() Infos.dec_counter('Inverter_Cnt') await self.publish_outstanding_mqtt() logger.info(f'[{self.node_id}:{self.conn_no}] Server loop stopped for' f' r{self.r_addr}') # if the server connection closes, we also have to disconnect # the connection to te TSUN cloud if self.remote and self.remote.stream: logger.info(f'[{self.node_id}:{self.conn_no}] disc client ' f'connection: [{self.remote.ifc.node_id}:' f'{self.remote.ifc.conn_no}]') await self.remote.ifc.disc() async def _async_forward(self) -> None: """forward handler transmits data over the remote connection""" if not self.remote.stream: await self.create_remote() if self.remote.stream and \ self.remote.ifc.init_new_client_conn_cb(): await self.remote.ifc._AsyncStream__async_write() if self.remote.stream: self.remote.ifc.update_header_cb(self.fwd_fifo.peek()) self.fwd_fifo.logging(logging.INFO, 'Forward to ' f'{self.remote.ifc.r_addr}:') self.remote.ifc._writer.write(self.fwd_fifo.get()) await self.remote.ifc._writer.drain() class AsyncStreamClient(AsyncStream): def __init__(self, reader: StreamReader, writer: StreamWriter, rstream: "StreamPtr", close_cb) -> None: AsyncStream.__init__(self, reader, writer, rstream) self.close_cb = close_cb async def disc(self) -> None: logging.debug('AsyncStreamClient.disc()') self.remote = None await super().disc() def close(self) -> None: logging.debug('AsyncStreamClient.close()') self.close_cb = None super().close() async def client_loop(self, _: str) -> None: '''Loop for receiving messages from the TSUN cloud (client-side)''' Infos.inc_counter('Cloud_Conn_Cnt') await self.publish_outstanding_mqtt() await self.loop() Infos.dec_counter('Cloud_Conn_Cnt') await self.publish_outstanding_mqtt() logger.info(f'[{self.node_id}:{self.conn_no}] ' 'Client loop stopped for' f' l{self.l_addr}') if self.close_cb: self.close_cb() async def _async_forward(self) -> None: """forward handler transmits data over the remote connection""" if self.remote.stream: self.remote.ifc.update_header_cb(self.fwd_fifo.peek()) self.fwd_fifo.logging(logging.INFO, 'Forward to ' f'{self.remote.ifc.r_addr}:') self.remote.ifc._writer.write(self.fwd_fifo.get()) await self.remote.ifc._writer.drain()