diff --git a/CHANGELOG.md b/CHANGELOG.md index ac164e1..96a9ac8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +- add exception handling for message forwarding [#94](https://github.com/s-allius/tsun-gen3-proxy/issues/94) - print imgae build time during proxy start - add type annotations - improve async unit test and fix pytest warnings diff --git a/app/src/async_stream.py b/app/src/async_stream.py index 513ab8e..7cbca8e 100644 --- a/app/src/async_stream.py +++ b/app/src/async_stream.py @@ -3,6 +3,7 @@ import traceback import time from asyncio import StreamReader, StreamWriter from messages import hex_dump_memory +from typing import Self logger = logging.getLogger('conn') @@ -20,7 +21,7 @@ class AsyncStream(): self.proc_start = None # start processing start timestamp self.proc_max = 0 - async def server_loop(self, addr): + async def server_loop(self, addr: str) -> None: '''Loop for receiving messages from the inverter (server-side)''' logging.info(f'[{self.node_id}] Accept connection from {addr}') self.inc_counter('Inverter_Cnt') @@ -39,7 +40,7 @@ class AsyncStream(): except Exception: pass - async def client_loop(self, addr): + async def client_loop(self, addr: str) -> None: '''Loop for receiving messages from the TSUN cloud (client-side)''' clientStream = await self.remoteStream.loop() logging.info(f'[{self.node_id}] Client loop stopped for' @@ -59,7 +60,8 @@ class AsyncStream(): # than erase client connection self.remoteStream = None - async def loop(self): + async def loop(self) -> Self: + """Async loop handler for precessing all received messages""" self.r_addr = self.writer.get_extra_info('peername') self.l_addr = self.writer.get_extra_info('sockname') self.proc_start = time.time() @@ -96,14 +98,29 @@ class AsyncStream(): f"Exception for {self.addr}:\n" f"{traceback.format_exc()}") + async def async_write(self, headline: str = 'Transmit to ') -> None: + """Async write handler to transmit the send_buffer""" + if self._send_buffer: + hex_dump_memory(logging.INFO, f'{headline}{self.addr}:', + self._send_buffer, len(self._send_buffer)) + self.writer.write(self._send_buffer) + await self.writer.drain() + self._send_buffer = bytearray(0) # self._send_buffer[sent:] + async def disc(self) -> None: + """Async disc handler for graceful disconnect""" if self.writer.is_closing(): return logger.debug(f'AsyncStream.disc() l{self.l_addr} | r{self.r_addr}') self.writer.close() await self.writer.wait_closed() - def close(self): + def close(self) -> None: + """close handler for a no waiting disconnect + + hint: must be called before releasing the connection instance + """ + self.reader.feed_eof() # abort awaited read if self.writer.is_closing(): return logger.debug(f'AsyncStream.close() l{self.l_addr} | r{self.r_addr}') @@ -122,6 +139,7 @@ class AsyncStream(): Our private methods ''' async def __async_read(self) -> None: + """Async read handler to read received data from TCP stream""" data = await self.reader.read(4096) if data: self.proc_start = time.time() @@ -130,16 +148,11 @@ class AsyncStream(): else: raise RuntimeError("Peer closed.") - async def async_write(self, headline='Transmit to ') -> None: - if self._send_buffer: - hex_dump_memory(logging.INFO, f'{headline}{self.addr}:', - self._send_buffer, len(self._send_buffer)) - self.writer.write(self._send_buffer) - await self.writer.drain() - self._send_buffer = bytearray(0) # self._send_buffer[sent:] - async def __async_forward(self) -> None: - if self._forward_buffer: + """forward handler transmits data over the remote connection""" + if not self._forward_buffer: + return + try: if not self.remoteStream: await self.async_create_remote() if self.remoteStream: @@ -156,6 +169,29 @@ class AsyncStream(): await self.remoteStream.writer.drain() self._forward_buffer = bytearray(0) + except OSError as error: + if self.remoteStream: + rmt = self.remoteStream + self.remoteStream = None + logger.error(f'[{rmt.node_id}] Fwd: {error} for ' + f'l{rmt.l_addr} | r{rmt.r_addr}') + await rmt.disc() + rmt.close() + + except RuntimeError as error: + if self.remoteStream: + rmt = self.remoteStream + self.remoteStream = None + logger.info(f"[{rmt.node_id}] Fwd: {error} for {rmt.l_addr}") + await rmt.disc() + rmt.close() + + except Exception: + self.inc_counter('SW_Exception') + logger.error( + f"Fwd Exception for {self.addr}:\n" + f"{traceback.format_exc()}") + def __del__(self): logger.debug( f"AsyncStream.__del__ l{self.l_addr} | r{self.r_addr}")