diff --git a/app/tests/test_modbus_tcp.py b/app/tests/test_modbus_tcp.py index 52aabee..41640ec 100644 --- a/app/tests/test_modbus_tcp.py +++ b/app/tests/test_modbus_tcp.py @@ -4,6 +4,7 @@ import asyncio from mock import patch from enum import Enum +from enum import Enum from app.src.singleton import Singleton from app.src.config import Config from app.src.infos import Infos @@ -11,6 +12,10 @@ from app.src.mqtt import Mqtt from app.src.messages import Message, State from app.src.inverter import Inverter from app.src.modbus_tcp import ModbusConn, ModbusTcp +from app.src.mqtt import Mqtt +from app.src.messages import Message, State +from app.src.inverter import Inverter +from app.src.modbus_tcp import ModbusConn, ModbusTcp pytest_plugins = ('pytest_asyncio',) @@ -75,6 +80,47 @@ class TestType(Enum): RD_TEST_TIMEOUT = 2 +test = TestType.RD_TEST_0_BYTES +def config_conn(test_hostname, test_port): + Config.act_config = { + 'mqtt':{ + 'host': test_hostname, + 'port': test_port, + 'user': '', + 'passwd': '' + }, + 'ha':{ + 'auto_conf_prefix': 'homeassistant', + 'discovery_prefix': 'homeassistant', + 'entity_prefix': 'tsun', + 'proxy_node_id': 'test_1', + 'proxy_unique_id': '' + }, + 'inverters':{ + 'allow_all': True, + "R170000000000001":{ + 'node_id': 'inv_1' + }, + "Y170000000000001":{ + 'node_id': 'inv_2', + 'monitor_sn': 2000000000, + 'modbus_polling': True, + 'suggested_area': "", + 'sensor_list': 0x2b0, + 'client_mode':{ + 'host': '192.168.0.1', + 'port': 8899 + } + } + } + } + + +class TestType(Enum): + RD_TEST_0_BYTES = 1 + RD_TEST_TIMEOUT = 2 + + test = TestType.RD_TEST_0_BYTES class FakeReader(): @@ -88,6 +134,16 @@ class FakeReader(): raise TimeoutError def feed_eof(self): return + def __init__(self): + self.on_recv = asyncio.Event() + async def read(self, max_len: int): + await self.on_recv.wait() + if test == TestType.RD_TEST_0_BYTES: + return b'' + elif test == TestType.RD_TEST_TIMEOUT: + raise TimeoutError + def feed_eof(self): + return class FakeWriter(): @@ -105,6 +161,20 @@ class FakeWriter(): return async def wait_closed(self): return + def write(self, buf: bytes): + return + def get_extra_info(self, sel: str): + if sel == 'peername': + return 'remote.intern' + elif sel == 'sockname': + return 'sock:1234' + assert False + def is_closing(self): + return False + def close(self): + return + async def wait_closed(self): + return @pytest.fixture @@ -115,6 +185,9 @@ def patch_open(): def new_open(host: str, port: int): global test + if test == TestType.RD_TEST_TIMEOUT: + raise TimeoutError + global test if test == TestType.RD_TEST_TIMEOUT: raise TimeoutError return new_conn(None) @@ -127,6 +200,11 @@ def patch_no_mqtt(): with patch.object(Mqtt, 'publish') as conn: yield conn +@pytest.fixture +def patch_no_mqtt(): + with patch.object(Mqtt, 'publish') as conn: + yield conn + @pytest.mark.asyncio async def test_modbus_conn(patch_open):