398 lines
13 KiB
Python
398 lines
13 KiB
Python
import asyncio
|
|
import logging
|
|
import traceback
|
|
import time
|
|
from asyncio import StreamReader, StreamWriter
|
|
from typing import Self
|
|
from itertools import count
|
|
|
|
from proxy import Proxy
|
|
from byte_fifo import ByteFifo
|
|
from async_ifc import AsyncIfc
|
|
from infos import Infos
|
|
|
|
|
|
import gc
|
|
logger = logging.getLogger('conn')
|
|
|
|
|
|
class AsyncIfcImpl(AsyncIfc):
|
|
_ids = count(0)
|
|
|
|
def __init__(self) -> None:
|
|
logger.debug('AsyncIfcImpl.__init__')
|
|
self.fwd_fifo = ByteFifo()
|
|
self.tx_fifo = ByteFifo()
|
|
self.rx_fifo = ByteFifo()
|
|
self.conn_no = next(self._ids)
|
|
self.node_id = ''
|
|
self.timeout_cb = None
|
|
self.init_new_client_conn_cb = None
|
|
self.update_header_cb = None
|
|
|
|
def close(self):
|
|
self.timeout_cb = None
|
|
self.fwd_fifo.reg_trigger(None)
|
|
self.tx_fifo.reg_trigger(None)
|
|
self.rx_fifo.reg_trigger(None)
|
|
|
|
def set_node_id(self, value: str):
|
|
self.node_id = value
|
|
|
|
def get_conn_no(self):
|
|
return self.conn_no
|
|
|
|
def tx_add(self, data: bytearray):
|
|
''' add data to transmit queue'''
|
|
self.tx_fifo += data
|
|
|
|
def tx_flush(self):
|
|
''' send transmit queue and clears it'''
|
|
self.tx_fifo()
|
|
|
|
def tx_peek(self, size: int = None) -> bytearray:
|
|
'''returns size numbers of byte without removing them'''
|
|
return self.tx_fifo.peek(size)
|
|
|
|
def tx_log(self, level, info):
|
|
''' log the transmit queue'''
|
|
self.tx_fifo.logging(level, info)
|
|
|
|
def tx_clear(self):
|
|
''' clear transmit queue'''
|
|
self.tx_fifo.clear()
|
|
|
|
def tx_len(self):
|
|
''' get numner of bytes in the transmit queue'''
|
|
return len(self.tx_fifo)
|
|
|
|
def fwd_add(self, data: bytearray):
|
|
''' add data to forward queue'''
|
|
self.fwd_fifo += data
|
|
|
|
def fwd_log(self, level, info):
|
|
''' log the forward queue'''
|
|
self.fwd_fifo.logging(level, info)
|
|
|
|
def rx_get(self, size: int = None) -> bytearray:
|
|
'''removes size numbers of bytes and return them'''
|
|
return self.rx_fifo.get(size)
|
|
|
|
def rx_peek(self, size: int = None) -> bytearray:
|
|
'''returns size numbers of byte without removing them'''
|
|
return self.rx_fifo.peek(size)
|
|
|
|
def rx_log(self, level, info):
|
|
''' logs the receive queue'''
|
|
self.rx_fifo.logging(level, info)
|
|
|
|
def rx_clear(self):
|
|
''' clear receive queue'''
|
|
self.rx_fifo.clear()
|
|
|
|
def rx_len(self):
|
|
''' get numner of bytes in the receive queue'''
|
|
return len(self.rx_fifo)
|
|
|
|
def rx_set_cb(self, callback):
|
|
self.rx_fifo.reg_trigger(callback)
|
|
|
|
def prot_set_timeout_cb(self, callback):
|
|
self.timeout_cb = callback
|
|
|
|
def prot_set_init_new_client_conn_cb(self, callback):
|
|
self.init_new_client_conn_cb = callback
|
|
|
|
def prot_set_update_header_cb(self, callback):
|
|
self.update_header_cb = callback
|
|
|
|
|
|
class StreamPtr():
|
|
'''Descr StreamPtr'''
|
|
def __init__(self, _stream, _ifc=None):
|
|
self.stream = _stream
|
|
self.ifc = _ifc
|
|
|
|
@property
|
|
def ifc(self):
|
|
return self._ifc
|
|
|
|
@ifc.setter
|
|
def ifc(self, value):
|
|
self._ifc = value
|
|
|
|
@property
|
|
def stream(self):
|
|
return self._stream
|
|
|
|
@stream.setter
|
|
def stream(self, value):
|
|
self._stream = value
|
|
|
|
|
|
class AsyncStream(AsyncIfcImpl):
|
|
MAX_PROC_TIME = 2
|
|
'''maximum processing time for a received msg in sec'''
|
|
MAX_START_TIME = 400
|
|
'''maximum time without a received msg in sec'''
|
|
MAX_INV_IDLE_TIME = 120
|
|
'''maximum time without a received msg from the inverter in sec'''
|
|
MAX_DEF_IDLE_TIME = 360
|
|
'''maximum default time without a received msg in sec'''
|
|
|
|
def __init__(self, reader: StreamReader, writer: StreamWriter,
|
|
rstream: "StreamPtr") -> None:
|
|
AsyncIfcImpl.__init__(self)
|
|
|
|
logger.debug('AsyncStream.__init__')
|
|
|
|
self.remote = rstream
|
|
self.tx_fifo.reg_trigger(self.__write_cb)
|
|
self._reader = reader
|
|
self._writer = writer
|
|
self.r_addr = writer.get_extra_info('peername')
|
|
self.l_addr = writer.get_extra_info('sockname')
|
|
self.proc_start = None # start processing start timestamp
|
|
self.proc_max = 0
|
|
self.async_publ_mqtt = None # will be set AsyncStreamServer only
|
|
|
|
def __write_cb(self):
|
|
self._writer.write(self.tx_fifo.get())
|
|
|
|
def __timeout(self) -> int:
|
|
if self.timeout_cb:
|
|
return self.timeout_cb()
|
|
return 360
|
|
|
|
async def loop(self) -> Self:
|
|
"""Async loop handler for precessing all received messages"""
|
|
self.proc_start = time.time()
|
|
while True:
|
|
try:
|
|
self.__calc_proc_time()
|
|
dead_conn_to = self.__timeout()
|
|
await asyncio.wait_for(self.__async_read(),
|
|
dead_conn_to)
|
|
|
|
await self.__async_write()
|
|
await self.__async_forward()
|
|
if self.async_publ_mqtt:
|
|
await self.async_publ_mqtt()
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f'[{self.node_id}:{self.conn_no}] Dead '
|
|
f'connection timeout ({dead_conn_to}s) '
|
|
f'for {self.l_addr}')
|
|
await self.disc()
|
|
return self
|
|
|
|
except OSError as error:
|
|
logger.error(f'[{self.node_id}:{self.conn_no}] '
|
|
f'{error} for l{self.l_addr} | '
|
|
f'r{self.r_addr}')
|
|
await self.disc()
|
|
return self
|
|
|
|
except RuntimeError as error:
|
|
logger.info(f'[{self.node_id}:{self.conn_no}] '
|
|
f'{error} for {self.l_addr}')
|
|
await self.disc()
|
|
return self
|
|
|
|
except Exception:
|
|
Infos.inc_counter('SW_Exception')
|
|
logger.error(
|
|
f"Exception for {self.r_addr}:\n"
|
|
f"{traceback.format_exc()}")
|
|
await asyncio.sleep(0) # be cooperative to other task
|
|
|
|
def __calc_proc_time(self):
|
|
if self.proc_start:
|
|
proc = time.time() - self.proc_start
|
|
if proc > self.proc_max:
|
|
self.proc_max = proc
|
|
self.proc_start = None
|
|
|
|
async def disc(self) -> None:
|
|
"""Async disc handler for graceful disconnect"""
|
|
if self._writer.is_closing():
|
|
return
|
|
logger.debug(f'AsyncStream.disc() l{self.l_addr} | r{self.r_addr}')
|
|
self._writer.close()
|
|
await self._writer.wait_closed()
|
|
|
|
def close(self) -> None:
|
|
logging.debug(f'AsyncStream.close() l{self.l_addr} | r{self.r_addr}')
|
|
"""close handler for a no waiting disconnect
|
|
|
|
hint: must be called before releasing the connection instance
|
|
"""
|
|
super().close()
|
|
self._reader.feed_eof() # abort awaited read
|
|
if self._writer.is_closing():
|
|
return
|
|
self._writer.close()
|
|
|
|
def healthy(self) -> bool:
|
|
elapsed = 0
|
|
if self.proc_start is not None:
|
|
elapsed = time.time() - self.proc_start
|
|
if elapsed > self.MAX_PROC_TIME:
|
|
logging.debug(f'[{self.node_id}:{self.conn_no}:'
|
|
f'{type(self).__name__}]'
|
|
f' act:{round(1000*elapsed)}ms'
|
|
f' max:{round(1000*self.proc_max)}ms')
|
|
logging.debug(f'Healthy()) refs: {gc.get_referrers(self)}')
|
|
return elapsed < 5
|
|
|
|
'''
|
|
Our private methods
|
|
'''
|
|
async def __async_read(self) -> None:
|
|
"""Async read handler to read received data from TCP stream"""
|
|
data = await self._reader.read(4096)
|
|
if data:
|
|
self.proc_start = time.time()
|
|
self.rx_fifo += data
|
|
wait = self.rx_fifo() # call read in parent class
|
|
if wait and wait > 0:
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise RuntimeError("Peer closed.")
|
|
|
|
async def __async_write(self, headline: str = 'Transmit to ') -> None:
|
|
"""Async write handler to transmit the send_buffer"""
|
|
if len(self.tx_fifo) > 0:
|
|
self.tx_fifo.logging(logging.INFO, f'{headline}{self.r_addr}:')
|
|
self._writer.write(self.tx_fifo.get())
|
|
await self._writer.drain()
|
|
|
|
async def __async_forward(self) -> None:
|
|
"""forward handler transmits data over the remote connection"""
|
|
if len(self.fwd_fifo) == 0:
|
|
return
|
|
try:
|
|
await self._async_forward()
|
|
|
|
except OSError as error:
|
|
if self.remote.stream:
|
|
rmt = self.remote
|
|
logger.error(f'[{rmt.stream.node_id}:{rmt.stream.conn_no}] '
|
|
f'Fwd: {error} for '
|
|
f'l{rmt.ifc.l_addr} | r{rmt.ifc.r_addr}')
|
|
await rmt.ifc.disc()
|
|
if rmt.ifc.close_cb:
|
|
rmt.ifc.close_cb()
|
|
|
|
except RuntimeError as error:
|
|
if self.remote.stream:
|
|
rmt = self.remote
|
|
logger.info(f'[{rmt.stream.node_id}:{rmt.stream.conn_no}] '
|
|
f'Fwd: {error} for {rmt.ifc.l_addr}')
|
|
await rmt.ifc.disc()
|
|
if rmt.ifc.close_cb:
|
|
rmt.ifc.close_cb()
|
|
|
|
except Exception:
|
|
Infos.inc_counter('SW_Exception')
|
|
logger.error(
|
|
f"Fwd Exception for {self.r_addr}:\n"
|
|
f"{traceback.format_exc()}")
|
|
|
|
async def publish_outstanding_mqtt(self):
|
|
'''Publish all outstanding MQTT topics'''
|
|
try:
|
|
await self.async_publ_mqtt()
|
|
await Proxy._async_publ_mqtt_proxy_stat('proxy')
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class AsyncStreamServer(AsyncStream):
|
|
def __init__(self, reader: StreamReader, writer: StreamWriter,
|
|
async_publ_mqtt, create_remote,
|
|
rstream: "StreamPtr") -> None:
|
|
AsyncStream.__init__(self, reader, writer, rstream)
|
|
self.create_remote = create_remote
|
|
self.async_publ_mqtt = async_publ_mqtt
|
|
|
|
def close(self) -> None:
|
|
logging.debug('AsyncStreamServer.close()')
|
|
self.create_remote = None
|
|
self.async_publ_mqtt = None
|
|
super().close()
|
|
|
|
async def server_loop(self) -> None:
|
|
'''Loop for receiving messages from the inverter (server-side)'''
|
|
logger.info(f'[{self.node_id}:{self.conn_no}] '
|
|
f'Accept connection from {self.r_addr}')
|
|
Infos.inc_counter('Inverter_Cnt')
|
|
await self.publish_outstanding_mqtt()
|
|
await self.loop()
|
|
Infos.dec_counter('Inverter_Cnt')
|
|
await self.publish_outstanding_mqtt()
|
|
logger.info(f'[{self.node_id}:{self.conn_no}] Server loop stopped for'
|
|
f' r{self.r_addr}')
|
|
|
|
# if the server connection closes, we also have to disconnect
|
|
# the connection to te TSUN cloud
|
|
if self.remote and self.remote.stream:
|
|
logger.info(f'[{self.node_id}:{self.conn_no}] disc client '
|
|
f'connection: [{self.remote.ifc.node_id}:'
|
|
f'{self.remote.ifc.conn_no}]')
|
|
await self.remote.ifc.disc()
|
|
|
|
async def _async_forward(self) -> None:
|
|
"""forward handler transmits data over the remote connection"""
|
|
if not self.remote.stream:
|
|
await self.create_remote()
|
|
if self.remote.stream and \
|
|
self.remote.ifc.init_new_client_conn_cb():
|
|
await self.remote.ifc._AsyncStream__async_write()
|
|
if self.remote.stream:
|
|
self.remote.ifc.update_header_cb(self.fwd_fifo.peek())
|
|
self.fwd_fifo.logging(logging.INFO, 'Forward to '
|
|
f'{self.remote.ifc.r_addr}:')
|
|
self.remote.ifc._writer.write(self.fwd_fifo.get())
|
|
await self.remote.ifc._writer.drain()
|
|
|
|
|
|
class AsyncStreamClient(AsyncStream):
|
|
def __init__(self, reader: StreamReader, writer: StreamWriter,
|
|
rstream: "StreamPtr", close_cb) -> None:
|
|
AsyncStream.__init__(self, reader, writer, rstream)
|
|
self.close_cb = close_cb
|
|
|
|
async def disc(self) -> None:
|
|
logging.debug('AsyncStreamClient.disc()')
|
|
self.remote = None
|
|
await super().disc()
|
|
|
|
def close(self) -> None:
|
|
logging.debug('AsyncStreamClient.close()')
|
|
self.close_cb = None
|
|
super().close()
|
|
|
|
async def client_loop(self, _: str) -> None:
|
|
'''Loop for receiving messages from the TSUN cloud (client-side)'''
|
|
Infos.inc_counter('Cloud_Conn_Cnt')
|
|
await self.publish_outstanding_mqtt()
|
|
await self.loop()
|
|
Infos.dec_counter('Cloud_Conn_Cnt')
|
|
await self.publish_outstanding_mqtt()
|
|
logger.info(f'[{self.node_id}:{self.conn_no}] '
|
|
'Client loop stopped for'
|
|
f' l{self.l_addr}')
|
|
|
|
if self.close_cb:
|
|
self.close_cb()
|
|
|
|
async def _async_forward(self) -> None:
|
|
"""forward handler transmits data over the remote connection"""
|
|
if self.remote.stream:
|
|
self.remote.ifc.update_header_cb(self.fwd_fifo.peek())
|
|
self.fwd_fifo.logging(logging.INFO, 'Forward to '
|
|
f'{self.remote.ifc.r_addr}:')
|
|
self.remote.ifc._writer.write(self.fwd_fifo.get())
|
|
await self.remote.ifc._writer.drain()
|