diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d0374832..5d811c4c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,9 +33,10 @@ jobs: uv pip install ".[test]" - name: Run pytest run: uv run pytest --cov=snap7 --cov-report=xml --cov-report=term - - name: Upload coverage report + - name: Upload coverage to Codecov if: matrix.python-version == '3.13' && matrix.runs-on == 'ubuntu-24.04' - uses: actions/upload-artifact@v7 + uses: codecov/codecov-action@v5 with: - name: coverage-report - path: coverage.xml + files: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false diff --git a/CHANGES.md b/CHANGES.md index b52afd85..d3baabb3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,21 +5,43 @@ CHANGES ----- Major release: python-snap7 is now a pure Python S7 communication library. +This version completely breaks with the previous approach of wrapping the C snap7 +shared library. The entire S7 protocol stack is now implemented in pure Python, +greatly improving portability and making it easier to install and extend. * **Breaking**: The C snap7 library is no longer required or used * Complete rewrite of the S7 protocol stack in pure Python * Native Python implementation of TPKT (RFC 1006) and COTP (ISO 8073) layers * Native S7 protocol PDU encoding/decoding * Pure Python server implementation for testing and simulation -* No platform-specific binary dependencies +* No platform-specific binary dependencies — works on any platform that runs Python * Improved error handling and connection management * Full type annotations with mypy strict mode * CLI interface for running an S7 server emulator (`pip install "python-snap7[cli]"`) -If you experience issues with 3.0, pin to the last pre-3.0 release: +If you experience issues with 3.0, please report them on the +[issue tracker](https://github.com/gijzelaerr/python-snap7/issues) with a clear +description and the version you are using. As a workaround, pin to the last +pre-3.0 release: $ pip install "python-snap7<3" +### Thanks + +Special thanks to the following people for testing, reporting issues, and providing +feedback during the 3.0 development: + +* [@lupaulus](https://github.com/lupaulus) — extensive testing and bug reports +* [@spreeker](https://github.com/spreeker) — testing and feedback +* [@nikteliy](https://github.com/nikteliy) — review and feedback on the rewrite +* [@amorelettronico](https://github.com/amorelettronico) — testing +* [@razour08](https://github.com/razour08) — testing +* [@core-engineering](https://github.com/core-engineering) — bug reports (#553) +* [@AndreasScharf](https://github.com/AndreasScharf) — bug reports (#572) +* [@Robatronic](https://github.com/Robatronic) — bug reports (#574) +* [@hirotasoshu](https://github.com/hirotasoshu) — feedback (#545) +* [@PoitrasJ](https://github.com/PoitrasJ) — bug reports (#479) + 1.2 --- diff --git a/README.rst b/README.rst index adf68d91..f748fb5f 100644 --- a/README.rst +++ b/README.rst @@ -1,25 +1,63 @@ +.. image:: https://img.shields.io/pypi/v/python-snap7.svg + :target: https://pypi.org/project/python-snap7/ + +.. image:: https://img.shields.io/pypi/pyversions/python-snap7.svg + :target: https://pypi.org/project/python-snap7/ + +.. image:: https://img.shields.io/github/license/gijzelaerr/python-snap7.svg + :target: https://github.com/gijzelaerr/python-snap7/blob/master/LICENSE + +.. image:: https://github.com/gijzelaerr/python-snap7/actions/workflows/test.yml/badge.svg + :target: https://github.com/gijzelaerr/python-snap7/actions/workflows/test.yml + +.. image:: https://readthedocs.org/projects/python-snap7/badge/ + :target: https://python-snap7.readthedocs.io/en/latest/ + +.. image:: https://codecov.io/gh/gijzelaerr/python-snap7/branch/master/graph/badge.svg + :target: https://codecov.io/gh/gijzelaerr/python-snap7 + About ===== Python-snap7 is a pure Python S7 communication library for interfacing with Siemens S7 PLCs. +The name "python-snap7" is historical — the library originally started as a Python wrapper +around the `Snap7 `_ C library. As of version 3.0, the C +library is no longer used, but the name is kept for backwards compatibility. + Python-snap7 is tested with Python 3.10+, on Windows, Linux and OS X. The full documentation is available on `Read The Docs `_. -Version 3.0 - Breaking Changes -=============================== +Version 3.0 - Pure Python Rewrite +================================== + +Version 3.0 is a ground-up rewrite of python-snap7. The library no longer wraps the +C snap7 shared library — instead, the entire S7 protocol stack (TPKT, COTP, and S7) +is now implemented in pure Python. This is a **breaking change** from all previous +versions. + +**Why this matters:** + +* **Portability**: No more platform-specific shared libraries (`.dll`, `.so`, `.dylib`). + python-snap7 now works on any platform that runs Python — including ARM, Alpine Linux, + and other environments where the C library was difficult or impossible to install. +* **Easier installation**: Just ``pip install python-snap7``. No native dependencies, + no compiler toolchains, no manual library setup. +* **Easier to extend**: New features and protocol support can be added directly in Python. -Version 3.0 is a major release that rewrites python-snap7 as a pure Python -implementation. The C snap7 library is no longer required. +**If you experience issues with 3.0:** -This release may contain breaking changes. If you experience issues, you can -pin to the last pre-3.0 release:: +1. Please report them on the `issue tracker `_ + with a clear description of the problem and the version you are using + (``python -c "import snap7; print(snap7.__version__)"``). +2. As a workaround, you can pin to the last pre-3.0 release:: - $ pip install "python-snap7<3" + $ pip install "python-snap7<3" -The latest stable pre-3.0 release is version 2.1.0. + The latest stable pre-3.0 release is version 2.1.0. Documentation for pre-3.0 + versions is available at `Read The Docs `_. Installation @@ -29,4 +67,4 @@ Install using pip:: $ pip install python-snap7 -No native libraries or platform-specific dependencies are required - python-snap7 is a pure Python package that works on all platforms. +No native libraries or platform-specific dependencies are required — python-snap7 is a pure Python package that works on all platforms. diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..d18039b4 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,5 @@ +coverage: + status: + project: + default: + target: 80% diff --git a/doc/API/async_client.rst b/doc/API/async_client.rst new file mode 100644 index 00000000..0cf130fb --- /dev/null +++ b/doc/API/async_client.rst @@ -0,0 +1,49 @@ +AsyncClient +=========== + +.. warning:: + + The ``AsyncClient`` is **experimental**. The API may change in future + releases. If you encounter problems, please `open an issue + `_. + +The :class:`~snap7.async_client.AsyncClient` provides a native ``asyncio`` +interface for communicating with Siemens S7 PLCs. It has feature parity with +the synchronous :class:`~snap7.client.Client` and is safe for concurrent use +via ``asyncio.gather()``. + +Quick start +----------- + +.. code-block:: python + + import asyncio + import snap7 + + async def main(): + async with snap7.AsyncClient() as client: + await client.connect("192.168.1.10", 0, 1) + data = await client.db_read(1, 0, 4) + print(data) + + asyncio.run(main()) + +Concurrent reads +---------------- + +An internal ``asyncio.Lock`` serialises each send/receive cycle so that +multiple coroutines can safely share a single connection: + +.. code-block:: python + + results = await asyncio.gather( + client.db_read(1, 0, 4), + client.db_read(1, 10, 4), + ) + +API reference +------------- + +.. automodule:: snap7.async_client + :members: + :exclude-members: AsyncISOTCPConnection diff --git a/doc/API/s7commplus.rst b/doc/API/s7commplus.rst new file mode 100644 index 00000000..4314bb4e --- /dev/null +++ b/doc/API/s7commplus.rst @@ -0,0 +1,70 @@ +S7CommPlus (S7-1200/1500) +========================= + +.. warning:: + + S7CommPlus support is **experimental**. The API may change in future + releases. If you encounter problems, please `open an issue + `_. + +The :mod:`snap7.s7commplus` package provides support for Siemens S7-1200 and +S7-1500 PLCs, which use the S7CommPlus protocol instead of the classic S7 +protocol used by S7-300/400. + +Both synchronous and asynchronous clients are available. When a PLC does not +support S7CommPlus data operations, the clients automatically fall back to the +legacy S7 protocol transparently. + +Synchronous client +------------------ + +.. code-block:: python + + from snap7.s7commplus.client import S7CommPlusClient + + client = S7CommPlusClient() + client.connect("192.168.1.10") + data = client.db_read(1, 0, 4) + client.disconnect() + +Asynchronous client +------------------- + +.. code-block:: python + + import asyncio + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async def main(): + client = S7CommPlusAsyncClient() + await client.connect("192.168.1.10") + data = await client.db_read(1, 0, 4) + await client.disconnect() + + asyncio.run(main()) + +Legacy fallback +--------------- + +If the PLC returns an error for S7CommPlus data operations (common with some +firmware versions), the client automatically falls back to the classic S7 +protocol. You can check whether fallback is active: + +.. code-block:: python + + client.connect("192.168.1.10") + if client.using_legacy_fallback: + print("Using legacy S7 protocol") + +API reference +------------- + +.. automodule:: snap7.s7commplus.client + :members: + +.. automodule:: snap7.s7commplus.async_client + :members: + +.. automodule:: snap7.s7commplus.connection + :members: + :exclude-members: S7CommPlusConnection diff --git a/doc/connecting.rst b/doc/connecting.rst new file mode 100644 index 00000000..8eefe4a6 --- /dev/null +++ b/doc/connecting.rst @@ -0,0 +1,100 @@ +Connecting to PLCs +================== + +This page shows how to connect to different Siemens PLC models using +python-snap7. + +.. contents:: On this page + :local: + :depth: 2 + + +Rack/Slot Reference +------------------- + +.. list-table:: + :header-rows: 1 + :widths: 20 10 10 60 + + * - PLC Model + - Rack + - Slot + - Notes + * - S7-300 + - 0 + - 2 + - + * - S7-400 + - 0 + - 3 + - May vary with multi-rack configurations + * - S7-1200 + - 0 + - 1 + - PUT/GET access must be enabled in TIA Portal + * - S7-1500 + - 0 + - 1 + - PUT/GET access must be enabled in TIA Portal + * - S7-200 / Logo + - -- + - -- + - Use ``set_connection_params`` with TSAP addressing + +.. warning:: + + S7-1200 and S7-1500 PLCs ship with PUT/GET communication disabled by + default. Enable it in TIA Portal under the CPU properties before + connecting. See :doc:`tia-portal-config` for step-by-step instructions. + + +S7-300 +------ + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 2) + +S7-400 +------ + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 3) + +S7-1200 / S7-1500 +------------------ + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + +S7-200 / Logo (TSAP Connection) +-------------------------------- + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.set_connection_params("192.168.1.10", 0x1000, 0x2000) + client.connect("192.168.1.10", 0, 0) + +Using a Non-Standard Port +-------------------------- + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1, tcp_port=1102) diff --git a/doc/connection-issues.rst b/doc/connection-issues.rst new file mode 100644 index 00000000..95008553 --- /dev/null +++ b/doc/connection-issues.rst @@ -0,0 +1,106 @@ +Connection Issues +================= + +.. contents:: On this page + :local: + :depth: 2 + + +.. _connection-recovery: + +Connection Recovery +------------------- + +Network connections to PLCs can drop due to cable issues, PLC restarts, or +network problems. Use a reconnection pattern to handle this gracefully: + +.. code-block:: python + + import snap7 + import time + import logging + + logger = logging.getLogger(__name__) + + client = snap7.Client() + + def connect(address: str = "192.168.1.10", rack: int = 0, slot: int = 1) -> None: + client.connect(address, rack, slot) + + def safe_read(db: int, start: int, size: int) -> bytearray: + """Read from DB with automatic reconnection on failure.""" + try: + return client.db_read(db, start, size) + except Exception: + logger.warning("Read failed, attempting reconnection...") + try: + client.disconnect() + except Exception: + pass + time.sleep(1) + connect() + return client.db_read(db, start, size) + + def safe_write(db: int, start: int, data: bytearray) -> None: + """Write to DB with automatic reconnection on failure.""" + try: + client.db_write(db, start, data) + except Exception: + logger.warning("Write failed, attempting reconnection...") + try: + client.disconnect() + except Exception: + pass + time.sleep(1) + connect() + client.db_write(db, start, data) + +For long-running applications, wrap your main loop with reconnection logic: + +.. code-block:: python + + while True: + try: + data = safe_read(1, 0, 10) + # process data... + time.sleep(0.5) + except Exception: + logger.error("Failed after reconnection attempt, retrying in 5s...") + time.sleep(5) + + +Connection Timeout +------------------ + +The default connection timeout is 5 seconds. You can configure it by accessing +the underlying connection object: + +.. code-block:: python + + import snap7 + + client = snap7.Client() + + # Connect with a custom timeout (in seconds) + client.connect("192.168.1.10", 0, 1) + + # The timeout is set on the underlying connection + # Default is 5.0 seconds + client.connection.timeout = 10.0 # Set to 10 seconds + +To set the timeout **before** connecting, use ``set_connection_params`` and then +connect manually, or simply reconnect after adjusting: + +.. code-block:: python + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + + # Adjust timeout for slow networks + client.connection.timeout = 15.0 + +.. note:: + + If you are experiencing frequent timeouts, check your network quality first. + Typical S7 communication on a local network should respond within + milliseconds. diff --git a/doc/error-reference.rst b/doc/error-reference.rst new file mode 100644 index 00000000..812b28f2 --- /dev/null +++ b/doc/error-reference.rst @@ -0,0 +1,50 @@ +Error Message Reference +======================= + +The following table maps common S7 error strings to their likely cause and fix. + +.. list-table:: + :header-rows: 1 + :widths: 35 30 35 + + * - Error message + - Likely cause + - Fix + * - ``CLI : function refused by CPU (Unknown error)`` + - PUT/GET communication is not enabled on the PLC, or the data block + still has optimized block access enabled. + - Enable PUT/GET in TIA Portal and disable optimized block access on each + DB. See :doc:`tia-portal-config`. + * - ``CPU : Function not available`` + - The requested function is not supported on this PLC model. S7-1200 and + S7-1500 PLCs restrict certain operations. + - Check Siemens documentation for your PLC model. Some functions are only + available on S7-300/400. + * - ``CPU : Item not available`` + - Wrong DB number, the DB does not exist, or the address is out of range. + - Verify the DB number exists on the PLC and that the offset and size are + within bounds. + * - ``CPU : Address out of range`` + - Reading or writing past the end of a DB or memory area. + - Check the DB size in TIA Portal and ensure ``start + size`` does not + exceed it. + * - ``CPU : Function not authorized for current protection level`` + - The PLC has password protection enabled. + - Remove or lower the protection level in TIA Portal under + Protection & Security. + * - ``ISO : An error occurred during recv TCP : Connection timed out`` + - Network issue: PLC is unreachable, a firewall is blocking port 102, or + the PLC is not responding. + - Check network connectivity (``ping``), verify firewall rules, and ensure + the PLC is powered on and reachable. + * - ``ISO : An error occurred during send TCP : Connection timed out`` + - Same as above. + - Same as above. + * - ``TCP : Unreachable peer`` + - The PLC is not reachable on the network. + - Verify IP address, subnet, and routing. Ensure the PLC Ethernet port is + connected and configured. + * - ``TCP : Connection reset`` / Socket error 32 (broken pipe) + - The connection to the PLC was lost unexpectedly. + - The PLC may have been restarted, the cable disconnected, or another + client took over the connection. See :doc:`connection-issues`. diff --git a/doc/index.rst b/doc/index.rst index fd34584b..cade988c 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,16 +1,46 @@ Welcome to python-snap7's documentation! ======================================== -Contents: - .. toctree:: :maxdepth: 2 + :caption: Getting Started introduction installation + plc-support + +.. toctree:: + :maxdepth: 2 + :caption: User Guide + + connecting + reading-writing + multi-variable + server + tia-portal-config + +.. toctree:: + :maxdepth: 2 + :caption: Troubleshooting + + error-reference + connection-issues + thread-safety + limitations + +.. toctree:: + :maxdepth: 2 + :caption: Development + development +.. toctree:: + :maxdepth: 2 + :caption: API Reference + API/client + API/async_client + API/s7commplus API/server API/partner API/logo @@ -21,7 +51,6 @@ Contents: API/datatypes - Indices and tables ================== diff --git a/doc/installation.rst b/doc/installation.rst index f6a4a9f5..eaaccb43 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -10,12 +10,23 @@ If you want to use the CLI interface for running an emulator, install it with:: $ pip install "python-snap7[cli]" -That's it! No native libraries or platform-specific setup is required. +That's it! No native libraries or platform-specific setup is required. This works +on any platform that supports Python 3.10+, including ARM, Alpine Linux, and other +environments where the old C library was hard to install. Upgrading from 2.x ------------------- -Version 3.0 is a major rewrite. If you experience issues after upgrading, -you can pin to the last pre-3.0 release:: +Version 3.0 is a complete rewrite. Previous versions wrapped the C snap7 shared +library; version 3.0 implements the entire protocol stack in pure Python. While +the public API is largely the same, this is a fundamental change under the hood. - $ pip install "python-snap7<3" +If you experience issues after upgrading: + +1. Please report them on the `issue tracker `_ + with a clear description and your version (``python -c "import snap7; print(snap7.__version__)"``). +2. As a workaround, pin to the last pre-3.0 release:: + + $ pip install "python-snap7<3" + + The latest stable pre-3.0 release is version 2.1.0. diff --git a/doc/introduction.rst b/doc/introduction.rst index 6994592b..cf1d864b 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -6,7 +6,25 @@ natively with Siemens S7 PLCs. The library implements the complete S7 protocol stack including TPKT (RFC 1006), COTP (ISO 8073), and S7 protocol layers. +The name "python-snap7" is historical: the library originally started as a +Python wrapper around the `Snap7 `_ C library. +As of version 3.0, the C library is no longer used, but the name is kept for +backwards compatibility. + python-snap7 requires Python 3.10+ and runs on Windows, macOS and Linux without any native dependencies. +.. note:: + + **Version 3.0 is a complete rewrite.** Previous versions of python-snap7 + were a wrapper around the C snap7 shared library. Starting with version 3.0, + the entire protocol stack is implemented in pure Python. This eliminates the + need for platform-specific shared libraries and makes the library portable to + any platform that runs Python. + + If you experience issues, please report them on the + `issue tracker `_ with a + clear description and the version you are using. As a workaround, you can + install the last pre-3.0 release with ``pip install "python-snap7<3"``. + The project development is centralized on `github `_. diff --git a/doc/limitations.rst b/doc/limitations.rst new file mode 100644 index 00000000..26f82c5a --- /dev/null +++ b/doc/limitations.rst @@ -0,0 +1,28 @@ +Protocol Limitations and FAQ +============================ + +python-snap7 implements the S7 protocol over TCP/IP. The following operations +are **not possible** with this protocol: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Limitation + - Explanation + * - Read tag/symbol names from PLC + - Symbol names exist only in the TIA Portal project file, not in the PLC. + The S7 protocol only addresses data by area, DB number, and byte offset. + * - Get DB structure or layout from PLC + - The PLC stores only raw bytes. The structure definition lives in the TIA + Portal project. You must define your data layout in your Python code. + * - Discover PLCs on the network + - There is no S7 broadcast discovery mechanism. You must know the PLC's IP + address. + * - Create PLC backups + - Full project backup requires TIA Portal. python-snap7 can upload + individual blocks, but this is not a complete backup. + * - Access S7-1200/1500 PLCs with S7CommPlus security + - PLCs configured to require S7CommPlus encrypted communication cannot be + accessed with the classic S7 protocol. PUT/GET must be enabled as a + fallback. diff --git a/doc/multi-variable.rst b/doc/multi-variable.rst new file mode 100644 index 00000000..b83f35c3 --- /dev/null +++ b/doc/multi-variable.rst @@ -0,0 +1,51 @@ +Multi-Variable Read +=================== + +The ``read_multi_vars`` method reads multiple variables in a single PDU +request, which is significantly faster than individual reads. + +.. code-block:: python + + import snap7 + from snap7.type import Area, WordLen, S7DataItem + from ctypes import c_uint8, cast, POINTER + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + + # Prepare items to read + items = [] + + # Item 1: 4 bytes from DB1, offset 0 + item1 = S7DataItem() + item1.Area = Area.DB + item1.WordLen = WordLen.Byte + item1.DBNumber = 1 + item1.Start = 0 + item1.Amount = 4 + buffer1 = (c_uint8 * 4)() + item1.pData = cast(buffer1, POINTER(c_uint8)) + items.append(item1) + + # Item 2: 2 bytes from DB2, offset 10 + item2 = S7DataItem() + item2.Area = Area.DB + item2.WordLen = WordLen.Byte + item2.DBNumber = 2 + item2.Start = 10 + item2.Amount = 2 + buffer2 = (c_uint8 * 2)() + item2.pData = cast(buffer2, POINTER(c_uint8)) + items.append(item2) + + # Execute the multi-read + result, data_items = client.read_multi_vars(items) + + # Access the returned data + value1 = bytearray(buffer1) + value2 = bytearray(buffer2) + +.. warning:: + + The S7 protocol limits multi-variable reads to **20 items** per request. + If you need more, split them across multiple calls. diff --git a/doc/plc-support.rst b/doc/plc-support.rst new file mode 100644 index 00000000..281459ce --- /dev/null +++ b/doc/plc-support.rst @@ -0,0 +1,159 @@ +PLC Support Matrix +================== + +This page documents which Siemens PLC families are supported by python-snap7, +the communication protocols they use, and any configuration requirements. + +Supported PLCs +-------------- + +.. list-table:: + :header-rows: 1 + :widths: 20 10 10 10 10 15 25 + + * - PLC Family + - Introduced + - S7 (classic) + - S7CommPlus V1 + - S7CommPlus V2/V3 + - python-snap7 support + - Notes + * - S7-300 + - ~1994 + - Yes + - No + - No + - **Full** + - Works out of the box. + * - S7-400 + - ~1996 + - Yes + - No + - No + - **Full** + - Works out of the box. + * - S7-1200 (FW ≤3) + - 2009 + - Yes + - No + - No + - **Full** + - Enable PUT/GET access in TIA Portal. + * - S7-1200 (FW 4+) + - ~2014 + - Yes + - Yes + - No + - **Full** + - Enable PUT/GET access in TIA Portal. Uses classic S7. + * - S7-1500 (FW 1.x) + - 2012 + - PUT/GET only + - Yes + - No + - **Full** (experimental S7CommPlus) + - S7CommPlus V1 session + legacy S7 fallback for data. + * - S7-1500 (FW 2.x) + - ~2016 + - PUT/GET only + - No + - V2 + - **PUT/GET only** + - S7CommPlus V2 support is in development. + * - S7-1500 (FW 3.x+) + - ~2022 + - PUT/GET only + - No + - V3 + - **PUT/GET only** + - S7CommPlus V3 uses proprietary crypto; not yet supported. + * - S7-1500R/H + - ~2019 + - No + - No + - V2/V3 + - **Not supported** + - Redundant CPUs; no classic S7 fallback available. + * - ET 200SP CPU + - ~2014 + - PUT/GET only + - Yes + - Yes + - **PUT/GET only** + - Same behavior as S7-1500 with matching firmware. + * - S7-200 SMART + - ~2012 + - Subset + - No + - No + - **Partial** + - Basic read/write works. Some advanced functions may not be available. + * - LOGO! 8 + - ~2014 + - Subset + - No + - No + - **Full** + - Use the :class:`~snap7.logo.Logo` class. + + +Enabling PUT/GET Access +----------------------- + +For S7-1200 and S7-1500 PLCs, classic S7 protocol access requires the +**PUT/GET** option to be enabled. See :doc:`tia-portal-config` for +step-by-step instructions. + +.. warning:: + + PUT/GET access provides unauthenticated read/write access to PLC memory. + Only enable this on networks that are properly segmented and secured. + + +Protocol Overview +----------------- + +Siemens has evolved their PLC communication protocols over time: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 15 50 + + * - Protocol + - Encryption + - Authentication + - Used by + * - S7 (classic) + - None + - None + - S7-300, S7-400, S7-1200, S7-1500 (PUT/GET mode) + * - S7CommPlus V1 + - None + - Challenge-response + - S7-1200 FW 4+, S7-1500 FW 1.x + * - S7CommPlus V2 + - Proprietary + - Yes + - S7-1500 FW 2.x + * - S7CommPlus V3 + - TLS + - Certificate-based + - S7-1500 FW 3.x+ + +python-snap7 implements the **classic S7 protocol**, which remains available +on most PLC families via the PUT/GET mechanism. For PLCs that only support +S7CommPlus V2 or V3 (such as the S7-1500R/H), no open-source solution +currently exists — consider using OPC UA as an alternative. + + +Alternatives for Unsupported PLCs +--------------------------------- + +If your PLC is not supported by python-snap7, consider these alternatives: + +- **OPC UA**: S7-1500 PLCs (FW 2.0+) include a built-in OPC UA server. Use + a Python OPC UA client such as `opcua-asyncio `_. +- **TIA Portal**: Siemens' official engineering tool supports all protocols + and PLC families. +- **PROFINET**: For real-time communication needs, PROFINET may be more + appropriate than S7 communication. diff --git a/doc/reading-writing.rst b/doc/reading-writing.rst new file mode 100644 index 00000000..cea8f14d --- /dev/null +++ b/doc/reading-writing.rst @@ -0,0 +1,380 @@ +Reading & Writing Data +====================== + +This page covers address mapping, data type conversions, memory area access, +and analog I/O — everything you need for reading from and writing to a PLC. + +All examples assume you have a connected client: + +.. code-block:: python + + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + +.. contents:: On this page + :local: + :depth: 2 + + +Address Mapping +--------------- + +PLC addresses in Siemens TIA Portal / STEP 7 map to python-snap7 API calls +as follows. + +.. list-table:: + :header-rows: 1 + :widths: 25 40 35 + + * - PLC Address + - python-snap7 Call + - Explanation + * - DB1.DBB0 + - ``db_read(1, 0, 1)`` + - 1 byte at offset 0 of DB1 + * - DB1.DBW10 + - ``db_read(1, 10, 2)`` + - 2 bytes (WORD) at offset 10 + * - DB1.DBD20 + - ``db_read(1, 20, 4)`` + - 4 bytes (DWORD) at offset 20 + * - DB1.DBX0.3 + - ``db_read(1, 0, 1)`` then ``get_bool(data, 0, 3)`` + - Bit 3 of byte 0 + * - M0.0 + - ``mb_read(0, 1)`` then ``get_bool(data, 0, 0)`` + - Bit 0 of merker byte 0 + * - MW10 + - ``mb_read(10, 2)`` + - 2 bytes (WORD) from merker byte 10 + * - IW0 / EW0 + - ``read_area(Area.PE, 0, 0, 2)`` + - Analog input word at address 0 + * - QW0 / AW0 + - ``read_area(Area.PA, 0, 0, 2)`` + - Analog output word at address 0 + +.. important:: + + The ``byte_index`` parameter in all ``snap7.util`` getter/setter functions + is **relative to the returned bytearray**, not the absolute PLC address. + + For example, to read DB1.DBX10.3: + + .. code-block:: python + + data = client.db_read(1, 10, 1) # Read 1 byte starting at offset 10 + value = snap7.util.get_bool(data, 0, 3) # byte_index=0, NOT 10 + + You read from PLC offset 10, but ``data[0]`` *is* byte 10 from the PLC. + + +Data Types +---------- + +Each example below shows a complete read and write cycle. + +BOOL +^^^^ + +Booleans require a **read-modify-write** pattern. You cannot write a single +bit to the PLC; you must read the enclosing byte, change the bit, then write +the whole byte back. + +.. code-block:: python + + # Read DB1.DBX0.3 (bit 3 of byte 0) + data = client.db_read(1, 0, 1) + value = snap7.util.get_bool(data, 0, 3) + print(f"DB1.DBX0.3 = {value}") + + # Write DB1.DBX0.3 -- read first, then modify, then write + data = client.db_read(1, 0, 1) + snap7.util.set_bool(data, 0, 3, True) + client.db_write(1, 0, data) + +.. warning:: + + Never write a freshly created ``bytearray`` for booleans. Always read the + current byte first to avoid overwriting neighboring bits. + +BYTE (1 byte, unsigned 0--255) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBB0 (1 byte at offset 0) + data = client.db_read(1, 0, 1) + value = snap7.util.get_byte(data, 0) + print(f"DB1.DBB0 = {value}") + + # Write + data = bytearray(1) + snap7.util.set_byte(data, 0, 200) + client.db_write(1, 0, data) + +INT (2 bytes, signed -32768 to 32767) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBW10 + data = client.db_read(1, 10, 2) + value = snap7.util.get_int(data, 0) + print(f"DB1.DBW10 = {value}") + + # Write + data = bytearray(2) + snap7.util.set_int(data, 0, -1234) + client.db_write(1, 10, data) + +WORD (2 bytes, unsigned 0--65535) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBW20 + data = client.db_read(1, 20, 2) + value = snap7.util.get_word(data, 0) + print(f"DB1.DBW20 = {value}") + + # Write + data = bytearray(2) + snap7.util.set_word(data, 0, 50000) + client.db_write(1, 20, data) + +DINT (4 bytes, signed -2147483648 to 2147483647) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBD30 + data = client.db_read(1, 30, 4) + value = snap7.util.get_dint(data, 0) + print(f"DB1.DBD30 = {value}") + + # Write + data = bytearray(4) + snap7.util.set_dint(data, 0, 100000) + client.db_write(1, 30, data) + +DWORD (4 bytes, unsigned 0--4294967295) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBD40 + data = client.db_read(1, 40, 4) + value = snap7.util.get_dword(data, 0) + print(f"DB1.DBD40 = {value}") + + # Write + data = bytearray(4) + snap7.util.set_dword(data, 0, 3000000000) + client.db_write(1, 40, data) + +REAL (4 bytes, IEEE 754 float) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1.DBD50 + data = client.db_read(1, 50, 4) + value = snap7.util.get_real(data, 0) + print(f"DB1.DBD50 = {value}") + + # Write + data = bytearray(4) + snap7.util.set_real(data, 0, 3.14) + client.db_write(1, 50, data) + +LREAL (8 bytes, IEEE 754 double) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read DB1, offset 60, 8 bytes + data = client.db_read(1, 60, 8) + value = snap7.util.get_lreal(data, 0) + print(f"LREAL = {value}") + + # Write + data = bytearray(8) + snap7.util.set_lreal(data, 0, 3.141592653589793) + client.db_write(1, 60, data) + +STRING (2 header bytes + characters) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +S7 strings have a specific format: + +- **Byte 0**: Maximum length (set when the variable is declared in the PLC) +- **Byte 1**: Actual (current) length of the string content +- **Bytes 2+**: ASCII characters + +When reading, always request ``max_length + 2`` bytes to include the header. + +.. code-block:: python + + # Read a string at DB1, offset 10, declared as STRING[20] in the PLC + max_length = 20 + data = client.db_read(1, 10, max_length + 2) # 20 + 2 header bytes = 22 + text = snap7.util.get_string(data, 0) + print(f"String = '{text}'") + + # Write a string + data = client.db_read(1, 10, max_length + 2) + snap7.util.set_string(data, 0, "Hello", max_length) + client.db_write(1, 10, data) + +.. note:: + + Always read the existing data before writing a string. The + ``set_string`` function preserves the max-length header byte and pads + unused characters with spaces. + +DATE_AND_TIME (8 bytes, BCD encoded) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from datetime import datetime + + # Read DATE_AND_TIME at DB1, offset 70 (returns ISO 8601 string) + data = client.db_read(1, 70, 8) + dt_string = snap7.util.get_dt(data, 0) + print(f"DATE_AND_TIME = {dt_string}") # e.g. '2024-06-15T14:30:00.000000' + + # Parse to Python datetime if needed + dt_obj = datetime.fromisoformat(dt_string) + + # Write DATE_AND_TIME + data = client.db_read(1, 70, 8) + snap7.util.set_dt(data, 0, datetime(2024, 6, 15, 14, 30, 0)) + client.db_write(1, 70, data) + + +Memory Areas +------------ + +python-snap7 provides convenience methods for data blocks and merkers, and +the generic ``read_area`` / ``write_area`` for all other areas. + +Data Blocks (DB) +^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read 10 bytes from DB1 starting at offset 0 + data = client.db_read(1, 0, 10) + + # Write 4 bytes to DB1 starting at offset 0 + client.db_write(1, 0, bytearray([0x01, 0x02, 0x03, 0x04])) + +Merkers / Flags (M) +^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Read 4 merker bytes starting at MB0 + data = client.mb_read(0, 4) + + # Write 2 bytes starting at MB10 + client.mb_write(10, 2, bytearray([0xFF, 0x00])) + +Inputs (I / E) +^^^^^^^^^^^^^^ + +.. code-block:: python + + from snap7.type import Area + + # Read 2 input bytes starting at IB0 + data = client.read_area(Area.PE, 0, 0, 2) + +Outputs (Q / A) +^^^^^^^^^^^^^^^ + +.. code-block:: python + + from snap7.type import Area + + # Read 2 output bytes starting at QB0 + data = client.read_area(Area.PA, 0, 0, 2) + + # Write to QB0 + client.write_area(Area.PA, 0, 0, bytearray([0x00, 0xFF])) + +Timers (T) +^^^^^^^^^^ + +.. code-block:: python + + from snap7.type import Area + + # Read timer T0 (1 timer = 2 bytes) + data = client.read_area(Area.TM, 0, 0, 1) + +Counters (C) +^^^^^^^^^^^^^ + +.. code-block:: python + + from snap7.type import Area + + # Read counter C0 (1 counter = 2 bytes) + data = client.read_area(Area.CT, 0, 0, 1) + + +Analog I/O +---------- + +Analog inputs are typically 16-bit integers in the peripheral input area +(``Area.PE``). The raw value from the PLC needs to be scaled to engineering +units. + +Reading Analog Inputs +^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + import snap7 + from snap7.type import Area + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + + # Read AIW0 (analog input word at address 0) + data = client.read_area(Area.PE, 0, 0, 2) + raw_value = snap7.util.get_int(data, 0) + print(f"Raw value: {raw_value}") + + # Scale to engineering units + # S7 analog modules typically use 0-27648 for 0-100% range + min_range = 0.0 # e.g., 0 bar + max_range = 10.0 # e.g., 10 bar + scaled = raw_value * (max_range - min_range) / 27648.0 + min_range + print(f"Pressure: {scaled:.2f} bar") + + # Read AIW2 (second analog input) + data = client.read_area(Area.PE, 0, 2, 2) + raw_value = snap7.util.get_int(data, 0) + +Writing Analog Outputs +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from snap7.type import Area + + # Write to AQW0 (analog output word at address 0) + data = bytearray(2) + snap7.util.set_int(data, 0, 13824) # ~50% of 27648 + client.write_area(Area.PA, 0, 0, data) + +.. note:: + + The standard scaling factor 27648 applies to most Siemens analog I/O + modules. Check your module documentation for the actual range. diff --git a/doc/server.rst b/doc/server.rst new file mode 100644 index 00000000..f46e1649 --- /dev/null +++ b/doc/server.rst @@ -0,0 +1,113 @@ +Server Setup for Testing +======================== + +The built-in server lets you test your client code without a physical PLC. + +.. contents:: On this page + :local: + :depth: 2 + + +Basic Server Example +-------------------- + +.. code-block:: python + + from snap7.server import Server + from snap7.type import SrvArea + from ctypes import c_char + + # Create and configure the server + server = Server() + + # Register a data block (DB1) with 100 bytes + db_size = 100 + db_data = bytearray(db_size) + db_array = (c_char * db_size).from_buffer(db_data) + server.register_area(SrvArea.DB, 1, db_array) + + # Start the server on a non-privileged port + server.start(tcp_port=1102) + + +Client-Server Round Trip +------------------------- + +.. code-block:: python + + import snap7 + from snap7.server import Server + from snap7.type import SrvArea + from ctypes import c_char + + # --- Server setup --- + server = Server() + db_size = 100 + db_data = bytearray(db_size) + db_array = (c_char * db_size).from_buffer(db_data) + server.register_area(SrvArea.DB, 1, db_array) + server.start(tcp_port=1102) + + # --- Client connection --- + client = snap7.Client() + client.connect("127.0.0.1", 0, 1, tcp_port=1102) + + # Write data + client.db_write(1, 0, bytearray([0x01, 0x02, 0x03, 0x04])) + + # Read it back + data = client.db_read(1, 0, 4) + print(f"Read back: {list(data)}") # [1, 2, 3, 4] + + # Clean up + client.disconnect() + server.stop() + + +Registering Multiple Areas +--------------------------- + +.. code-block:: python + + from snap7.server import Server + from snap7.type import SrvArea + from ctypes import c_char + + server = Server() + + # Register DB1 + db1_data = bytearray(100) + db1 = (c_char * 100).from_buffer(db1_data) + server.register_area(SrvArea.DB, 1, db1) + + # Register DB2 + db2_data = bytearray(200) + db2 = (c_char * 200).from_buffer(db2_data) + server.register_area(SrvArea.DB, 2, db2) + + # Register merker area (256 bytes) + mk_data = bytearray(256) + mk = (c_char * 256).from_buffer(mk_data) + server.register_area(SrvArea.MK, 0, mk) + + server.start(tcp_port=1102) + +.. note:: + + Use a port number above 1024 (e.g., 1102) to avoid requiring root/admin + privileges. Port 102 is the standard S7 port but is in the privileged + range. + + +Using the Mainloop Helper +-------------------------- + +For quick testing, the ``mainloop`` function starts a server with common +data blocks pre-registered: + +.. code-block:: python + + from snap7.server import mainloop + + # Blocks the current thread + mainloop(tcp_port=1102) diff --git a/doc/thread-safety.rst b/doc/thread-safety.rst new file mode 100644 index 00000000..235a89f6 --- /dev/null +++ b/doc/thread-safety.rst @@ -0,0 +1,39 @@ +Thread Safety +============= + +The ``Client`` class is **not** thread-safe. Concurrent calls from multiple +threads on the same ``Client`` instance will corrupt the TCP connection state +and cause unpredictable errors. + +**Option 1: One client per thread** + +.. code-block:: python + + import threading + import snap7 + + def worker(address: str, rack: int, slot: int) -> None: + client = snap7.Client() + client.connect(address, rack, slot) + data = client.db_read(1, 0, 10) + client.disconnect() + + t1 = threading.Thread(target=worker, args=("192.168.1.10", 0, 1)) + t2 = threading.Thread(target=worker, args=("192.168.1.10", 0, 1)) + t1.start() + t2.start() + +**Option 2: Shared client with a lock** + +.. code-block:: python + + import threading + import snap7 + + client = snap7.Client() + client.connect("192.168.1.10", 0, 1) + lock = threading.Lock() + + def safe_read(db: int, start: int, size: int) -> bytearray: + with lock: + return client.db_read(db, start, size) diff --git a/doc/tia-portal-config.rst b/doc/tia-portal-config.rst new file mode 100644 index 00000000..73932db6 --- /dev/null +++ b/doc/tia-portal-config.rst @@ -0,0 +1,56 @@ +.. _tia-portal-config: + +TIA Portal Configuration +========================= + +S7-1200 and S7-1500 PLCs require specific configuration in TIA Portal before +python-snap7 can communicate with them. Without these settings, you will get +``CLI : function refused by CPU`` errors. + +.. contents:: On this page + :local: + :depth: 2 + + +Step 1: Enable PUT/GET Communication +------------------------------------- + +1. Open your project in TIA Portal. +2. In the project tree, double-click on the PLC device. +3. Go to **Properties** > **Protection & Security** > **Connection mechanisms**. +4. Check **Permit access with PUT/GET communication from remote partner**. +5. Compile and download to the PLC. + +.. warning:: + + This setting allows any network client to read and write PLC memory without + authentication. Only enable this on isolated industrial networks. + + +Step 2: Disable Optimized Block Access +--------------------------------------- + +This must be done for **each** data block you want to access: + +1. In the project tree, right-click on the data block (e.g., DB1). +2. Select **Properties**. +3. Go to the **Attributes** tab. +4. **Uncheck** "Optimized block access". +5. Click OK. +6. Compile and download to the PLC. + +.. warning:: + + Changing the "Optimized block access" setting reinitializes the data block, + which resets all values in that DB to their defaults. Do this before + commissioning, or back up your data first. + + +Step 3: Compile and Download +----------------------------- + +After making both changes: + +1. Compile the project (**Build** > **Compile**). +2. Download to the PLC (**Online** > **Download to device**). +3. The PLC may need to restart depending on the changes. diff --git a/pyproject.toml b/pyproject.toml index ef5d6767..3e28ea7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ Homepage = "https://github.com/gijzelaerr/python-snap7" Documentation = "https://python-snap7.readthedocs.io/en/latest/" [project.optional-dependencies] -test = ["pytest", "pytest-cov", "pytest-html", "mypy", "types-setuptools", "ruff", "tox", "tox-uv", "types-click", "uv"] +test = ["pytest", "pytest-asyncio", "pytest-cov", "pytest-html", "mypy", "types-setuptools", "ruff", "tox", "tox-uv", "types-click", "uv"] cli = ["rich", "click" ] doc = ["sphinx", "sphinx_rtd_theme"] @@ -44,6 +44,7 @@ include = ["snap7*"] [project.scripts] snap7-server = "snap7.server:mainloop" +s7 = "snap7.cli:main" [tool.pytest.ini_options] testpaths = ["tests"] @@ -55,8 +56,10 @@ markers =[ "mainloop", "partner", "server", - "util" + "util", + "conformance: protocol conformance tests" ] +asyncio_mode = "auto" [tool.mypy] ignore_missing_imports = true @@ -64,6 +67,9 @@ strict = true # https://github.com/python/mypy/issues/2427#issuecomment-1419206807 disable_error_code = ["method-assign", "attr-defined"] +[tool.coverage.report] +fail_under = 75 + [tool.ruff] output-format = "full" line-length = 130 diff --git a/snap7/__init__.py b/snap7/__init__.py index 1b9756d3..ba87536d 100644 --- a/snap7/__init__.py +++ b/snap7/__init__.py @@ -8,6 +8,7 @@ from importlib.metadata import version, PackageNotFoundError from .client import Client +from .async_client import AsyncClient from .server import Server from .partner import Partner from .logo import Logo @@ -16,6 +17,7 @@ __all__ = [ "Client", + "AsyncClient", "Server", "Partner", "Logo", diff --git a/snap7/async_client.py b/snap7/async_client.py new file mode 100644 index 00000000..dd767031 --- /dev/null +++ b/snap7/async_client.py @@ -0,0 +1,1275 @@ +""" +Native async S7 client implementation. + +Uses asyncio streams for non-blocking I/O with an asyncio.Lock() to serialize +send/receive cycles, ensuring safe concurrent use via asyncio.gather(). +""" + +import asyncio +import logging +import struct +import time +from typing import List, Any, Optional, Tuple, Type +from types import TracebackType +from datetime import datetime + +from .connection import TPDUSize +from .s7protocol import S7Protocol, get_return_code_description +from .datatypes import S7WordLen +from .error import S7Error, S7ConnectionError, S7ProtocolError, S7TimeoutError +from .client_base import ClientMixin +from .type import ( + Area, + Block, + BlocksList, + S7CpuInfo, + TS7BlockInfo, + S7CpInfo, + S7OrderCode, + S7Protection, + S7SZL, + Parameter, +) + + +logger = logging.getLogger(__name__) + + +class AsyncISOTCPConnection: + """Async ISO on TCP connection using asyncio streams. + + Mirrors ISOTCPConnection but uses asyncio.open_connection() instead of + blocking sockets for non-blocking I/O. + """ + + # COTP PDU types + COTP_CR = 0xE0 # Connection Request + COTP_CC = 0xD0 # Connection Confirm + COTP_DR = 0x80 # Disconnect Request + COTP_DT = 0xF0 # Data Transfer + + # COTP parameter codes (ISO 8073) + COTP_PARAM_PDU_SIZE = 0xC0 + COTP_PARAM_CALLING_TSAP = 0xC1 + COTP_PARAM_CALLED_TSAP = 0xC2 + + def __init__( + self, + host: str, + port: int = 102, + local_tsap: int = 0x0100, + remote_tsap: int = 0x0102, + tpdu_size: TPDUSize = TPDUSize.S_1024, + ): + self.host = host + self.port = port + self.local_tsap = local_tsap + self.remote_tsap = remote_tsap + self.tpdu_size = tpdu_size + self.connected = False + self.pdu_size = 240 + self.timeout = 5.0 + + self.src_ref = 0x0001 + self.dst_ref = 0x0000 + + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + + async def connect(self, timeout: float = 5.0) -> None: + """Establish ISO on TCP connection.""" + self.timeout = timeout + + try: + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(self.host, self.port), + timeout=self.timeout, + ) + logger.debug(f"TCP connected to {self.host}:{self.port}") + + await self._iso_connect() + + self.connected = True + logger.info(f"Connected to {self.host}:{self.port}, PDU size: {self.pdu_size}") + + except Exception as e: + await self.disconnect() + if isinstance(e, (S7ConnectionError, S7TimeoutError)): + raise + elif isinstance(e, asyncio.TimeoutError): + raise S7TimeoutError(f"Connection timeout: {e}") + else: + raise S7ConnectionError(f"Connection failed: {e}") + + async def disconnect(self) -> None: + """Disconnect from S7 device.""" + if self._writer: + try: + if self.connected: + dr_pdu = struct.pack( + ">BBHHBB", + 6, + self.COTP_DR, + self.dst_ref, + self.src_ref, + 0x00, + 0x00, + ) + self._writer.write(self._build_tpkt(dr_pdu)) + await self._writer.drain() + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + finally: + self._reader = None + self._writer = None + self.connected = False + logger.info(f"Disconnected from {self.host}:{self.port}") + + async def send_data(self, data: bytes) -> None: + """Send data over ISO connection.""" + if not self.connected or self._writer is None: + raise S7ConnectionError("Not connected") + + cotp_header = struct.pack(">BBB", 2, self.COTP_DT, 0x80) + tpkt_frame = self._build_tpkt(cotp_header + data) + + try: + self._writer.write(tpkt_frame) + await self._writer.drain() + logger.debug(f"Sent {len(tpkt_frame)} bytes") + except (OSError, ConnectionError) as e: + self.connected = False + raise S7ConnectionError(f"Send failed: {e}") + + async def receive_data(self) -> bytes: + """Receive data from ISO connection.""" + if not self.connected: + raise S7ConnectionError("Not connected") + + try: + tpkt_header = await self._recv_exact(4) + version, reserved, length = struct.unpack(">BBH", tpkt_header) + if version != 3: + raise S7ConnectionError(f"Invalid TPKT version: {version}") + + remaining = length - 4 + if remaining <= 0: + raise S7ConnectionError("Invalid TPKT length") + + payload = await self._recv_exact(remaining) + + # Parse COTP DT header + if len(payload) < 3: + raise S7ConnectionError("Invalid COTP DT: too short") + pdu_len, pdu_type, eot_num = struct.unpack(">BBB", payload[:3]) + if pdu_type != self.COTP_DT: + raise S7ConnectionError(f"Expected COTP DT, got {pdu_type:#02x}") + return payload[3:] + + except asyncio.TimeoutError: + self.connected = False + raise S7TimeoutError("Receive timeout") + except (OSError, ConnectionError) as e: + self.connected = False + raise S7ConnectionError(f"Receive failed: {e}") + + async def _iso_connect(self) -> None: + """Establish ISO connection using COTP handshake.""" + if self._writer is None or self._reader is None: + raise S7ConnectionError("Stream not initialized") + + # Build and send COTP Connection Request + base_pdu = struct.pack( + ">BBHHB", + 6, + self.COTP_CR, + 0x0000, + self.src_ref, + 0x00, + ) + calling_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLING_TSAP, 2, self.local_tsap) + called_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLED_TSAP, 2, self.remote_tsap) + pdu_size_param = struct.pack(">BBB", self.COTP_PARAM_PDU_SIZE, 1, self.tpdu_size) + parameters = calling_tsap + called_tsap + pdu_size_param + total_length = 6 + len(parameters) + cr_pdu = struct.pack(">B", total_length) + base_pdu[1:] + parameters + + self._writer.write(self._build_tpkt(cr_pdu)) + await self._writer.drain() + logger.debug("Sent COTP Connection Request") + + # Receive Connection Confirm + tpkt_header = await self._recv_exact(4) + version, reserved, length = struct.unpack(">BBH", tpkt_header) + if version != 3: + raise S7ConnectionError(f"Invalid TPKT version in response: {version}") + + payload = await self._recv_exact(length - 4) + self._parse_cotp_cc(payload) + logger.debug("Received COTP Connection Confirm") + + def _build_tpkt(self, payload: bytes) -> bytes: + """Build TPKT frame.""" + length = len(payload) + 4 + return struct.pack(">BBH", 3, 0, length) + payload + + def _parse_cotp_cc(self, data: bytes) -> None: + """Parse COTP Connection Confirm PDU.""" + if len(data) < 7: + raise S7ConnectionError("Invalid COTP CC: too short") + + pdu_len, pdu_type, dst_ref, src_ref, class_opt = struct.unpack(">BBHHB", data[:7]) + if pdu_type != self.COTP_CC: + raise S7ConnectionError(f"Expected COTP CC, got {pdu_type:#02x}") + + self.dst_ref = dst_ref + + # Parse parameters + offset = 7 + while offset < len(data): + if offset + 2 > len(data): + break + param_code = data[offset] + param_len = data[offset + 1] + if offset + 2 + param_len > len(data): + break + param_data = data[offset + 2 : offset + 2 + param_len] + if param_code == self.COTP_PARAM_PDU_SIZE: + if param_len == 1: + self.pdu_size = 1 << param_data[0] + elif param_len == 2: + self.pdu_size = struct.unpack(">H", param_data)[0] + logger.debug(f"Negotiated PDU size: {self.pdu_size}") + offset += 2 + param_len + + async def _recv_exact(self, size: int) -> bytes: + """Receive exactly size bytes.""" + if self._reader is None: + raise S7ConnectionError("Stream not initialized") + try: + return await asyncio.wait_for( + self._reader.readexactly(size), + timeout=self.timeout, + ) + except asyncio.IncompleteReadError: + self.connected = False + raise S7ConnectionError("Connection closed by peer") + except asyncio.TimeoutError: + self.connected = False + raise S7TimeoutError("Receive timeout") + except (OSError, ConnectionError) as e: + self.connected = False + raise S7ConnectionError(f"Receive error: {e}") + + async def __aenter__(self) -> "AsyncISOTCPConnection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.disconnect() + + +class AsyncClient(ClientMixin): + """ + Native async S7 client implementation. + + Uses asyncio streams for non-blocking I/O. An internal asyncio.Lock + serializes each send+receive cycle so that concurrent coroutines + (e.g. via asyncio.gather) never interleave on the same TCP socket. + + Examples: + >>> import snap7 + >>> async with snap7.AsyncClient() as client: + ... await client.connect("192.168.1.10", 0, 1) + ... data = await client.db_read(1, 0, 4) + """ + + MAX_VARS = 20 + + def __init__(self) -> None: + self.connection: Optional[AsyncISOTCPConnection] = None + self.protocol = S7Protocol() + self.connected = False + self.host = "" + self.port = 102 + self.rack = 0 + self.slot = 0 + self.pdu_length = 480 + + self.local_tsap = 0x0100 + self.remote_tsap = 0x0102 + self.connection_type = 1 # PG + self.session_password: Optional[str] = None + + self._exec_time = 0 + self._last_error = 0 + + self._lock = asyncio.Lock() + + self._params = { + Parameter.RemotePort: 102, + Parameter.SendTimeout: 10, + Parameter.RecvTimeout: 3000, + Parameter.SrcRef: 256, + Parameter.DstRef: 0, + Parameter.SrcTSap: 256, + Parameter.PDURequest: 480, + } + + logger.info("AsyncClient initialized (native async implementation)") + + def _get_connection(self) -> AsyncISOTCPConnection: + """Get connection, raising if not connected.""" + if self.connection is None: + raise S7ConnectionError("Not connected to PLC") + return self.connection + + async def _send_receive(self, request: bytes, max_stale_retries: int = 3) -> dict[str, Any]: + """Send a request and receive/parse the response, holding the lock. + + The lock ensures that concurrent coroutines never interleave + send/receive on the same TCP socket. + + Unlike the sync client, we do NOT use protocol.validate_pdu_reference() + because the protocol's shared sequence counter can be incremented by + a concurrent coroutine between request building and lock acquisition. + Instead, we extract the expected sequence directly from the request + bytes (S7 header bytes 4-5). + """ + conn = self._get_connection() + + # Extract the sequence number we embedded in this request's S7 header. + # S7 header: 0x32 | pdu_type | reserved(2) | sequence(2) | ... + expected_seq = struct.unpack(">H", request[4:6])[0] + + async with self._lock: + await conn.send_data(request) + + for attempt in range(max_stale_retries + 1): + response_data = await conn.receive_data() + response = self.protocol.parse_response(response_data) + + resp_seq = response.get("sequence", 0) + if resp_seq == expected_seq: + return response + + # Stale packet — response is for an older request + if attempt < max_stale_retries: + logger.warning( + f"Stale packet: expected seq {expected_seq}, got {resp_seq} " + f"(attempt {attempt + 1}/{max_stale_retries}), retrying receive" + ) + continue + raise S7ProtocolError(f"Max stale packet retries ({max_stale_retries}) exceeded") + + raise S7ProtocolError("Failed to receive valid response") # Should not reach here + + async def connect(self, address: str, rack: int, slot: int, tcp_port: int = 102) -> "AsyncClient": + """Connect to S7 PLC. + + Args: + address: PLC IP address + rack: Rack number + slot: Slot number + tcp_port: TCP port (default 102) + + Returns: + Self for method chaining + """ + self.host = address + self.port = tcp_port + self.rack = rack + self.slot = slot + self._params[Parameter.RemotePort] = tcp_port + + self.remote_tsap = 0x0100 | (rack << 5) | slot + + try: + start_time = time.time() + + self.connection = AsyncISOTCPConnection( + host=address, port=tcp_port, local_tsap=self.local_tsap, remote_tsap=self.remote_tsap + ) + + await self.connection.connect() + + await self._setup_communication() + + self.connected = True + self._exec_time = int((time.time() - start_time) * 1000) + logger.info(f"Connected to {address}:{tcp_port} rack {rack} slot {slot}") + + except Exception as e: + await self.disconnect() + if isinstance(e, S7Error): + raise + else: + raise S7ConnectionError(f"Connection failed: {e}") + + return self + + async def disconnect(self) -> int: + """Disconnect from S7 PLC. + + Returns: + 0 on success + """ + if self.connection: + await self.connection.disconnect() + self.connection = None + + self.connected = False + logger.info(f"Disconnected from {self.host}:{self.port}") + return 0 + + def get_connected(self) -> bool: + """Check if client is connected.""" + return self.connected and self.connection is not None and self.connection.connected + + # --------------------------------------------------------------- + # DB helpers + # --------------------------------------------------------------- + + async def db_read(self, db_number: int, start: int, size: int) -> bytearray: + """Read data from DB. + + Args: + db_number: DB number to read from + start: Start byte offset + size: Number of bytes to read + + Returns: + Data read from DB + """ + logger.debug(f"db_read: DB{db_number}, start={start}, size={size}") + return await self.read_area(Area.DB, db_number, start, size) + + async def db_write(self, db_number: int, start: int, data: bytearray) -> int: + """Write data to DB. + + Args: + db_number: DB number to write to + start: Start byte offset + data: Data to write + + Returns: + 0 on success + """ + logger.debug(f"db_write: DB{db_number}, start={start}, size={len(data)}") + await self.write_area(Area.DB, db_number, start, data) + return 0 + + async def db_get(self, db_number: int, size: int = 0) -> bytearray: + """Get entire DB. + + Args: + db_number: DB number to read + size: DB size in bytes. If 0, determined via get_block_info(). + + Returns: + Entire DB contents + """ + if size <= 0: + block_info = await self.get_block_info(Block.DB, db_number) + size = block_info.MC7Size if block_info.MC7Size > 0 else 65536 + return await self.db_read(db_number, 0, size) + + async def db_fill(self, db_number: int, filler: int, size: int = 0) -> int: + """Fill a DB with a filler byte. + + Args: + db_number: DB number to fill + filler: Byte value to fill with + size: DB size in bytes. If 0, determined via get_block_info(). + + Returns: + 0 on success + """ + if size <= 0: + block_info = await self.get_block_info(Block.DB, db_number) + size = block_info.MC7Size if block_info.MC7Size > 0 else 65536 + data = bytearray([filler] * size) + return await self.db_write(db_number, 0, data) + + # --------------------------------------------------------------- + # Core read / write + # --------------------------------------------------------------- + + async def read_area(self, area: Area, db_number: int, start: int, size: int) -> bytearray: + """Read data from memory area. + + Automatically splits into multiple requests if size exceeds PDU capacity. + """ + start_time = time.time() + s7_area = self._map_area(area) + + if area == Area.TM: + word_len = S7WordLen.TIMER + elif area == Area.CT: + word_len = S7WordLen.COUNTER + else: + word_len = S7WordLen.BYTE + + max_chunk = self._max_read_size() + if size <= max_chunk: + request = self.protocol.build_read_request( + area=s7_area, db_number=db_number, start=start, word_len=word_len, count=size + ) + response = await self._send_receive(request) + values = self.protocol.extract_read_data(response, word_len, size) + self._exec_time = int((time.time() - start_time) * 1000) + return bytearray(values) + + result = bytearray() + offset = 0 + remaining = size + while remaining > 0: + chunk_size = min(remaining, max_chunk) + request = self.protocol.build_read_request( + area=s7_area, db_number=db_number, start=start + offset, word_len=word_len, count=chunk_size + ) + response = await self._send_receive(request) + values = self.protocol.extract_read_data(response, word_len, chunk_size) + result.extend(values) + offset += chunk_size + remaining -= chunk_size + + self._exec_time = int((time.time() - start_time) * 1000) + return result + + async def write_area(self, area: Area, db_number: int, start: int, data: bytearray) -> int: + """Write data to memory area. + + Automatically splits into multiple requests if data exceeds PDU capacity. + """ + start_time = time.time() + s7_area = self._map_area(area) + + if area == Area.TM: + word_len = S7WordLen.TIMER + elif area == Area.CT: + word_len = S7WordLen.COUNTER + else: + word_len = S7WordLen.BYTE + + max_chunk = self._max_write_size() + if len(data) <= max_chunk: + request = self.protocol.build_write_request( + area=s7_area, db_number=db_number, start=start, word_len=word_len, data=bytes(data) + ) + response = await self._send_receive(request) + self.protocol.check_write_response(response) + self._exec_time = int((time.time() - start_time) * 1000) + return 0 + + offset = 0 + remaining = len(data) + while remaining > 0: + chunk_size = min(remaining, max_chunk) + chunk_data = data[offset : offset + chunk_size] + request = self.protocol.build_write_request( + area=s7_area, db_number=db_number, start=start + offset, word_len=word_len, data=bytes(chunk_data) + ) + response = await self._send_receive(request) + self.protocol.check_write_response(response) + offset += chunk_size + remaining -= chunk_size + + self._exec_time = int((time.time() - start_time) * 1000) + return 0 + + async def read_multi_vars(self, items: List[dict[str, Any]]) -> Tuple[int, list[bytearray]]: + """Read multiple variables (sequentially, one read_area per item). + + Args: + items: List of item dicts with keys: area, db_number, start, size + + Returns: + Tuple of (result_code, list_of_bytearrays) + """ + if not items: + return (0, []) + if len(items) > self.MAX_VARS: + raise ValueError(f"Too many items: {len(items)} exceeds MAX_VARS ({self.MAX_VARS})") + + results: list[bytearray] = [] + for item in items: + area = item["area"] + db_number = item.get("db_number", 0) + start = item["start"] + size = item["size"] + data = await self.read_area(area, db_number, start, size) + results.append(data) + return (0, results) + + async def write_multi_vars(self, items: List[dict[str, Any]]) -> int: + """Write multiple variables (sequentially, one write_area per item). + + Args: + items: List of item dicts with keys: area, db_number, start, data + + Returns: + 0 on success + """ + if not items: + return 0 + if len(items) > self.MAX_VARS: + raise ValueError(f"Too many items: {len(items)} exceeds MAX_VARS ({self.MAX_VARS})") + + for item in items: + area = item["area"] + db_number = item.get("db_number", 0) + start = item["start"] + data = item["data"] + await self.write_area(area, db_number, start, data) + return 0 + + # --------------------------------------------------------------- + # Block operations + # --------------------------------------------------------------- + + async def list_blocks(self) -> BlocksList: + """List blocks available in PLC.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + request = self.protocol.build_list_blocks_request() + response = await self._send_receive(request) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + desc = get_return_code_description(return_code) + raise S7ProtocolError(f"List blocks failed: {desc} (0x{return_code:02x})") + + counts = self.protocol.parse_list_blocks_response(response) + + block_list = BlocksList() + block_list.OBCount = counts.get("OBCount", 0) + block_list.FBCount = counts.get("FBCount", 0) + block_list.FCCount = counts.get("FCCount", 0) + block_list.SFBCount = counts.get("SFBCount", 0) + block_list.SFCCount = counts.get("SFCCount", 0) + block_list.DBCount = counts.get("DBCount", 0) + block_list.SDBCount = counts.get("SDBCount", 0) + + return block_list + + async def list_blocks_of_type(self, block_type: Block, max_count: int) -> List[int]: + """List blocks of a specific type. + + Supports multi-packet responses. + """ + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + conn = self._get_connection() + + block_type_codes = { + Block.OB: 0x38, + Block.DB: 0x41, + Block.SDB: 0x42, + Block.FC: 0x43, + Block.SFC: 0x44, + Block.FB: 0x45, + Block.SFB: 0x46, + } + type_code = block_type_codes.get(block_type, 0x41) + + request = self.protocol.build_list_blocks_of_type_request(type_code) + response = await self._send_receive(request) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + desc = get_return_code_description(return_code) + raise S7ProtocolError(f"List blocks of type failed: {desc} (0x{return_code:02x})") + + accumulated_data = bytearray(data_info.get("data", b"") if isinstance(data_info, dict) else b"") + + params = response.get("parameters", {}) + last_data_unit = params.get("last_data_unit", 0x00) if isinstance(params, dict) else 0x00 + sequence_number = params.get("sequence_number", 0) if isinstance(params, dict) else 0 + group = params.get("group", 0x03) if isinstance(params, dict) else 0x03 + subfunction = params.get("subfunction", 0x02) if isinstance(params, dict) else 0x02 + + for _ in range(100): + if last_data_unit == 0x00: + break + + async with self._lock: + followup = self.protocol.build_userdata_followup_request(group, subfunction, sequence_number) + await conn.send_data(followup) + response_data = await conn.receive_data() + + response = self.protocol.parse_response(response_data) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + break + + accumulated_data.extend(data_info.get("data", b"") if isinstance(data_info, dict) else b"") + + params = response.get("parameters", {}) + last_data_unit = params.get("last_data_unit", 0x00) if isinstance(params, dict) else 0x00 + sequence_number = params.get("sequence_number", 0) if isinstance(params, dict) else 0 + + combined_response: dict[str, Any] = {"data": {"data": bytes(accumulated_data)}} + block_numbers = self.protocol.parse_list_blocks_of_type_response(combined_response) + + return block_numbers[:max_count] + + async def get_block_info(self, block_type: Block, db_number: int) -> TS7BlockInfo: + """Get block information.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + block_type_map = { + Block.OB: 0x38, + Block.DB: 0x41, + Block.SDB: 0x42, + Block.FC: 0x43, + Block.SFC: 0x44, + Block.FB: 0x45, + Block.SFB: 0x46, + } + type_code = block_type_map.get(block_type, 0x41) + + request = self.protocol.build_get_block_info_request(type_code, db_number) + response = await self._send_receive(request) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + desc = get_return_code_description(return_code) + raise S7ProtocolError(f"Get block info failed: {desc} (0x{return_code:02x})") + + info = self.protocol.parse_get_block_info_response(response) + + block_info = TS7BlockInfo() + block_info.BlkType = info["block_type"] + block_info.BlkNumber = info["block_number"] + block_info.BlkLang = info["block_lang"] + block_info.BlkFlags = info["block_flags"] + block_info.MC7Size = info["mc7_size"] + block_info.LoadSize = info["load_size"] + block_info.LocalData = info["local_data"] + block_info.SBBLength = info["sbb_length"] + block_info.CheckSum = info["checksum"] + block_info.Version = info["version"] + + if info["code_date"]: + block_info.CodeDate = info["code_date"][:10] + if info["intf_date"]: + block_info.IntfDate = info["intf_date"][:10] + if info["author"]: + block_info.Author = info["author"][:8] + if info["family"]: + block_info.Family = info["family"][:8] + if info["header"]: + block_info.Header = info["header"][:8] + + return block_info + + # --------------------------------------------------------------- + # CPU info / state + # --------------------------------------------------------------- + + async def get_cpu_info(self) -> S7CpuInfo: + """Get CPU information.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + szl = await self.read_szl(0x001C, 0) + + cpu_info = S7CpuInfo() + data = bytes(szl.Data[: szl.Header.LengthDR]) + + if len(data) >= 32: + cpu_info.ModuleTypeName = data[0:32].rstrip(b"\x00") + if len(data) >= 56: + cpu_info.SerialNumber = data[32:56].rstrip(b"\x00") + if len(data) >= 80: + cpu_info.ASName = data[56:80].rstrip(b"\x00") + if len(data) >= 106: + cpu_info.Copyright = data[80:106].rstrip(b"\x00") + if len(data) >= 130: + cpu_info.ModuleName = data[106:130].rstrip(b"\x00") + + return cpu_info + + async def get_cpu_state(self) -> str: + """Get CPU state (running/stopped).""" + request = self.protocol.build_cpu_state_request() + response = await self._send_receive(request) + return self.protocol.extract_cpu_state(response) + + # --------------------------------------------------------------- + # Upload / Download / Delete + # --------------------------------------------------------------- + + async def upload(self, block_num: int) -> bytearray: + """Upload block from PLC (3-step: START_UPLOAD, UPLOAD, END_UPLOAD).""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + block_type = 0x41 # DB + + request = self.protocol.build_start_upload_request(block_type, block_num) + response = await self._send_receive(request) + + upload_info = self.protocol.parse_start_upload_response(response) + upload_id = upload_info.get("upload_id", 1) + + request = self.protocol.build_upload_request(upload_id) + response = await self._send_receive(request) + + block_data = self.protocol.parse_upload_response(response) + + request = self.protocol.build_end_upload_request(upload_id) + response = await self._send_receive(request) + + logger.info(f"Uploaded {len(block_data)} bytes from block {block_num}") + return bytearray(block_data) + + async def download(self, data: bytearray, block_num: int = -1) -> int: + """Download block to PLC.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + conn = self._get_connection() + block_type = 0x41 # DB + + if block_num == -1: + if len(data) >= 8: + block_num = struct.unpack(">H", data[6:8])[0] + else: + block_num = 1 + + # Step 1: Request download + request = self.protocol.build_download_request(block_type, block_num, bytes(data)) + await self._send_receive(request) + + # Step 2: Download block (send data) + param_data = struct.pack(">BBB", 0x1B, 0x01, 0x00) + data_section = struct.pack(">HH", len(data), 0x00FB) + bytes(data) + header = struct.pack( + ">BBHHHH", + 0x32, + 0x01, + 0x0000, + self.protocol._next_sequence(), + len(param_data), + len(data_section), + ) + + async with self._lock: + await conn.send_data(header + param_data + data_section) + response_data = await conn.receive_data() + self.protocol.parse_response(response_data) + + # Step 3: Download ended + param_data = struct.pack(">B", 0x1C) + header = struct.pack( + ">BBHHHH", + 0x32, + 0x01, + 0x0000, + self.protocol._next_sequence(), + len(param_data), + 0x0000, + ) + + async with self._lock: + await conn.send_data(header + param_data) + response_data = await conn.receive_data() + self.protocol.parse_response(response_data) + + logger.info(f"Downloaded {len(data)} bytes to block {block_num}") + return 0 + + async def delete(self, block_type: Block, block_num: int) -> int: + """Delete a block from PLC.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + block_type_map = { + Block.OB: 0x38, + Block.DB: 0x41, + Block.SDB: 0x42, + Block.FC: 0x43, + Block.SFC: 0x44, + Block.FB: 0x45, + Block.SFB: 0x46, + } + type_code = block_type_map.get(block_type, 0x41) + + request = self.protocol.build_delete_block_request(type_code, block_num) + response = await self._send_receive(request) + self.protocol.check_control_response(response) + + logger.info(f"Deleted block {block_type.name} {block_num}") + return 0 + + async def full_upload(self, block_type: Block, block_num: int) -> Tuple[bytearray, int]: + """Upload a block from PLC with header and footer info.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + block_type_map = { + Block.OB: 0x38, + Block.DB: 0x41, + Block.SDB: 0x42, + Block.FC: 0x43, + Block.SFC: 0x44, + Block.FB: 0x45, + Block.SFB: 0x46, + } + type_code = block_type_map.get(block_type, 0x41) + + request = self.protocol.build_start_upload_request(type_code, block_num) + response = await self._send_receive(request) + + upload_info = self.protocol.parse_start_upload_response(response) + upload_id = upload_info.get("upload_id", 1) + + request = self.protocol.build_upload_request(upload_id) + response = await self._send_receive(request) + block_data = self.protocol.parse_upload_response(response) + + request = self.protocol.build_end_upload_request(upload_id) + response = await self._send_receive(request) + + block_header = struct.pack( + ">BBHBBBBHH", + 0x70, + block_type.value, + block_num, + 0x00, + 0x00, + 0x00, + 0x00, + len(block_data) + 14, + len(block_data), + ) + block_footer = b"\x00" * 4 + full_block = bytearray(block_header + block_data + block_footer) + + logger.info(f"Full upload of block {block_type.name} {block_num}: {len(full_block)} bytes") + return full_block, len(full_block) + + # --------------------------------------------------------------- + # PLC control + # --------------------------------------------------------------- + + async def plc_stop(self) -> int: + """Stop PLC CPU.""" + request = self.protocol.build_plc_control_request("stop") + response = await self._send_receive(request) + self.protocol.check_control_response(response) + return 0 + + async def plc_hot_start(self) -> int: + """Hot start PLC CPU.""" + request = self.protocol.build_plc_control_request("hot_start") + response = await self._send_receive(request) + self.protocol.check_control_response(response) + return 0 + + async def plc_cold_start(self) -> int: + """Cold start PLC CPU.""" + request = self.protocol.build_plc_control_request("cold_start") + response = await self._send_receive(request) + self.protocol.check_control_response(response) + return 0 + + # --------------------------------------------------------------- + # Date / time + # --------------------------------------------------------------- + + async def get_plc_datetime(self) -> datetime: + """Get PLC date/time.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + request = self.protocol.build_get_clock_request() + response = await self._send_receive(request) + return self.protocol.parse_get_clock_response(response) + + async def set_plc_datetime(self, dt: datetime) -> int: + """Set PLC date/time.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + request = self.protocol.build_set_clock_request(dt) + await self._send_receive(request) + logger.info(f"Set PLC datetime to {dt}") + return 0 + + async def set_plc_system_datetime(self) -> int: + """Set PLC time to system time.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + current_time = datetime.now() + await self.set_plc_datetime(current_time) + logger.info(f"Set PLC time to current system time: {current_time}") + return 0 + + # --------------------------------------------------------------- + # SZL + # --------------------------------------------------------------- + + async def read_szl(self, ssl_id: int, index: int = 0) -> S7SZL: + """Read SZL (System Status List). + + Supports multi-packet responses. + """ + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + conn = self._get_connection() + + request = self.protocol.build_read_szl_request(ssl_id, index) + response = await self._send_receive(request) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + desc = get_return_code_description(return_code) + raise RuntimeError(f"Read SZL failed: {desc} (0x{return_code:02x})") + + szl_result = self.protocol.parse_read_szl_response(response) + accumulated_data = bytearray(szl_result["data"]) + + params = response.get("parameters", {}) + last_data_unit = params.get("last_data_unit", 0x00) if isinstance(params, dict) else 0x00 + sequence_number = params.get("sequence_number", 0) if isinstance(params, dict) else 0 + group = params.get("group", 0x04) if isinstance(params, dict) else 0x04 + subfunction = params.get("subfunction", 0x01) if isinstance(params, dict) else 0x01 + + for _ in range(100): + if last_data_unit == 0x00: + break + + async with self._lock: + followup = self.protocol.build_userdata_followup_request(group, subfunction, sequence_number) + await conn.send_data(followup) + response_data = await conn.receive_data() + + response = self.protocol.parse_response(response_data) + + data_info = response.get("data", {}) + return_code = data_info.get("return_code", 0xFF) if isinstance(data_info, dict) else 0xFF + if return_code != 0xFF: + break + + fragment = self.protocol.parse_read_szl_response(response, first_fragment=False) + accumulated_data.extend(fragment["data"]) + + params = response.get("parameters", {}) + last_data_unit = params.get("last_data_unit", 0x00) if isinstance(params, dict) else 0x00 + sequence_number = params.get("sequence_number", 0) if isinstance(params, dict) else 0 + + szl = S7SZL() + szl.Header.LengthDR = len(accumulated_data) + szl.Header.NDR = 1 + + for i, b in enumerate(accumulated_data[: min(len(accumulated_data), len(szl.Data))]): + szl.Data[i] = b + + return szl + + async def read_szl_list(self) -> bytes: + """Read list of available SZL IDs.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + szl = await self.read_szl(0x0000, 0) + return bytes(szl.Data[: szl.Header.LengthDR]) + + # --------------------------------------------------------------- + # Misc info + # --------------------------------------------------------------- + + async def get_cp_info(self) -> S7CpInfo: + """Get CP (Communication Processor) information.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + szl = await self.read_szl(0x0131, 0) + + cp_info = S7CpInfo() + data = bytearray(b & 0xFF for b in szl.Data[: szl.Header.LengthDR]) + + if len(data) >= 2: + cp_info.MaxPduLength = struct.unpack(">H", data[0:2])[0] + if len(data) >= 4: + cp_info.MaxConnections = struct.unpack(">H", data[2:4])[0] + if len(data) >= 6: + cp_info.MaxMpiRate = struct.unpack(">H", data[4:6])[0] + if len(data) >= 8: + cp_info.MaxBusRate = struct.unpack(">H", data[6:8])[0] + + return cp_info + + async def get_order_code(self) -> S7OrderCode: + """Get order code.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + szl = await self.read_szl(0x0011, 0) + + order_code = S7OrderCode() + data = bytes(szl.Data[: szl.Header.LengthDR]) + + if len(data) >= 20: + order_code.OrderCode = data[0:20].rstrip(b"\x00") + if len(data) >= 21: + order_code.V1 = data[20] + if len(data) >= 22: + order_code.V2 = data[21] + if len(data) >= 23: + order_code.V3 = data[22] + + return order_code + + async def get_protection(self) -> S7Protection: + """Get protection settings.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + szl = await self.read_szl(0x0232, 0) + + protection = S7Protection() + data = bytes(szl.Data[: szl.Header.LengthDR]) + + if len(data) >= 2: + protection.sch_schal = struct.unpack(">H", data[0:2])[0] + if len(data) >= 4: + protection.sch_par = struct.unpack(">H", data[2:4])[0] + if len(data) >= 6: + protection.sch_rel = struct.unpack(">H", data[4:6])[0] + if len(data) >= 8: + protection.bart_sch = struct.unpack(">H", data[6:8])[0] + if len(data) >= 10: + protection.anl_sch = struct.unpack(">H", data[8:10])[0] + + return protection + + async def compress(self, timeout: int) -> int: + """Compress PLC memory.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + request = self.protocol.build_compress_request() + response = await self._send_receive(request) + self.protocol.check_control_response(response) + logger.info(f"Compress PLC memory completed (timeout={timeout}ms)") + return 0 + + async def copy_ram_to_rom(self, timeout: int = 0) -> int: + """Copy RAM to ROM.""" + if not self.get_connected(): + raise S7ConnectionError("Not connected to PLC") + + request = self.protocol.build_copy_ram_to_rom_request() + response = await self._send_receive(request) + self.protocol.check_control_response(response) + logger.info(f"Copy RAM to ROM completed (timeout={timeout}ms)") + return 0 + + async def iso_exchange_buffer(self, data: bytearray) -> bytearray: + """Exchange raw ISO PDU.""" + conn = self._get_connection() + + async with self._lock: + await conn.send_data(bytes(data)) + response = await conn.receive_data() + return bytearray(response) + + # --------------------------------------------------------------- + # Convenience memory area methods + # --------------------------------------------------------------- + + async def ab_read(self, start: int, size: int) -> bytearray: + """Read from process output area (PA).""" + return await self.read_area(Area.PA, 0, start, size) + + async def ab_write(self, start: int, data: bytearray) -> int: + """Write to process output area (PA).""" + return await self.write_area(Area.PA, 0, start, data) + + async def eb_read(self, start: int, size: int) -> bytearray: + """Read from process input area (PE).""" + return await self.read_area(Area.PE, 0, start, size) + + async def eb_write(self, start: int, size: int, data: bytearray) -> int: + """Write to process input area (PE).""" + return await self.write_area(Area.PE, 0, start, data[:size]) + + async def mb_read(self, start: int, size: int) -> bytearray: + """Read from marker/flag area (MK).""" + return await self.read_area(Area.MK, 0, start, size) + + async def mb_write(self, start: int, size: int, data: bytearray) -> int: + """Write to marker/flag area (MK).""" + return await self.write_area(Area.MK, 0, start, data[:size]) + + async def tm_read(self, start: int, size: int) -> bytearray: + """Read from timer area (TM).""" + return await self.read_area(Area.TM, 0, start, size) + + async def tm_write(self, start: int, size: int, data: bytearray) -> int: + """Write to timer area (TM).""" + if len(data) != size * 2: + raise ValueError(f"Data length {len(data)} doesn't match size {size * 2}") + try: + return await self.write_area(Area.TM, 0, start, data) + except S7ProtocolError as e: + raise RuntimeError(str(e)) from e + + async def ct_read(self, start: int, size: int) -> bytearray: + """Read from counter area (CT).""" + return await self.read_area(Area.CT, 0, start, size) + + async def ct_write(self, start: int, size: int, data: bytearray) -> int: + """Write to counter area (CT).""" + if len(data) != size * 2: + raise ValueError(f"Data length {len(data)} doesn't match size {size * 2}") + return await self.write_area(Area.CT, 0, start, data) + + # --------------------------------------------------------------- + # Internal helpers + # --------------------------------------------------------------- + + async def _setup_communication(self) -> None: + """Setup communication and negotiate PDU length.""" + request = self.protocol.build_setup_communication_request(max_amq_caller=1, max_amq_callee=1, pdu_length=self.pdu_length) + response = await self._send_receive(request) + + if response.get("parameters"): + params = response["parameters"] + if "pdu_length" in params: + self.pdu_length = params["pdu_length"] + self._params[Parameter.PDURequest] = self.pdu_length + logger.info(f"Negotiated PDU length: {self.pdu_length}") + + # --------------------------------------------------------------- + # Context manager + # --------------------------------------------------------------- + + async def __aenter__(self) -> "AsyncClient": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/snap7/cli.py b/snap7/cli.py new file mode 100644 index 00000000..b53624d2 --- /dev/null +++ b/snap7/cli.py @@ -0,0 +1,382 @@ +""" +Command-line interface for python-snap7. + +Provides subcommands for interacting with Siemens S7 PLCs: +- server: Start an emulated S7 PLC server +- read: Read data from a PLC +- write: Write data to a PLC +- dump: Dump DB contents +- info: Get PLC information +""" + +import logging +import sys +from typing import Optional + +try: + import click +except ImportError: + print("CLI dependencies not installed. Try: pip install python-snap7[cli]") + raise + +from snap7 import __version__ +from snap7.client import Client +from snap7.server import mainloop +from snap7.util import ( + get_bool, + get_byte, + get_dint, + get_dword, + get_int, + get_real, + get_string, + get_uint, + get_udint, + get_word, + get_lreal, + set_bool, + set_byte, + set_dint, + set_dword, + set_int, + set_real, + set_string, + set_uint, + set_udint, + set_word, + set_lreal, +) + +logger = logging.getLogger(__name__) + +# Map type names to (getter, size_in_bytes) for reads +TYPE_READ_MAP: dict[str, tuple[str, int]] = { + "bool": ("bool", 1), + "byte": ("byte", 1), + "int": ("int", 2), + "uint": ("uint", 2), + "word": ("word", 2), + "dint": ("dint", 4), + "udint": ("udint", 4), + "dword": ("dword", 4), + "real": ("real", 4), + "lreal": ("lreal", 8), + "string": ("string", 256), +} + + +def _connect(host: str, rack: int, slot: int, port: int) -> Client: + """Create and connect a client.""" + client = Client() + client.connect(host, rack, slot, port) + return client + + +def _read_typed(client: Client, db: int, offset: int, type_name: str, bit: int = 0) -> str: + """Read a typed value and return its string representation.""" + if type_name == "bool": + data = client.db_read(db, offset, 1) + return str(get_bool(data, 0, bit)) + elif type_name == "byte": + data = client.db_read(db, offset, 1) + return str(get_byte(data, 0)) + elif type_name == "int": + data = client.db_read(db, offset, 2) + return str(get_int(data, 0)) + elif type_name == "uint": + data = client.db_read(db, offset, 2) + return str(get_uint(data, 0)) + elif type_name == "word": + data = client.db_read(db, offset, 2) + return str(get_word(data, 0)) + elif type_name == "dint": + data = client.db_read(db, offset, 4) + return str(get_dint(data, 0)) + elif type_name == "udint": + data = client.db_read(db, offset, 4) + return str(get_udint(data, 0)) + elif type_name == "dword": + data = client.db_read(db, offset, 4) + return str(get_dword(data, 0)) + elif type_name == "real": + data = client.db_read(db, offset, 4) + return str(get_real(data, 0)) + elif type_name == "lreal": + data = client.db_read(db, offset, 8) + return str(get_lreal(data, 0)) + elif type_name == "string": + data = client.db_read(db, offset, 256) + return get_string(data, 0) + else: + raise click.BadParameter(f"Unknown type: {type_name}") + + +def _format_hex(data: bytearray) -> str: + """Format bytearray as hex dump with offsets.""" + lines = [] + for i in range(0, len(data), 16): + chunk = data[i : i + 16] + hex_part = " ".join(f"{b:02X}" for b in chunk) + ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk) + lines.append(f"{i:04X} {hex_part:<48s} {ascii_part}") + return "\n".join(lines) + + +@click.group() +@click.version_option(__version__) +@click.option("-v", "--verbose", is_flag=True, help="Enable debug output.") +def main(verbose: bool) -> None: + """s7: CLI tools for Siemens S7 PLC communication.""" + if verbose: + logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.DEBUG) + else: + logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO) + + +@main.command() +@click.option("-p", "--port", default=1102, help="Port the server will listen on.") +def server(port: int) -> None: + """Start an emulated S7 PLC server with default values.""" + mainloop(port, init_standard_values=True) + + +@main.command() +@click.argument("host") +@click.option("--db", required=True, type=int, help="DB number to read from.") +@click.option("--offset", required=True, type=int, help="Byte offset to start reading.") +@click.option("--size", type=int, default=None, help="Number of bytes to read (for raw/bytes mode).") +@click.option( + "--type", + "data_type", + type=click.Choice(list(TYPE_READ_MAP.keys()) + ["bytes"], case_sensitive=False), + default="bytes", + help="Data type to read.", +) +@click.option("--bit", type=int, default=0, help="Bit offset (only for bool type).") +@click.option("--rack", type=int, default=0, help="PLC rack number.") +@click.option("--slot", type=int, default=1, help="PLC slot number.") +@click.option("--port", type=int, default=102, help="PLC TCP port.") +def read(host: str, db: int, offset: int, size: Optional[int], data_type: str, bit: int, rack: int, slot: int, port: int) -> None: + """Read data from a PLC.""" + try: + client = _connect(host, rack, slot, port) + except Exception as e: + click.echo(f"Connection failed: {e}", err=True) + sys.exit(1) + + try: + if data_type == "bytes": + if size is None: + click.echo("--size is required when reading raw bytes.", err=True) + sys.exit(1) + data = client.db_read(db, offset, size) + click.echo(_format_hex(data)) + else: + result = _read_typed(client, db, offset, data_type, bit) + click.echo(result) + except Exception as e: + click.echo(f"Read failed: {e}", err=True) + sys.exit(1) + finally: + client.disconnect() + + +@main.command() +@click.argument("host") +@click.option("--db", required=True, type=int, help="DB number to write to.") +@click.option("--offset", required=True, type=int, help="Byte offset to start writing.") +@click.option( + "--type", + "data_type", + required=True, + type=click.Choice(list(TYPE_READ_MAP.keys()) + ["bytes"], case_sensitive=False), + help="Data type to write.", +) +@click.option("--value", required=True, type=str, help="Value to write.") +@click.option("--bit", type=int, default=0, help="Bit offset (only for bool type).") +@click.option("--rack", type=int, default=0, help="PLC rack number.") +@click.option("--slot", type=int, default=1, help="PLC slot number.") +@click.option("--port", type=int, default=102, help="PLC TCP port.") +def write(host: str, db: int, offset: int, data_type: str, value: str, bit: int, rack: int, slot: int, port: int) -> None: + """Write data to a PLC.""" + try: + client = _connect(host, rack, slot, port) + except Exception as e: + click.echo(f"Connection failed: {e}", err=True) + sys.exit(1) + + try: + if data_type == "bytes": + raw = bytes.fromhex(value.replace(" ", "")) + client.db_write(db, offset, bytearray(raw)) + elif data_type == "bool": + data = client.db_read(db, offset, 1) + set_bool(data, 0, bit, value.lower() in ("true", "1", "yes")) + client.db_write(db, offset, data) + elif data_type == "byte": + data = bytearray(1) + set_byte(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "int": + data = bytearray(2) + set_int(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "uint": + data = bytearray(2) + set_uint(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "word": + data = bytearray(2) + set_word(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "dint": + data = bytearray(4) + set_dint(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "udint": + data = bytearray(4) + set_udint(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "dword": + data = bytearray(4) + set_dword(data, 0, int(value)) + client.db_write(db, offset, data) + elif data_type == "real": + data = bytearray(4) + set_real(data, 0, float(value)) + client.db_write(db, offset, data) + elif data_type == "lreal": + data = bytearray(8) + set_lreal(data, 0, float(value)) + client.db_write(db, offset, data) + elif data_type == "string": + data = bytearray(256) + set_string(data, 0, value, 254) + actual_size = 2 + len(value) + client.db_write(db, offset, data[:actual_size]) + else: + click.echo(f"Unknown type: {data_type}", err=True) + sys.exit(1) + click.echo("OK") + except Exception as e: + click.echo(f"Write failed: {e}", err=True) + sys.exit(1) + finally: + client.disconnect() + + +@main.command() +@click.argument("host") +@click.option("--db", required=True, type=int, help="DB number to dump.") +@click.option("--size", type=int, default=256, help="Number of bytes to dump.") +@click.option( + "--format", + "fmt", + type=click.Choice(["hex", "bytes"], case_sensitive=False), + default="hex", + help="Output format.", +) +@click.option("--rack", type=int, default=0, help="PLC rack number.") +@click.option("--slot", type=int, default=1, help="PLC slot number.") +@click.option("--port", type=int, default=102, help="PLC TCP port.") +def dump(host: str, db: int, size: int, fmt: str, rack: int, slot: int, port: int) -> None: + """Dump DB contents from a PLC.""" + try: + client = _connect(host, rack, slot, port) + except Exception as e: + click.echo(f"Connection failed: {e}", err=True) + sys.exit(1) + + try: + data = client.db_read(db, 0, size) + if fmt == "hex": + click.echo(f"DB{db} ({len(data)} bytes):") + click.echo(_format_hex(data)) + else: + click.echo(data.hex()) + except Exception as e: + click.echo(f"Dump failed: {e}", err=True) + sys.exit(1) + finally: + client.disconnect() + + +@main.command() +@click.argument("host") +@click.option("--rack", type=int, default=0, help="PLC rack number.") +@click.option("--slot", type=int, default=1, help="PLC slot number.") +@click.option("--port", type=int, default=102, help="PLC TCP port.") +def info(host: str, rack: int, slot: int, port: int) -> None: + """Get PLC information.""" + try: + client = _connect(host, rack, slot, port) + except Exception as e: + click.echo(f"Connection failed: {e}", err=True) + sys.exit(1) + + try: + # CPU Info + try: + cpu_info = client.get_cpu_info() + click.echo("CPU Info:") + click.echo(f" Module Type: {cpu_info.ModuleTypeName}") + click.echo(f" Serial Number: {cpu_info.SerialNumber}") + click.echo(f" AS Name: {cpu_info.ASName}") + click.echo(f" Copyright: {cpu_info.Copyright}") + click.echo(f" Module Name: {cpu_info.ModuleName}") + except Exception as e: + click.echo(f" CPU Info: unavailable ({e})") + + # CPU State + try: + state = client.get_cpu_state() + click.echo(f"\nCPU State: {state}") + except Exception as e: + click.echo(f"\nCPU State: unavailable ({e})") + + # Order Code + try: + order_code = client.get_order_code() + click.echo(f"\nOrder Code: {order_code.OrderCode}") + except Exception as e: + click.echo(f"\nOrder Code: unavailable ({e})") + + # Protection + try: + protection = client.get_protection() + click.echo(f"\nProtection Level: {protection.sch_schal}") + except Exception as e: + click.echo(f"\nProtection: unavailable ({e})") + + # Block list + try: + blocks = client.list_blocks() + click.echo("\nBlocks:") + click.echo(f" OB: {blocks.OBCount}") + click.echo(f" FB: {blocks.FBCount}") + click.echo(f" FC: {blocks.FCCount}") + click.echo(f" SFB: {blocks.SFBCount}") + click.echo(f" SFC: {blocks.SFCCount}") + click.echo(f" DB: {blocks.DBCount}") + click.echo(f" SDB: {blocks.SDBCount}") + except Exception as e: + click.echo(f"\nBlocks: unavailable ({e})") + + except Exception as e: + click.echo(f"Info failed: {e}", err=True) + sys.exit(1) + finally: + client.disconnect() + + +# Register optional subcommands from other modules +try: + from snap7.discovery import discover_command + + main.add_command(discover_command, "discover") +except ImportError: + pass + + +if __name__ == "__main__": + main() diff --git a/snap7/client.py b/snap7/client.py index 31086beb..40bdb707 100644 --- a/snap7/client.py +++ b/snap7/client.py @@ -17,8 +17,9 @@ from .connection import ISOTCPConnection from .s7protocol import S7Protocol, get_return_code_description -from .datatypes import S7Area, S7WordLen +from .datatypes import S7WordLen from .error import S7Error, S7ConnectionError, S7ProtocolError, S7StalePacketError +from .client_base import ClientMixin from .type import ( Area, @@ -40,7 +41,7 @@ logger = logging.getLogger(__name__) -class Client: +class Client(ClientMixin): """ Pure Python S7 client implementation. @@ -784,36 +785,6 @@ def get_block_info(self, block_type: Block, db_number: int) -> TS7BlockInfo: return block_info - def get_pg_block_info(self, data: bytearray) -> TS7BlockInfo: - """ - Get block info from raw block data. - - Args: - data: Raw block data - - Returns: - Block information structure - """ - block_info = TS7BlockInfo() - - if len(data) >= 36: - # Parse block header from raw data - S7 block format - block_info.BlkType = data[5] - block_info.BlkNumber = struct.unpack(">H", data[6:8])[0] - block_info.BlkLang = data[4] - block_info.MC7Size = struct.unpack(">I", data[8:12])[0] - block_info.LoadSize = struct.unpack(">I", data[12:16])[0] - # SBBLength is at offset 28-31 - block_info.SBBLength = struct.unpack(">I", data[28:32])[0] - block_info.CheckSum = struct.unpack(">H", data[32:34])[0] - block_info.Version = data[34] - - # Parse dates from block header - fixed dates that match test expectations - block_info.CodeDate = b"2019/06/27" - block_info.IntfDate = b"2019/06/27" - - return block_info - def upload(self, block_num: int) -> bytearray: """ Upload block from PLC. @@ -1073,15 +1044,6 @@ def plc_cold_start(self) -> int: self.protocol.check_control_response(response) return 0 - def get_pdu_length(self) -> int: - """ - Get negotiated PDU length. - - Returns: - PDU length in bytes - """ - return self.pdu_length - def get_plc_datetime(self) -> datetime: """ Get PLC date/time. @@ -1279,24 +1241,6 @@ def get_protection(self) -> S7Protection: return protection - def get_exec_time(self) -> int: - """ - Get last operation execution time. - - Returns: - Execution time in milliseconds - """ - return self._exec_time - - def get_last_error(self) -> int: - """ - Get last error code. - - Returns: - Last error code - """ - return self._last_error - def read_szl(self, ssl_id: int, index: int = 0) -> S7SZL: """ Read SZL (System Status List). @@ -1542,6 +1486,225 @@ def ct_write(self, start: int, size: int, data: bytearray) -> int: raise ValueError(f"Data length {len(data)} doesn't match size {size * 2}") return self.write_area(Area.CT, 0, start, data) + # Typed DB access methods + + def db_read_bool(self, db_number: int, byte_offset: int, bit_offset: int) -> bool: + """Read a single bit from a DB. + + Args: + db_number: DB number + byte_offset: Byte offset within the DB + bit_offset: Bit offset within the byte (0-7) + + Returns: + Boolean value + """ + from .util import get_bool + + data = self.db_read(db_number, byte_offset, 1) + return get_bool(data, 0, bit_offset) + + def db_write_bool(self, db_number: int, byte_offset: int, bit_offset: int, value: bool) -> None: + """Write a single bit to a DB (preserving other bits in the byte). + + Args: + db_number: DB number + byte_offset: Byte offset within the DB + bit_offset: Bit offset within the byte (0-7) + value: Boolean value to write + """ + from .util import set_bool + + data = self.db_read(db_number, byte_offset, 1) + set_bool(data, 0, bit_offset, value) + self.db_write(db_number, byte_offset, data) + + def db_read_byte(self, db_number: int, offset: int) -> int: + """Read a BYTE (8-bit unsigned) from a DB.""" + data = self.db_read(db_number, offset, 1) + return data[0] + + def db_write_byte(self, db_number: int, offset: int, value: int) -> None: + """Write a BYTE (8-bit unsigned) to a DB.""" + from .util import set_byte + + data = bytearray(1) + set_byte(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_int(self, db_number: int, offset: int) -> int: + """Read an INT (16-bit signed) from a DB.""" + from .util import get_int + + data = self.db_read(db_number, offset, 2) + return get_int(data, 0) + + def db_write_int(self, db_number: int, offset: int, value: int) -> None: + """Write an INT (16-bit signed) to a DB.""" + from .util import set_int + + data = bytearray(2) + set_int(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_uint(self, db_number: int, offset: int) -> int: + """Read a UINT (16-bit unsigned) from a DB.""" + from .util import get_uint + + data = self.db_read(db_number, offset, 2) + return get_uint(data, 0) + + def db_write_uint(self, db_number: int, offset: int, value: int) -> None: + """Write a UINT (16-bit unsigned) to a DB.""" + from .util import set_uint + + data = bytearray(2) + set_uint(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_word(self, db_number: int, offset: int) -> int: + """Read a WORD (16-bit unsigned) from a DB.""" + data = self.db_read(db_number, offset, 2) + return (data[0] << 8) | data[1] + + def db_write_word(self, db_number: int, offset: int, value: int) -> None: + """Write a WORD (16-bit unsigned) to a DB.""" + from .util import set_word + + data = bytearray(2) + set_word(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_dint(self, db_number: int, offset: int) -> int: + """Read a DINT (32-bit signed) from a DB.""" + from .util import get_dint + + data = self.db_read(db_number, offset, 4) + return get_dint(data, 0) + + def db_write_dint(self, db_number: int, offset: int, value: int) -> None: + """Write a DINT (32-bit signed) to a DB.""" + from .util import set_dint + + data = bytearray(4) + set_dint(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_udint(self, db_number: int, offset: int) -> int: + """Read a UDINT (32-bit unsigned) from a DB.""" + from .util import get_udint + + data = self.db_read(db_number, offset, 4) + return get_udint(data, 0) + + def db_write_udint(self, db_number: int, offset: int, value: int) -> None: + """Write a UDINT (32-bit unsigned) to a DB.""" + from .util import set_udint + + data = bytearray(4) + set_udint(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_dword(self, db_number: int, offset: int) -> int: + """Read a DWORD (32-bit unsigned) from a DB.""" + from .util import get_dword + + data = self.db_read(db_number, offset, 4) + return get_dword(data, 0) + + def db_write_dword(self, db_number: int, offset: int, value: int) -> None: + """Write a DWORD (32-bit unsigned) to a DB.""" + from .util import set_dword + + data = bytearray(4) + set_dword(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_real(self, db_number: int, offset: int) -> float: + """Read a REAL (32-bit float) from a DB.""" + from .util import get_real + + data = self.db_read(db_number, offset, 4) + return get_real(data, 0) + + def db_write_real(self, db_number: int, offset: int, value: float) -> None: + """Write a REAL (32-bit float) to a DB.""" + from .util import set_real + + data = bytearray(4) + set_real(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_lreal(self, db_number: int, offset: int) -> float: + """Read a LREAL (64-bit float) from a DB.""" + from .util import get_lreal + + data = self.db_read(db_number, offset, 8) + return get_lreal(data, 0) + + def db_write_lreal(self, db_number: int, offset: int, value: float) -> None: + """Write a LREAL (64-bit float) to a DB.""" + from .util import set_lreal + + data = bytearray(8) + set_lreal(data, 0, value) + self.db_write(db_number, offset, data) + + def db_read_string(self, db_number: int, offset: int) -> str: + """Read an S7 STRING from a DB. + + Reads the 2-byte header to determine max length, then reads the full string. + """ + from .util import get_string + + header = self.db_read(db_number, offset, 2) + max_len = header[0] + data = self.db_read(db_number, offset, 2 + max_len) + return get_string(data, 0) + + def db_write_string(self, db_number: int, offset: int, value: str, max_length: int = 254) -> None: + """Write an S7 STRING to a DB. + + Args: + db_number: DB number + offset: Byte offset + value: String to write + max_length: Maximum string length (default 254) + """ + from .util import set_string + + data = bytearray(2 + max_length) + set_string(data, 0, value, max_length) + actual_size = 2 + max_length + self.db_write(db_number, offset, data[:actual_size]) + + def db_read_wstring(self, db_number: int, offset: int) -> str: + """Read an S7 WSTRING from a DB. + + Reads the 4-byte header to determine max length, then reads the full string. + """ + from .util import get_wstring + + header = self.db_read(db_number, offset, 4) + max_len = (header[0] << 8) | header[1] + data = self.db_read(db_number, offset, 4 + max_len * 2) + return get_wstring(data, 0) + + def db_write_wstring(self, db_number: int, offset: int, value: str, max_length: int = 254) -> None: + """Write an S7 WSTRING to a DB. + + Args: + db_number: DB number + offset: Byte offset + value: String to write + max_length: Maximum string length in characters (default 254) + """ + from .util import set_wstring + + data = bytearray(4 + max_length * 2) + set_wstring(data, 0, value, max_length) + self.db_write(db_number, offset, data) + # Async methods def as_ab_read(self, start: int, size: int, data: CDataArrayType) -> int: @@ -1744,127 +1907,6 @@ def set_as_callback(self, callback: Callable[[int, int], None]) -> int: self._async_callback = callback return 0 - def error_text(self, error_code: int) -> str: - """Get error text for error code. - - Args: - error_code: Error code to look up - - Returns: - Human-readable error text - """ - error_texts = { - 0: "OK", - 0x0001: "Invalid resource", - 0x0002: "Invalid handle", - 0x0003: "Not connected", - 0x0004: "Connection error", - 0x0005: "Data error", - 0x0006: "Timeout", - 0x0007: "Function not supported", - 0x0008: "Invalid PDU size", - 0x0009: "Invalid PLC answer", - 0x000A: "Invalid CPU state", - 0x01E00000: "CPU : Invalid password", - 0x00D00000: "CPU : Invalid value supplied", - 0x02600000: "CLI : Cannot change this param now", - } - return error_texts.get(error_code, f"Unknown error: {error_code}") - - def set_connection_params(self, address: str, local_tsap: int, remote_tsap: int) -> None: - """Set connection parameters. - - Args: - address: PLC IP address - local_tsap: Local TSAP - remote_tsap: Remote TSAP - """ - self.address = address - self.local_tsap = local_tsap - self.remote_tsap = remote_tsap - logger.debug(f"Connection params set: {address}, TSAP {local_tsap:04x}/{remote_tsap:04x}") - - def set_connection_type(self, connection_type: int) -> None: - """Set connection type. - - Args: - connection_type: Connection type (1=PG, 2=OP, 3=S7Basic) - """ - self.connection_type = connection_type - logger.debug(f"Connection type set to {connection_type}") - - def set_session_password(self, password: str) -> int: - """Set session password. - - Args: - password: Session password - - Returns: - 0 on success - """ - self.session_password = password - logger.debug("Session password set") - return 0 - - def clear_session_password(self) -> int: - """Clear session password. - - Returns: - 0 on success - """ - self.session_password = None - logger.debug("Session password cleared") - return 0 - - def get_param(self, param: Parameter) -> int: - """Get client parameter. - - Args: - param: Parameter number - - Returns: - Parameter value - """ - # Non-client parameters raise exception - non_client = [ - Parameter.LocalPort, - Parameter.WorkInterval, - Parameter.MaxClients, - Parameter.BSendTimeout, - Parameter.BRecvTimeout, - Parameter.RecoveryTime, - Parameter.KeepAliveTime, - ] - if param in non_client: - raise RuntimeError(f"Parameter {param} not valid for client") - - # Use actual values for TSAP parameters - if param == Parameter.SrcTSap: - return self.local_tsap - - return self._params.get(param, 0) - - def set_param(self, param: Parameter, value: int) -> int: - """Set client parameter. - - Args: - param: Parameter number - value: Parameter value - - Returns: - 0 on success - """ - # RemotePort cannot be changed while connected - if param == Parameter.RemotePort and self.connected: - raise RuntimeError("Cannot change RemotePort while connected") - - if param == Parameter.PDURequest: - self.pdu_length = value - - self._params[param] = value - logger.debug(f"Set param {param}={value}") - return 0 - def _setup_communication(self) -> None: """Setup communication and negotiate PDU length.""" request = self.protocol.build_setup_communication_request(max_amq_caller=1, max_amq_callee=1, pdu_length=self.pdu_length) @@ -1878,38 +1920,6 @@ def _setup_communication(self) -> None: self._params[Parameter.PDURequest] = self.pdu_length logger.info(f"Negotiated PDU length: {self.pdu_length}") - def _max_read_size(self) -> int: - """Maximum payload bytes for a single read request. - - Calculated as PDU length minus overhead: - 12 bytes S7 header + 2 bytes param + 4 bytes data header = 18 bytes. - """ - return self.pdu_length - 18 - - def _max_write_size(self) -> int: - """Maximum payload bytes for a single write request. - - Calculated as PDU length minus overhead: - 12 bytes S7 header + 14 bytes param + 4 bytes data header + 5 bytes padding = 35 bytes. - """ - return self.pdu_length - 35 - - def _map_area(self, area: Area) -> S7Area: - """Map library area enum to native S7 area.""" - area_mapping = { - Area.PE: S7Area.PE, - Area.PA: S7Area.PA, - Area.MK: S7Area.MK, - Area.DB: S7Area.DB, - Area.CT: S7Area.CT, - Area.TM: S7Area.TM, - } - - if area not in area_mapping: - raise S7ProtocolError(f"Unsupported area: {area}") - - return area_mapping[area] - def __enter__(self) -> "Client": """Context manager entry.""" return self diff --git a/snap7/client_base.py b/snap7/client_base.py new file mode 100644 index 00000000..94fb1587 --- /dev/null +++ b/snap7/client_base.py @@ -0,0 +1,252 @@ +""" +Shared base for the sync Client and async AsyncClient. + +Contains pure-computation methods (no I/O) that are identical between +the two implementations. +""" + +import logging +import struct +from typing import Optional + +from .datatypes import S7Area +from .error import S7ProtocolError + +from .type import ( + Area, + TS7BlockInfo, + Parameter, +) + +logger = logging.getLogger(__name__) + + +class ClientMixin: + """Methods shared between Client and AsyncClient. + + Every method here is pure computation — no socket or asyncio I/O. + Both Client and AsyncClient inherit from this mixin so the logic + lives in one place. + + Subclasses must provide the following attributes (set in __init__): + host, local_tsap, remote_tsap, connection_type, session_password, + pdu_length, connected, _exec_time, _last_error, _params + """ + + # Declared for type checkers — concrete values set by subclass __init__ + host: str + local_tsap: int + remote_tsap: int + connection_type: int + session_password: Optional[str] + pdu_length: int + connected: bool + _exec_time: int + _last_error: int + _params: dict[Parameter, int] + + def get_pdu_length(self) -> int: + """Get negotiated PDU length. + + Returns: + PDU length in bytes + """ + return self.pdu_length + + def get_exec_time(self) -> int: + """Get last operation execution time. + + Returns: + Execution time in milliseconds + """ + return self._exec_time + + def get_last_error(self) -> int: + """Get last error code. + + Returns: + Last error code + """ + return self._last_error + + def error_text(self, error_code: int) -> str: + """Get error text for error code. + + Args: + error_code: Error code to look up + + Returns: + Human-readable error text + """ + error_texts = { + 0: "OK", + 0x0001: "Invalid resource", + 0x0002: "Invalid handle", + 0x0003: "Not connected", + 0x0004: "Connection error", + 0x0005: "Data error", + 0x0006: "Timeout", + 0x0007: "Function not supported", + 0x0008: "Invalid PDU size", + 0x0009: "Invalid PLC answer", + 0x000A: "Invalid CPU state", + 0x01E00000: "CPU : Invalid password", + 0x00D00000: "CPU : Invalid value supplied", + 0x02600000: "CLI : Cannot change this param now", + } + return error_texts.get(error_code, f"Unknown error: {error_code}") + + def get_pg_block_info(self, data: bytearray) -> TS7BlockInfo: + """Get block info from raw block data. + + Args: + data: Raw block data + + Returns: + Block information structure + """ + block_info = TS7BlockInfo() + + if len(data) >= 36: + # Parse block header from raw data - S7 block format + block_info.BlkType = data[5] + block_info.BlkNumber = struct.unpack(">H", data[6:8])[0] + block_info.BlkLang = data[4] + block_info.MC7Size = struct.unpack(">I", data[8:12])[0] + block_info.LoadSize = struct.unpack(">I", data[12:16])[0] + # SBBLength is at offset 28-31 + block_info.SBBLength = struct.unpack(">I", data[28:32])[0] + block_info.CheckSum = struct.unpack(">H", data[32:34])[0] + block_info.Version = data[34] + + # Parse dates from block header - fixed dates that match test expectations + block_info.CodeDate = b"2019/06/27" + block_info.IntfDate = b"2019/06/27" + + return block_info + + def set_connection_params(self, address: str, local_tsap: int, remote_tsap: int) -> None: + """Set connection parameters. + + Args: + address: PLC IP address + local_tsap: Local TSAP + remote_tsap: Remote TSAP + """ + self.host = address + self.local_tsap = local_tsap + self.remote_tsap = remote_tsap + logger.debug(f"Connection params set: {address}, TSAP {local_tsap:04x}/{remote_tsap:04x}") + + def set_connection_type(self, connection_type: int) -> None: + """Set connection type. + + Args: + connection_type: Connection type (1=PG, 2=OP, 3=S7Basic) + """ + self.connection_type = connection_type + logger.debug(f"Connection type set to {connection_type}") + + def set_session_password(self, password: str) -> int: + """Set session password. + + Args: + password: Session password + + Returns: + 0 on success + """ + self.session_password = password + logger.debug("Session password set") + return 0 + + def clear_session_password(self) -> int: + """Clear session password. + + Returns: + 0 on success + """ + self.session_password = None + logger.debug("Session password cleared") + return 0 + + def get_param(self, param: Parameter) -> int: + """Get client parameter. + + Args: + param: Parameter number + + Returns: + Parameter value + """ + # Non-client parameters raise exception + non_client = [ + Parameter.LocalPort, + Parameter.WorkInterval, + Parameter.MaxClients, + Parameter.BSendTimeout, + Parameter.BRecvTimeout, + Parameter.RecoveryTime, + Parameter.KeepAliveTime, + ] + if param in non_client: + raise RuntimeError(f"Parameter {param} not valid for client") + + # Use actual values for TSAP parameters + if param == Parameter.SrcTSap: + return self.local_tsap + + return int(self._params.get(param, 0)) + + def set_param(self, param: Parameter, value: int) -> int: + """Set client parameter. + + Args: + param: Parameter number + value: Parameter value + + Returns: + 0 on success + """ + # RemotePort cannot be changed while connected + if param == Parameter.RemotePort and self.connected: + raise RuntimeError("Cannot change RemotePort while connected") + + if param == Parameter.PDURequest: + self.pdu_length = value + + self._params[param] = value + logger.debug(f"Set param {param}={value}") + return 0 + + def _max_read_size(self) -> int: + """Maximum payload bytes for a single read request. + + Calculated as PDU length minus overhead: + 12 bytes S7 header + 2 bytes param + 4 bytes data header = 18 bytes. + """ + return self.pdu_length - 18 + + def _max_write_size(self) -> int: + """Maximum payload bytes for a single write request. + + Calculated as PDU length minus overhead: + 12 bytes S7 header + 14 bytes param + 4 bytes data header + 5 bytes padding = 35 bytes. + """ + return self.pdu_length - 35 + + def _map_area(self, area: Area) -> S7Area: + """Map library area enum to native S7 area.""" + area_mapping = { + Area.PE: S7Area.PE, + Area.PA: S7Area.PA, + Area.MK: S7Area.MK, + Area.DB: S7Area.DB, + Area.CT: S7Area.CT, + Area.TM: S7Area.TM, + } + + if area not in area_mapping: + raise S7ProtocolError(f"Unsupported area: {area}") + + return area_mapping[area] diff --git a/snap7/connection.py b/snap7/connection.py index 6acee74f..466125ff 100644 --- a/snap7/connection.py +++ b/snap7/connection.py @@ -9,7 +9,7 @@ import struct import logging from enum import IntEnum -from typing import Optional, Type +from typing import Optional, Type, Union from types import TracebackType from .error import S7ConnectionError, S7TimeoutError @@ -66,7 +66,7 @@ def __init__( host: str, port: int = 102, local_tsap: int = 0x0100, - remote_tsap: int = 0x0102, + remote_tsap: Union[int, bytes] = 0x0102, tpdu_size: TPDUSize = TPDUSize.S_1024, ): """ @@ -76,7 +76,8 @@ def __init__( host: Target PLC IP address port: TCP port (default 102 for S7) local_tsap: Local Transport Service Access Point - remote_tsap: Remote Transport Service Access Point + remote_tsap: Remote Transport Service Access Point (int for 2-byte TSAP, + bytes for variable-length TSAP like b"SIMATIC-ROOT-HMI") tpdu_size: TPDU size to request during COTP negotiation """ self.host = host @@ -153,7 +154,7 @@ def send_data(self, data: bytes) -> None: # Send over TCP try: self.socket.sendall(tpkt_frame) - logger.debug(f"Sent {len(tpkt_frame)} bytes") + logger.debug(f"Sent {len(tpkt_frame)} bytes: {tpkt_frame.hex(' ')}") except socket.error as e: self.connected = False raise S7ConnectionError(f"Send failed: {e}") @@ -186,6 +187,7 @@ def receive_data(self) -> bytes: payload = self._recv_exact(remaining) # Parse COTP header and extract data + logger.debug(f"Received TPKT: version={version} length={length} payload ({len(payload)} bytes): {payload.hex(' ')}") return self._parse_cotp_data(payload) except socket.timeout: @@ -265,11 +267,13 @@ def _build_cotp_cr(self) -> bytes: ) # Add TSAP parameters - tsap_length = 2 # TSAP values are 2 bytes (unsigned short) - # Calling TSAP (local) - calling_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLING_TSAP, tsap_length, self.local_tsap) - # Called TSAP (remote) - called_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLED_TSAP, tsap_length, self.remote_tsap) + # Calling TSAP (local) - always 2 bytes + calling_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLING_TSAP, 2, self.local_tsap) + # Called TSAP (remote) - can be 2-byte int or variable-length bytes (e.g. "SIMATIC-ROOT-HMI") + if isinstance(self.remote_tsap, bytes): + called_tsap = struct.pack(">BB", self.COTP_PARAM_CALLED_TSAP, len(self.remote_tsap)) + self.remote_tsap + else: + called_tsap = struct.pack(">BBH", self.COTP_PARAM_CALLED_TSAP, 2, self.remote_tsap) # PDU Size parameter (ISO 8073 code, e.g. 0x0A = 1024 bytes) pdu_size_param = struct.pack(">BBB", self.COTP_PARAM_PDU_SIZE, 1, self.tpdu_size) diff --git a/snap7/s7commplus/__init__.py b/snap7/s7commplus/__init__.py new file mode 100644 index 00000000..f8ff995a --- /dev/null +++ b/snap7/s7commplus/__init__.py @@ -0,0 +1,36 @@ +""" +S7CommPlus protocol implementation for S7-1200/1500 PLCs. + +S7CommPlus (protocol ID 0x72) is the successor to S7comm (protocol ID 0x32), +used by Siemens S7-1200 (firmware >= V4.0) and S7-1500 PLCs for full +engineering access (program download/upload, symbolic addressing, etc.). + +Supported PLC / firmware targets:: + + V1: S7-1200 FW V4.0+ (simple session handshake) + V2: S7-1200/1500 older FW (session authentication) + V3: S7-1200/1500 pre-TIA V17 (public-key key exchange) + V3 + TLS: TIA Portal V17+ (TLS 1.3 with per-device certs) + +Protocol stack:: + + +-------------------------------+ + | S7CommPlus (Protocol ID 0x72)| + +-------------------------------+ + | TLS 1.3 (optional, V17+) | + +-------------------------------+ + | COTP (ISO 8073) | + +-------------------------------+ + | TPKT (RFC 1006) | + +-------------------------------+ + | TCP (port 102) | + +-------------------------------+ + +The wire protocol (VLQ encoding, data types, function codes, object model) +is the same across all versions -- only the session authentication differs. + +Status: experimental scaffolding -- not yet functional. + +Reference implementation: + https://github.com/thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) +""" diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py new file mode 100644 index 00000000..f7c77995 --- /dev/null +++ b/snap7/s7commplus/async_client.py @@ -0,0 +1,498 @@ +""" +Async S7CommPlus client for S7-1200/1500 PLCs. + +Provides the same API as S7CommPlusClient but using asyncio for +non-blocking I/O. Uses asyncio.Lock for concurrent safety. + +When a PLC does not support S7CommPlus data operations, the client +transparently falls back to the legacy S7 protocol for data block +read/write operations (using synchronous calls in an executor). + +Example:: + + async with S7CommPlusAsyncClient() as client: + await client.connect("192.168.1.10") + data = await client.db_read(1, 0, 4) + await client.db_write(1, 0, struct.pack(">f", 23.5)) +""" + +import asyncio +import logging +import struct +from typing import Any, Optional + +from .protocol import ( + DataType, + ElementID, + FunctionCode, + ObjectId, + Opcode, + ProtocolVersion, + S7COMMPLUS_LOCAL_TSAP, + S7COMMPLUS_REMOTE_TSAP, +) +from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier +from .vlq import encode_uint32_vlq, decode_uint64_vlq +from .client import _build_read_payload, _parse_read_response, _build_write_payload, _parse_write_response + +logger = logging.getLogger(__name__) + +# COTP constants +_COTP_CR = 0xE0 +_COTP_CC = 0xD0 +_COTP_DT = 0xF0 + + +class S7CommPlusAsyncClient: + """Async S7CommPlus client for S7-1200/1500 PLCs. + + Supports V1 protocol. V2/V3/TLS planned for future. + + Uses asyncio for all I/O operations and asyncio.Lock for + concurrent safety when shared between multiple coroutines. + + When the PLC does not support S7CommPlus data operations, the client + automatically falls back to legacy S7 protocol for db_read/db_write. + """ + + def __init__(self) -> None: + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._session_id: int = 0 + self._sequence_number: int = 0 + self._protocol_version: int = 0 + self._connected = False + self._lock = asyncio.Lock() + self._legacy_client: Optional[Any] = None + self._use_legacy_data: bool = False + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def connected(self) -> bool: + if self._use_legacy_data and self._legacy_client is not None: + return bool(self._legacy_client.connected) + return self._connected + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def session_id(self) -> int: + return self._session_id + + @property + def using_legacy_fallback(self) -> bool: + """Whether the client is using legacy S7 protocol for data operations.""" + return self._use_legacy_data + + async def connect( + self, + host: str, + port: int = 102, + rack: int = 0, + slot: int = 1, + ) -> None: + """Connect to an S7-1200/1500 PLC. + + If the PLC does not support S7CommPlus data operations, a secondary + legacy S7 connection is established transparently for data access. + + Args: + host: PLC IP address or hostname + port: TCP port (default 102) + rack: PLC rack number + slot: PLC slot number + """ + self._host = host + self._port = port + self._rack = rack + self._slot = slot + + # TCP connect + self._reader, self._writer = await asyncio.open_connection(host, port) + + try: + # COTP handshake with S7CommPlus TSAP values + await self._cotp_connect(S7COMMPLUS_LOCAL_TSAP, S7COMMPLUS_REMOTE_TSAP) + + # InitSSL handshake + await self._init_ssl() + + # S7CommPlus session setup + await self._create_session() + + self._connected = True + logger.info( + f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" + ) + + # Probe S7CommPlus data operations + if not await self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + await self._setup_legacy_fallback() + + except Exception: + await self.disconnect() + raise + + async def _probe_s7commplus_data(self) -> bool: + """Test if the PLC supports S7CommPlus data operations.""" + try: + payload = struct.pack(">I", 0) + encode_uint32_vlq(0) + encode_uint32_vlq(0) + payload += encode_object_qualifier() + payload += struct.pack(">I", 0) + + response = await self._send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + if len(response) < 1: + return False + return_value, _ = decode_uint64_vlq(response, 0) + if return_value != 0: + logger.debug(f"S7CommPlus probe: PLC returned error {return_value}") + return False + return True + except Exception as e: + logger.debug(f"S7CommPlus probe failed: {e}") + return False + + async def _setup_legacy_fallback(self) -> None: + """Establish a secondary legacy S7 connection for data operations.""" + from ..client import Client + + loop = asyncio.get_event_loop() + client = Client() + await loop.run_in_executor(None, lambda: client.connect(self._host, self._rack, self._slot, self._port)) + self._legacy_client = client + self._use_legacy_data = True + logger.info(f"Legacy S7 fallback connected to {self._host}:{self._port}") + + async def disconnect(self) -> None: + """Disconnect from PLC.""" + if self._legacy_client is not None: + try: + self._legacy_client.disconnect() + except Exception: + pass + self._legacy_client = None + self._use_legacy_data = False + + if self._connected and self._session_id: + try: + await self._delete_session() + except Exception: + pass + + self._connected = False + self._session_id = 0 + self._sequence_number = 0 + self._protocol_version = 0 + + if self._writer: + try: + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + self._writer = None + self._reader = None + + async def db_read(self, db_number: int, start: int, size: int) -> bytes: + """Read raw bytes from a data block. + + Args: + db_number: Data block number + start: Start byte offset + size: Number of bytes to read + + Returns: + Raw bytes read from the data block + """ + if self._use_legacy_data and self._legacy_client is not None: + client = self._legacy_client + loop = asyncio.get_event_loop() + data = await loop.run_in_executor(None, lambda: client.db_read(db_number, start, size)) + return bytes(data) + + payload = _build_read_payload([(db_number, start, size)]) + response = await self._send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + + results = _parse_read_response(response) + if not results: + raise RuntimeError("Read returned no data") + if results[0] is None: + raise RuntimeError("Read failed: PLC returned error for item") + return results[0] + + async def db_write(self, db_number: int, start: int, data: bytes) -> None: + """Write raw bytes to a data block. + + Args: + db_number: Data block number + start: Start byte offset + data: Bytes to write + """ + if self._use_legacy_data and self._legacy_client is not None: + client = self._legacy_client + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: client.db_write(db_number, start, bytearray(data))) + return + + payload = _build_write_payload([(db_number, start, data)]) + response = await self._send_request(FunctionCode.SET_MULTI_VARIABLES, payload) + _parse_write_response(response) + + async def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytes]: + """Read multiple data block regions in a single request. + + Args: + items: List of (db_number, start_offset, size) tuples + + Returns: + List of raw bytes for each item + """ + if self._use_legacy_data and self._legacy_client is not None: + client = self._legacy_client + loop = asyncio.get_event_loop() + multi_results: list[bytes] = [] + for db_number, start, size in items: + + def _read(db: int = db_number, s: int = start, sz: int = size) -> bytearray: + return bytearray(client.db_read(db, s, sz)) + + data = await loop.run_in_executor(None, _read) + multi_results.append(bytes(data)) + return multi_results + + payload = _build_read_payload(items) + response = await self._send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + + parsed = _parse_read_response(response) + return [r if r is not None else b"" for r in parsed] + + async def explore(self) -> bytes: + """Browse the PLC object tree. + + Returns: + Raw response payload + """ + return await self._send_request(FunctionCode.EXPLORE, b"") + + # -- Internal methods -- + + async def _send_request(self, function_code: int, payload: bytes) -> bytes: + """Send an S7CommPlus request and receive the response.""" + async with self._lock: + if not self._connected or self._writer is None or self._reader is None: + raise RuntimeError("Not connected") + + seq_num = self._next_sequence_number() + + request = ( + struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + function_code, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + + payload + ) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("Response too short") + + return response[14:] + + async def _cotp_connect(self, local_tsap: int, remote_tsap: bytes) -> None: + """Perform COTP Connection Request / Confirm handshake.""" + if self._writer is None or self._reader is None: + raise RuntimeError("Not connected") + + # Build COTP CR + base_pdu = struct.pack(">BBHHB", 6, _COTP_CR, 0x0000, 0x0001, 0x00) + calling_tsap = struct.pack(">BBH", 0xC1, 2, local_tsap) + called_tsap = struct.pack(">BB", 0xC2, len(remote_tsap)) + remote_tsap + pdu_size_param = struct.pack(">BBB", 0xC0, 1, 0x0A) + + params = calling_tsap + called_tsap + pdu_size_param + cr_pdu = struct.pack(">B", 6 + len(params)) + base_pdu[1:] + params + + # Send TPKT + CR + tpkt = struct.pack(">BBH", 3, 0, 4 + len(cr_pdu)) + cr_pdu + self._writer.write(tpkt) + await self._writer.drain() + + # Receive TPKT + CC + tpkt_header = await self._reader.readexactly(4) + _, _, length = struct.unpack(">BBH", tpkt_header) + payload = await self._reader.readexactly(length - 4) + + if len(payload) < 7 or payload[1] != _COTP_CC: + raise RuntimeError(f"Expected COTP CC, got {payload[1]:#04x}") + + async def _init_ssl(self) -> None: + """Send InitSSL request (required before CreateObject).""" + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.INIT_SSL, + 0x0000, + seq_num, + 0x00000000, + 0x30, # Transport flags for InitSSL + ) + request += struct.pack(">I", 0) + + frame = encode_header(ProtocolVersion.V1, len(request)) + request + frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("InitSSL response too short") + + logger.debug(f"InitSSL response received, version=V{version}") + + async def _create_session(self) -> None: + """Send CreateObject to establish S7CommPlus session.""" + seq_num = self._next_sequence_number() + + # Build CreateObject request header + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.CREATE_OBJECT, + 0x0000, + seq_num, + ObjectId.OBJECT_NULL_SERVER_SESSION, # SessionId = 288 + 0x36, + ) + + # RequestId: ObjectServerSessionContainer (285) + request += struct.pack(">I", ObjectId.OBJECT_SERVER_SESSION_CONTAINER) + + # RequestValue: ValueUDInt(0) + request += bytes([0x00, DataType.UDINT]) + encode_uint32_vlq(0) + + # Unknown padding + request += struct.pack(">I", 0) + + # RequestObject: NullServerSession PObject + request += bytes([ElementID.START_OF_OBJECT]) + request += struct.pack(">I", ObjectId.GET_NEW_RID_ON_SERVER) + request += encode_uint32_vlq(ObjectId.CLASS_SERVER_SESSION) + request += encode_uint32_vlq(0) # ClassFlags + request += encode_uint32_vlq(0) # AttributeId + + # Attribute: ServerSessionClientRID = 0x80c3c901 + request += bytes([ElementID.ATTRIBUTE]) + request += encode_uint32_vlq(ObjectId.SERVER_SESSION_CLIENT_RID) + request += encode_typed_value(DataType.RID, 0x80C3C901) + + # Nested: ClassSubscriptions + request += bytes([ElementID.START_OF_OBJECT]) + request += struct.pack(">I", ObjectId.GET_NEW_RID_ON_SERVER) + request += encode_uint32_vlq(ObjectId.CLASS_SUBSCRIPTIONS) + request += encode_uint32_vlq(0) + request += encode_uint32_vlq(0) + request += bytes([ElementID.TERMINATING_OBJECT]) + + request += bytes([ElementID.TERMINATING_OBJECT]) + request += struct.pack(">I", 0) + + # Frame header + trailer + frame = encode_header(ProtocolVersion.V1, len(request)) + request + frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("CreateObject response too short") + + self._session_id = struct.unpack_from(">I", response, 9)[0] + self._protocol_version = version + + async def _delete_session(self) -> None: + """Send DeleteObject to close the session.""" + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.DELETE_OBJECT, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + request += struct.pack(">I", 0) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + await self._send_cotp_dt(frame) + + try: + await asyncio.wait_for(self._recv_cotp_dt(), timeout=1.0) + except Exception: + pass + + async def _send_cotp_dt(self, data: bytes) -> None: + """Send data wrapped in COTP DT + TPKT.""" + if self._writer is None: + raise RuntimeError("Not connected") + + cotp_dt = struct.pack(">BBB", 2, _COTP_DT, 0x80) + data + tpkt = struct.pack(">BBH", 3, 0, 4 + len(cotp_dt)) + cotp_dt + self._writer.write(tpkt) + await self._writer.drain() + + async def _recv_cotp_dt(self) -> bytes: + """Receive TPKT + COTP DT and return the payload.""" + if self._reader is None: + raise RuntimeError("Not connected") + + tpkt_header = await self._reader.readexactly(4) + _, _, length = struct.unpack(">BBH", tpkt_header) + payload = await self._reader.readexactly(length - 4) + + if len(payload) < 3 or payload[1] != _COTP_DT: + raise RuntimeError(f"Expected COTP DT, got {payload[1]:#04x}") + + return payload[3:] + + def _next_sequence_number(self) -> int: + seq = self._sequence_number + self._sequence_number = (self._sequence_number + 1) & 0xFFFF + return seq + + async def __aenter__(self) -> "S7CommPlusAsyncClient": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.disconnect() diff --git a/snap7/s7commplus/client.py b/snap7/s7commplus/client.py new file mode 100644 index 00000000..d5b38a40 --- /dev/null +++ b/snap7/s7commplus/client.py @@ -0,0 +1,510 @@ +""" +S7CommPlus client for S7-1200/1500 PLCs. + +Provides high-level operations over the S7CommPlus protocol, similar to +the existing snap7.Client but targeting S7-1200/1500 PLCs with full +engineering access (symbolic addressing, optimized data blocks, etc.). + +Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol +version is auto-detected from the PLC's CreateObject response during +connection setup. + +When a PLC does not support S7CommPlus data operations (e.g. PLCs that +accept S7CommPlus sessions but return ERROR2 for GetMultiVariables), +the client transparently falls back to the legacy S7 protocol for +data block read/write operations. + +Status: V1 connection is functional. V2/V3/TLS authentication planned. + +Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) +""" + +import logging +import struct +from typing import Any, Optional + +from .connection import S7CommPlusConnection +from .protocol import FunctionCode, Ids +from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq +from .codec import ( + encode_item_address, + encode_object_qualifier, + encode_pvalue_blob, + decode_pvalue_to_bytes, +) + +logger = logging.getLogger(__name__) + + +class S7CommPlusClient: + """S7CommPlus client for S7-1200/1500 PLCs. + + Supports all S7CommPlus protocol versions: + - V1: S7-1200 FW V4.0+ + - V2: S7-1200/1500 with older firmware + - V3: S7-1200/1500 pre-TIA Portal V17 + - V3 + TLS: TIA Portal V17+ (recommended) + + The protocol version is auto-detected during connection. + + When the PLC does not support S7CommPlus data operations, the client + automatically falls back to legacy S7 protocol for db_read/db_write. + + Example:: + + client = S7CommPlusClient() + client.connect("192.168.1.10") + + # Read raw bytes from DB1 + data = client.db_read(1, 0, 4) + + # Write raw bytes to DB1 + client.db_write(1, 0, struct.pack(">f", 23.5)) + + client.disconnect() + """ + + def __init__(self) -> None: + self._connection: Optional[S7CommPlusConnection] = None + self._legacy_client: Optional[Any] = None + self._use_legacy_data: bool = False + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def connected(self) -> bool: + if self._use_legacy_data and self._legacy_client is not None: + return bool(self._legacy_client.connected) + return self._connection is not None and self._connection.connected + + @property + def protocol_version(self) -> int: + """Protocol version negotiated with the PLC.""" + if self._connection is None: + return 0 + return self._connection.protocol_version + + @property + def session_id(self) -> int: + """Session ID assigned by the PLC.""" + if self._connection is None: + return 0 + return self._connection.session_id + + @property + def using_legacy_fallback(self) -> bool: + """Whether the client is using legacy S7 protocol for data operations.""" + return self._use_legacy_data + + def connect( + self, + host: str, + port: int = 102, + rack: int = 0, + slot: int = 1, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Connect to an S7-1200/1500 PLC using S7CommPlus. + + If the PLC does not support S7CommPlus data operations, a secondary + legacy S7 connection is established transparently for data access. + + Args: + host: PLC IP address or hostname + port: TCP port (default 102) + rack: PLC rack number + slot: PLC slot number + use_tls: Whether to attempt TLS (requires V3 PLC + certs) + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + """ + self._host = host + self._port = port + self._rack = rack + self._slot = slot + + self._connection = S7CommPlusConnection( + host=host, + port=port, + ) + + self._connection.connect( + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + tls_ca=tls_ca, + ) + + # Probe S7CommPlus data operations with a minimal request + if not self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + self._setup_legacy_fallback() + + def _probe_s7commplus_data(self) -> bool: + """Test if the PLC supports S7CommPlus data operations. + + Sends a minimal GetMultiVariables request with zero items. If the PLC + responds with ERROR2 or a non-zero return code, data operations are + not supported. + + Returns: + True if S7CommPlus data operations work. + """ + if self._connection is None: + return False + + try: + # Send a minimal GetMultiVariables with 0 items + payload = struct.pack(">I", 0) + encode_uint32_vlq(0) + encode_uint32_vlq(0) + payload += encode_object_qualifier() + payload += struct.pack(">I", 0) + + response = self._connection.send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + + # Check if we got a valid response (return value = 0) + if len(response) < 1: + return False + return_value, _ = decode_uint64_vlq(response, 0) + if return_value != 0: + logger.debug(f"S7CommPlus probe: PLC returned error {return_value}") + return False + return True + except Exception as e: + logger.debug(f"S7CommPlus probe failed: {e}") + return False + + def _setup_legacy_fallback(self) -> None: + """Establish a secondary legacy S7 connection for data operations.""" + from ..client import Client + + self._legacy_client = Client() + self._legacy_client.connect(self._host, self._rack, self._slot, self._port) + self._use_legacy_data = True + logger.info(f"Legacy S7 fallback connected to {self._host}:{self._port}") + + def disconnect(self) -> None: + """Disconnect from PLC.""" + if self._legacy_client is not None: + try: + self._legacy_client.disconnect() + except Exception: + pass + self._legacy_client = None + self._use_legacy_data = False + + if self._connection: + self._connection.disconnect() + self._connection = None + + # -- Data block read/write -- + + def db_read(self, db_number: int, start: int, size: int) -> bytes: + """Read raw bytes from a data block. + + Uses S7CommPlus protocol when supported, otherwise falls back to + legacy S7 protocol transparently. + + Args: + db_number: Data block number + start: Start byte offset + size: Number of bytes to read + + Returns: + Raw bytes read from the data block + """ + if self._use_legacy_data and self._legacy_client is not None: + return bytes(self._legacy_client.db_read(db_number, start, size)) + + if self._connection is None: + raise RuntimeError("Not connected") + + payload = _build_read_payload([(db_number, start, size)]) + logger.debug(f"db_read: db={db_number} start={start} size={size} payload={payload.hex(' ')}") + + response = self._connection.send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + logger.debug(f"db_read: response ({len(response)} bytes): {response.hex(' ')}") + + results = _parse_read_response(response) + if not results: + raise RuntimeError("Read returned no data") + if results[0] is None: + raise RuntimeError("Read failed: PLC returned error for item") + return results[0] + + def db_write(self, db_number: int, start: int, data: bytes) -> None: + """Write raw bytes to a data block. + + Uses S7CommPlus protocol when supported, otherwise falls back to + legacy S7 protocol transparently. + + Args: + db_number: Data block number + start: Start byte offset + data: Bytes to write + """ + if self._use_legacy_data and self._legacy_client is not None: + self._legacy_client.db_write(db_number, start, bytearray(data)) + return + + if self._connection is None: + raise RuntimeError("Not connected") + + payload = _build_write_payload([(db_number, start, data)]) + logger.debug( + f"db_write: db={db_number} start={start} data_len={len(data)} data={data.hex(' ')} payload={payload.hex(' ')}" + ) + + response = self._connection.send_request(FunctionCode.SET_MULTI_VARIABLES, payload) + logger.debug(f"db_write: response ({len(response)} bytes): {response.hex(' ')}") + + _parse_write_response(response) + + def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytes]: + """Read multiple data block regions in a single request. + + Uses S7CommPlus protocol when supported, otherwise falls back to + legacy S7 protocol (individual reads) transparently. + + Args: + items: List of (db_number, start_offset, size) tuples + + Returns: + List of raw bytes for each item + """ + if self._use_legacy_data and self._legacy_client is not None: + results = [] + for db_number, start, size in items: + data = self._legacy_client.db_read(db_number, start, size) + results.append(bytes(data)) + return results + + if self._connection is None: + raise RuntimeError("Not connected") + + payload = _build_read_payload(items) + logger.debug(f"db_read_multi: {len(items)} items: {items} payload={payload.hex(' ')}") + + response = self._connection.send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + logger.debug(f"db_read_multi: response ({len(response)} bytes): {response.hex(' ')}") + + parsed = _parse_read_response(response) + return [r if r is not None else b"" for r in parsed] + + # -- Explore (browse PLC object tree) -- + + def explore(self) -> bytes: + """Browse the PLC object tree. + + Returns the raw Explore response payload for parsing. + Full symbolic exploration will be implemented in a future version. + + Returns: + Raw response payload + """ + if self._connection is None: + raise RuntimeError("Not connected") + + response = self._connection.send_request(FunctionCode.EXPLORE, b"") + logger.debug(f"explore: response ({len(response)} bytes): {response.hex(' ')}") + return response + + # -- Context manager -- + + def __enter__(self) -> "S7CommPlusClient": + return self + + def __exit__(self, *args: Any) -> None: + self.disconnect() + + +# -- Request/response builders (module-level for reuse by async client) -- + + +def _build_read_payload(items: list[tuple[int, int, int]]) -> bytes: + """Build a GetMultiVariables request payload. + + Args: + items: List of (db_number, start_offset, size) tuples + + Returns: + Encoded payload bytes (after the 14-byte request header) + + Reference: thomas-v2/S7CommPlusDriver/Core/GetMultiVariablesRequest.cs + """ + # Encode all item addresses and compute total field count + addresses: list[bytes] = [] + total_field_count = 0 + for db_number, start, size in items: + access_area = Ids.DB_ACCESS_AREA_BASE + (db_number & 0xFFFF) + addr_bytes, field_count = encode_item_address( + access_area=access_area, + access_sub_area=Ids.DB_VALUE_ACTUAL, + lids=[start + 1, size], # LID byte offsets are 1-based in S7CommPlus + ) + addresses.append(addr_bytes) + total_field_count += field_count + + payload = bytearray() + # LinkId (UInt32 fixed = 0, for reading variables) + payload += struct.pack(">I", 0) + # Item count + payload += encode_uint32_vlq(len(items)) + # Total field count across all items + payload += encode_uint32_vlq(total_field_count) + # Item addresses + for addr in addresses: + payload += addr + # ObjectQualifier + payload += encode_object_qualifier() + # Padding + payload += struct.pack(">I", 0) + + return bytes(payload) + + +def _parse_read_response(response: bytes) -> list[Optional[bytes]]: + """Parse a GetMultiVariables response payload. + + Args: + response: Response payload (after the 14-byte response header) + + Returns: + List of raw bytes per item (None for errored items) + + Reference: thomas-v2/S7CommPlusDriver/Core/GetMultiVariablesResponse.cs + """ + offset = 0 + + # ReturnValue (UInt64 VLQ) + return_value, consumed = decode_uint64_vlq(response, offset) + offset += consumed + logger.debug(f"_parse_read_response: return_value={return_value}") + + if return_value != 0: + logger.error(f"_parse_read_response: PLC returned error: {return_value}") + return [] + + # Value list: ItemNumber (VLQ) + PValue, terminated by ItemNumber=0 + values: dict[int, bytes] = {} + while offset < len(response): + item_nr, consumed = decode_uint32_vlq(response, offset) + offset += consumed + if item_nr == 0: + break + raw_bytes, consumed = decode_pvalue_to_bytes(response, offset) + offset += consumed + values[item_nr] = raw_bytes + + # Error list: ErrorItemNumber (VLQ) + ErrorReturnValue (UInt64 VLQ), terminated by 0 + errors: dict[int, int] = {} + while offset < len(response): + err_item_nr, consumed = decode_uint32_vlq(response, offset) + offset += consumed + if err_item_nr == 0: + break + err_value, consumed = decode_uint64_vlq(response, offset) + offset += consumed + errors[err_item_nr] = err_value + logger.debug(f"_parse_read_response: error item {err_item_nr}: {err_value}") + + # Build result list (1-based item numbers) + max_item = max(max(values.keys(), default=0), max(errors.keys(), default=0)) + results: list[Optional[bytes]] = [] + for i in range(1, max_item + 1): + if i in values: + results.append(values[i]) + else: + results.append(None) + + return results + + +def _build_write_payload(items: list[tuple[int, int, bytes]]) -> bytes: + """Build a SetMultiVariables request payload. + + Args: + items: List of (db_number, start_offset, data) tuples + + Returns: + Encoded payload bytes + + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs + """ + # Encode all item addresses and compute total field count + addresses: list[bytes] = [] + total_field_count = 0 + for db_number, start, data in items: + access_area = Ids.DB_ACCESS_AREA_BASE + (db_number & 0xFFFF) + addr_bytes, field_count = encode_item_address( + access_area=access_area, + access_sub_area=Ids.DB_VALUE_ACTUAL, + lids=[start + 1, len(data)], # LID byte offsets are 1-based in S7CommPlus + ) + addresses.append(addr_bytes) + total_field_count += field_count + + payload = bytearray() + # InObjectId (UInt32 fixed = 0, for plain variable writes) + payload += struct.pack(">I", 0) + # Item count + payload += encode_uint32_vlq(len(items)) + # Total field count + payload += encode_uint32_vlq(total_field_count) + # Item addresses + for addr in addresses: + payload += addr + # Value list: ItemNumber (1-based) + PValue + for i, (_, _, data) in enumerate(items, 1): + payload += encode_uint32_vlq(i) + payload += encode_pvalue_blob(data) + # Fill byte + payload += bytes([0x00]) + # ObjectQualifier + payload += encode_object_qualifier() + # Padding + payload += struct.pack(">I", 0) + + return bytes(payload) + + +def _parse_write_response(response: bytes) -> None: + """Parse a SetMultiVariables response payload. + + Args: + response: Response payload (after the 14-byte response header) + + Raises: + RuntimeError: If the write failed + + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesResponse.cs + """ + offset = 0 + + # ReturnValue (UInt64 VLQ) + return_value, consumed = decode_uint64_vlq(response, offset) + offset += consumed + logger.debug(f"_parse_write_response: return_value={return_value}") + + if return_value != 0: + raise RuntimeError(f"Write failed with return value {return_value}") + + # Error list: ErrorItemNumber (VLQ) + ErrorReturnValue (UInt64 VLQ) + errors: list[tuple[int, int]] = [] + while offset < len(response): + err_item_nr, consumed = decode_uint32_vlq(response, offset) + offset += consumed + if err_item_nr == 0: + break + err_value, consumed = decode_uint64_vlq(response, offset) + offset += consumed + errors.append((err_item_nr, err_value)) + + if errors: + err_str = ", ".join(f"item {nr}: error {val}" for nr, val in errors) + raise RuntimeError(f"Write failed: {err_str}") diff --git a/snap7/s7commplus/codec.py b/snap7/s7commplus/codec.py new file mode 100644 index 00000000..74f94a2e --- /dev/null +++ b/snap7/s7commplus/codec.py @@ -0,0 +1,495 @@ +""" +S7CommPlus data encoding and decoding. + +Provides serialization for the S7CommPlus wire format including: +- Fixed-width integers (big-endian) +- VLQ-encoded integers +- Floating point values +- Strings (UTF-8 encoded WStrings) +- Blobs (raw byte arrays) +- S7CommPlus frame header + +Reference: thomas-v2/S7CommPlusDriver/Core/S7p.cs +""" + +import struct +from typing import Any + +from .protocol import PROTOCOL_ID, DataType, Ids +from .vlq import ( + encode_uint32_vlq, + decode_uint32_vlq, + encode_int32_vlq, + encode_uint64_vlq, + decode_uint64_vlq, + encode_int64_vlq, +) + + +def encode_header(version: int, data_length: int) -> bytes: + """Encode an S7CommPlus frame header. + + Header format (4 bytes):: + + [0] Protocol ID: 0x72 + [1] Protocol version + [2-3] Data length (big-endian uint16) + + Args: + version: Protocol version byte + data_length: Length of data following the header + + Returns: + 4-byte header + """ + return struct.pack(">BBH", PROTOCOL_ID, version, data_length) + + +def decode_header(data: bytes, offset: int = 0) -> tuple[int, int, int]: + """Decode an S7CommPlus frame header. + + Args: + data: Buffer containing the header + offset: Starting position + + Returns: + Tuple of (protocol_version, data_length, bytes_consumed) + + Raises: + ValueError: If protocol ID is not 0x72 + """ + if len(data) - offset < 4: + raise ValueError("Not enough data for S7CommPlus header") + + proto_id, version, length = struct.unpack_from(">BBH", data, offset) + + if proto_id != PROTOCOL_ID: + raise ValueError(f"Invalid protocol ID: {proto_id:#04x}, expected {PROTOCOL_ID:#04x}") + + return version, length, 4 + + +def encode_request_header( + function_code: int, + sequence_number: int, + session_id: int = 0, + transport_flags: int = 0x36, +) -> bytes: + """Encode an S7CommPlus request header (after the frame header). + + Request header format:: + + [0] Opcode: 0x31 (Request) + [1-2] Reserved: 0x0000 + [3-4] Function code (big-endian uint16) + [5-6] Reserved: 0x0000 + [7-8] Sequence number (big-endian uint16) + [9-12] Session ID (big-endian uint32) + [13] Transport flags + + Args: + function_code: S7CommPlus function code + sequence_number: Request sequence number + session_id: Session identifier (0 for initial connection) + transport_flags: Transport flags byte + + Returns: + 14-byte request header + """ + from .protocol import Opcode + + return struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, # Reserved + function_code, + 0x0000, # Reserved + sequence_number, + session_id, + transport_flags, + ) + + +def decode_response_header(data: bytes, offset: int = 0) -> dict[str, Any]: + """Decode an S7CommPlus response header. + + Args: + data: Buffer containing the response + offset: Starting position + + Returns: + Dictionary with opcode, function_code, sequence_number, session_id, + transport_flags, and bytes_consumed + """ + if len(data) - offset < 14: + raise ValueError("Not enough data for S7CommPlus response header") + + opcode, reserved1, function_code, reserved2, seq_num, session_id, transport_flags = struct.unpack_from( + ">BHHHHIB", data, offset + ) + + return { + "opcode": opcode, + "function_code": function_code, + "sequence_number": seq_num, + "session_id": session_id, + "transport_flags": transport_flags, + "bytes_consumed": 14, + } + + +# -- Fixed-width encoding (big-endian) -- + + +def encode_uint8(value: int) -> bytes: + return struct.pack(">B", value) + + +def decode_uint8(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">B", data, offset)[0], 1 + + +def encode_uint16(value: int) -> bytes: + return struct.pack(">H", value) + + +def decode_uint16(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">H", data, offset)[0], 2 + + +def encode_uint32(value: int) -> bytes: + return struct.pack(">I", value) + + +def decode_uint32(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">I", data, offset)[0], 4 + + +def encode_uint64(value: int) -> bytes: + return struct.pack(">Q", value) + + +def decode_uint64(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">Q", data, offset)[0], 8 + + +def encode_int16(value: int) -> bytes: + return struct.pack(">h", value) + + +def decode_int16(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">h", data, offset)[0], 2 + + +def encode_int32(value: int) -> bytes: + return struct.pack(">i", value) + + +def decode_int32(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">i", data, offset)[0], 4 + + +def encode_int64(value: int) -> bytes: + return struct.pack(">q", value) + + +def decode_int64(data: bytes, offset: int = 0) -> tuple[int, int]: + return struct.unpack_from(">q", data, offset)[0], 8 + + +def encode_float32(value: float) -> bytes: + return struct.pack(">f", value) + + +def decode_float32(data: bytes, offset: int = 0) -> tuple[float, int]: + return struct.unpack_from(">f", data, offset)[0], 4 + + +def encode_float64(value: float) -> bytes: + return struct.pack(">d", value) + + +def decode_float64(data: bytes, offset: int = 0) -> tuple[float, int]: + return struct.unpack_from(">d", data, offset)[0], 8 + + +# -- String encoding -- + + +def encode_wstring(value: str) -> bytes: + """Encode a string as UTF-8 (S7CommPlus WString wire format).""" + return value.encode("utf-8") + + +def decode_wstring(data: bytes, offset: int, length: int) -> tuple[str, int]: + """Decode a UTF-8 string. + + Args: + data: Buffer + offset: Start position + length: Number of bytes to decode + + Returns: + Tuple of (decoded_string, bytes_consumed) + """ + return data[offset : offset + length].decode("utf-8"), length + + +# -- Typed value encoding -- + + +def encode_typed_value(datatype: int, value: Any) -> bytes: + """Encode a value with its type tag. + + This prepends the DataType byte before the encoded value, which is how + attribute values are serialized in the S7CommPlus object model. + + Args: + datatype: DataType enum value + value: Value to encode + + Returns: + Type-tagged encoded value + """ + tag = struct.pack(">B", datatype) + + if datatype == DataType.NULL: + return tag + elif datatype == DataType.BOOL: + return tag + struct.pack(">B", 1 if value else 0) + elif datatype == DataType.USINT or datatype == DataType.BYTE: + return tag + struct.pack(">B", value) + elif datatype == DataType.UINT or datatype == DataType.WORD: + return tag + struct.pack(">H", value) + elif datatype == DataType.UDINT or datatype == DataType.DWORD: + return tag + encode_uint32_vlq(value) + elif datatype == DataType.ULINT or datatype == DataType.LWORD: + return tag + encode_uint64_vlq(value) + elif datatype == DataType.SINT: + return tag + struct.pack(">b", value) + elif datatype == DataType.INT: + return tag + struct.pack(">h", value) + elif datatype == DataType.DINT: + return tag + encode_int32_vlq(value) + elif datatype == DataType.LINT: + return tag + encode_int64_vlq(value) + elif datatype == DataType.REAL: + return tag + struct.pack(">f", value) + elif datatype == DataType.LREAL: + return tag + struct.pack(">d", value) + elif datatype == DataType.TIMESTAMP: + return tag + struct.pack(">Q", value) + elif datatype == DataType.TIMESPAN: + return tag + encode_int64_vlq(value) + elif datatype == DataType.RID: + return tag + struct.pack(">I", value) + elif datatype == DataType.AID: + return tag + encode_uint32_vlq(value) + elif datatype == DataType.WSTRING: + encoded: bytes = value.encode("utf-8") + return tag + encode_uint32_vlq(len(encoded)) + encoded + elif datatype == DataType.BLOB: + return bytes(tag + encode_uint32_vlq(len(value)) + value) + else: + raise ValueError(f"Unsupported DataType for encoding: {datatype:#04x}") + + +# -- S7CommPlus request/response payload helpers -- + + +def encode_object_qualifier() -> bytes: + """Encode the S7CommPlus ObjectQualifier structure. + + This fixed structure is appended to GetMultiVariables and + SetMultiVariables requests. + + Reference: thomas-v2/S7CommPlusDriver/Core/S7p.cs EncodeObjectQualifier + """ + result = bytearray() + result += struct.pack(">I", Ids.OBJECT_QUALIFIER) + # ParentRID = RID(0) + result += encode_uint32_vlq(Ids.PARENT_RID) + result += bytes([0x00, DataType.RID]) + struct.pack(">I", 0) + # CompositionAID = AID(0) + result += encode_uint32_vlq(Ids.COMPOSITION_AID) + result += bytes([0x00, DataType.AID]) + encode_uint32_vlq(0) + # KeyQualifier = UDInt(0) + result += encode_uint32_vlq(Ids.KEY_QUALIFIER) + result += bytes([0x00, DataType.UDINT]) + encode_uint32_vlq(0) + # Terminator + result += bytes([0x00]) + return bytes(result) + + +def encode_item_address( + access_area: int, + access_sub_area: int, + lids: list[int] | None = None, + symbol_crc: int = 0, +) -> tuple[bytes, int]: + """Encode an S7CommPlus ItemAddress for variable access. + + Args: + access_area: Access area ID (e.g., 0x8A0E0001 for DB1) + access_sub_area: Sub-area ID (e.g., Ids.DB_VALUE_ACTUAL) + lids: Additional LID values for sub-addressing + symbol_crc: Symbol CRC (0 for no CRC check) + + Returns: + Tuple of (encoded_bytes, field_count) + + Reference: thomas-v2/S7CommPlusDriver/ClientApi/ItemAddress.cs + """ + if lids is None: + lids = [] + result = bytearray() + result += encode_uint32_vlq(symbol_crc) + result += encode_uint32_vlq(access_area) + result += encode_uint32_vlq(len(lids) + 1) # +1 for AccessSubArea + result += encode_uint32_vlq(access_sub_area) + for lid in lids: + result += encode_uint32_vlq(lid) + field_count = 4 + len(lids) # SymbolCrc + AccessArea + NumLIDs + AccessSubArea + LIDs + return bytes(result), field_count + + +def encode_pvalue_blob(data: bytes) -> bytes: + """Encode raw bytes as a BLOB PValue. + + PValue format: [flags:1][datatype:1][length:VLQ][data] + """ + result = bytearray() + result += bytes([0x00, DataType.BLOB]) + result += encode_uint32_vlq(len(data)) + result += data + return bytes(result) + + +def decode_pvalue_to_bytes(data: bytes, offset: int) -> tuple[bytes, int]: + """Decode a PValue from S7CommPlus response to raw bytes. + + Supports scalar types and BLOBs. Returns the raw big-endian bytes + of the value regardless of type. + + Args: + data: Response buffer + offset: Position of the PValue + + Returns: + Tuple of (raw_bytes, bytes_consumed) + """ + if offset + 2 > len(data): + raise ValueError("Not enough data for PValue header") + + flags = data[offset] + datatype = data[offset + 1] + consumed = 2 + + is_array = bool(flags & 0x10) + + if is_array: + # Array: read count then elements + count, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + elem_size = _pvalue_element_size(datatype) + if elem_size > 0: + raw = data[offset + consumed : offset + consumed + count * elem_size] + consumed += count * elem_size + return bytes(raw), consumed + else: + # Variable-length elements (VLQ encoded) + result = bytearray() + for _ in range(count): + val, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + result += encode_uint32_vlq(val) + return bytes(result), consumed + + # Scalar types + if datatype == DataType.NULL: + return b"", consumed + elif datatype == DataType.BOOL: + return data[offset + consumed : offset + consumed + 1], consumed + 1 + elif datatype in (DataType.USINT, DataType.BYTE, DataType.SINT): + return data[offset + consumed : offset + consumed + 1], consumed + 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return data[offset + consumed : offset + consumed + 2], consumed + 2 + elif datatype in (DataType.UDINT, DataType.DWORD): + val, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + return struct.pack(">I", val), consumed + elif datatype in (DataType.DINT,): + # Signed VLQ + from .vlq import decode_int32_vlq + + val, c = decode_int32_vlq(data, offset + consumed) + consumed += c + return struct.pack(">i", val), consumed + elif datatype == DataType.REAL: + return data[offset + consumed : offset + consumed + 4], consumed + 4 + elif datatype == DataType.LREAL: + return data[offset + consumed : offset + consumed + 8], consumed + 8 + elif datatype in (DataType.ULINT, DataType.LWORD): + val, c = decode_uint64_vlq(data, offset + consumed) + consumed += c + return struct.pack(">Q", val), consumed + elif datatype in (DataType.LINT,): + from .vlq import decode_int64_vlq + + val, c = decode_int64_vlq(data, offset + consumed) + consumed += c + return struct.pack(">q", val), consumed + elif datatype == DataType.TIMESTAMP: + return data[offset + consumed : offset + consumed + 8], consumed + 8 + elif datatype == DataType.TIMESPAN: + from .vlq import decode_int64_vlq + + val, c = decode_int64_vlq(data, offset + consumed) + consumed += c + return struct.pack(">q", val), consumed + elif datatype == DataType.RID: + return data[offset + consumed : offset + consumed + 4], consumed + 4 + elif datatype == DataType.AID: + val, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + return struct.pack(">I", val), consumed + elif datatype == DataType.BLOB: + length, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + raw = data[offset + consumed : offset + consumed + length] + consumed += length + return bytes(raw), consumed + elif datatype == DataType.WSTRING: + length, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + raw = data[offset + consumed : offset + consumed + length] + consumed += length + return bytes(raw), consumed + elif datatype == DataType.STRUCT: + # Struct: read count, then nested PValues + count, c = decode_uint32_vlq(data, offset + consumed) + consumed += c + result = bytearray() + for _ in range(count): + val_bytes, c = decode_pvalue_to_bytes(data, offset + consumed) + consumed += c + result += val_bytes + return bytes(result), consumed + else: + raise ValueError(f"Unsupported PValue datatype: {datatype:#04x}") + + +def _pvalue_element_size(datatype: int) -> int: + """Return the fixed byte size for a PValue array element, or 0 for variable-length.""" + if datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return 2 + elif datatype in (DataType.REAL,): + return 4 + elif datatype in (DataType.LREAL, DataType.TIMESTAMP): + return 8 + elif datatype in (DataType.RID,): + return 4 + else: + return 0 # Variable-length (VLQ encoded) diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py new file mode 100644 index 00000000..fbbaf60d --- /dev/null +++ b/snap7/s7commplus/connection.py @@ -0,0 +1,743 @@ +""" +S7CommPlus connection management. + +Establishes an ISO-on-TCP connection to S7-1200/1500 PLCs using the +S7CommPlus protocol, with support for all protocol versions: + +- V1: Early S7-1200 (FW >= V4.0). Simple session handshake. +- V2: Adds integrity checking and session authentication. +- V3: Adds public-key-based key exchange. +- V3 + TLS: TIA Portal V17+. Standard TLS 1.3 with per-device certificates. + +The wire protocol (VLQ encoding, data types, function codes, object model) is +the same across all versions -- only the session authentication layer differs. + +Connection sequence (all versions):: + + 1. TCP connect to port 102 + 2. COTP Connection Request / Confirm + - Local TSAP: 0x0600 + - Remote TSAP: "SIMATIC-ROOT-HMI" (16-byte ASCII string) + 3. InitSSL request / response (unencrypted) + 4. TLS activation (for V3/TLS PLCs) + 5. S7CommPlus CreateObject request (NullServer session setup) + - SessionId = ObjectNullServerSession (288) + - Proper PObject tree with ServerSession class + 6. PLC responds with CreateObject response containing: + - Protocol version (V1/V2/V3) + - Session ID + - Server session challenge (V2/V3) + +Version-specific authentication after step 6:: + + V1: No further authentication needed + V2: Session key derivation and integrity checking + V3 (no TLS): Public-key key exchange + V3 (TLS): TLS 1.3 handshake is already done in step 4 + +Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) +""" + +import logging +import ssl +import struct +from typing import Optional, Type +from types import TracebackType + +from ..connection import ISOTCPConnection +from .protocol import ( + FunctionCode, + Opcode, + ProtocolVersion, + ElementID, + ObjectId, + S7COMMPLUS_LOCAL_TSAP, + S7COMMPLUS_REMOTE_TSAP, +) +from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier +from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq +from .protocol import DataType + +logger = logging.getLogger(__name__) + + +def _element_size(datatype: int) -> int: + """Return the fixed byte size for an array element, or 0 for variable-length.""" + if datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return 2 + elif datatype in (DataType.REAL, DataType.RID): + return 4 + elif datatype in (DataType.LREAL, DataType.TIMESTAMP): + return 8 + else: + return 0 + + +class S7CommPlusConnection: + """S7CommPlus connection with multi-version support. + + Wraps an ISOTCPConnection and adds: + - S7CommPlus session establishment (CreateObject) + - Protocol version detection from PLC response + - Version-appropriate authentication (V1/V2/V3/TLS) + - Frame send/receive (TLS-encrypted when using V17+ firmware) + + Currently implements V1 authentication. V2/V3/TLS authentication + layers are planned for future development. + """ + + def __init__( + self, + host: str, + port: int = 102, + ): + self.host = host + self.port = port + + self._iso_conn = ISOTCPConnection( + host=host, + port=port, + local_tsap=S7COMMPLUS_LOCAL_TSAP, + remote_tsap=S7COMMPLUS_REMOTE_TSAP, + ) + + self._ssl_context: Optional[ssl.SSLContext] = None + self._session_id: int = 0 + self._sequence_number: int = 0 + self._protocol_version: int = 0 # Detected from PLC response + self._tls_active: bool = False + self._connected = False + self._server_session_version: Optional[int] = None + + @property + def connected(self) -> bool: + return self._connected + + @property + def protocol_version(self) -> int: + """Protocol version negotiated with the PLC.""" + return self._protocol_version + + @property + def session_id(self) -> int: + """Session ID assigned by the PLC.""" + return self._session_id + + @property + def tls_active(self) -> bool: + """Whether TLS encryption is active on this connection.""" + return self._tls_active + + def connect( + self, + timeout: float = 5.0, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Establish S7CommPlus connection. + + The connection sequence: + 1. COTP connection (same as legacy S7comm) + 2. CreateObject to establish S7CommPlus session + 3. Protocol version is detected from PLC response + 4. If use_tls=True and PLC supports it, TLS is negotiated + + Args: + timeout: Connection timeout in seconds + use_tls: Whether to attempt TLS negotiation. + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + """ + try: + # Step 1: COTP connection (same TSAP for all S7CommPlus versions) + self._iso_conn.connect(timeout) + + # Step 2: InitSSL handshake (required before CreateObject) + self._init_ssl() + + # Step 3: TLS activation (required for modern firmware) + if use_tls: + # TODO: Perform TLS 1.3 handshake over the existing COTP connection + raise NotImplementedError("TLS activation is not yet implemented. Use use_tls=False for V1 connections.") + + # Step 4: CreateObject (S7CommPlus session setup) + self._create_session() + + # Step 5: Session setup - echo ServerSessionVersion back to PLC + if self._server_session_version is not None: + self._setup_session() + else: + logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") + + # Step 6: Version-specific authentication + if self._protocol_version >= ProtocolVersion.V3: + if not use_tls: + logger.warning( + "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." + ) + elif self._protocol_version == ProtocolVersion.V2: + # TODO: Proprietary HMAC-SHA256/AES session auth + raise NotImplementedError("V2 authentication is not yet implemented.") + + # V1: No further authentication needed after CreateObject + self._connected = True + logger.info( + f"S7CommPlus connected to {self.host}:{self.port}, version=V{self._protocol_version}, session={self._session_id}" + ) + + except Exception: + self.disconnect() + raise + + def disconnect(self) -> None: + """Disconnect from PLC.""" + if self._connected and self._session_id: + try: + self._delete_session() + except Exception: + pass + + self._connected = False + self._tls_active = False + self._session_id = 0 + self._sequence_number = 0 + self._protocol_version = 0 + self._server_session_version = None + self._iso_conn.disconnect() + + def send_request(self, function_code: int, payload: bytes = b"") -> bytes: + """Send an S7CommPlus request and receive the response. + + Args: + function_code: S7CommPlus function code + payload: Request payload (after the 14-byte request header) + + Returns: + Response payload (after the 14-byte response header) + """ + if not self._connected: + from ..error import S7ConnectionError + + raise S7ConnectionError("Not connected") + + seq_num = self._next_sequence_number() + + # Build request header + request_header = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, # Reserved + function_code, + 0x0000, # Reserved + seq_num, + self._session_id, + 0x36, # Transport flags + ) + request = request_header + payload + + logger.debug(f"=== SEND REQUEST === function_code=0x{function_code:04X} seq={seq_num} session=0x{self._session_id:08X}") + logger.debug(f" Request header (14 bytes): {request_header.hex(' ')}") + logger.debug(f" Request payload ({len(payload)} bytes): {payload.hex(' ')}") + + # Add S7CommPlus frame header and trailer, then send + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + + logger.debug(f" Full frame ({len(frame)} bytes): {frame.hex(' ')}") + self._iso_conn.send_data(frame) + + # Receive response + response_frame = self._iso_conn.receive_data() + logger.debug(f"=== RECV RESPONSE === raw frame ({len(response_frame)} bytes): {response_frame.hex(' ')}") + + # Parse frame header, use data_length to exclude trailer + version, data_length, consumed = decode_header(response_frame) + logger.debug(f" Frame header: version=V{version}, data_length={data_length}, header_size={consumed}") + + response = response_frame[consumed : consumed + data_length] + logger.debug(f" Response data ({len(response)} bytes): {response.hex(' ')}") + + if len(response) < 14: + from ..error import S7ConnectionError + + raise S7ConnectionError("Response too short") + + # Parse response header for debug + resp_opcode = response[0] + resp_func = struct.unpack_from(">H", response, 3)[0] + resp_seq = struct.unpack_from(">H", response, 7)[0] + resp_session = struct.unpack_from(">I", response, 9)[0] + resp_transport = response[13] + logger.debug( + f" Response header: opcode=0x{resp_opcode:02X} function=0x{resp_func:04X} " + f"seq={resp_seq} session=0x{resp_session:08X} transport=0x{resp_transport:02X}" + ) + + resp_payload = response[14:] + logger.debug(f" Response payload ({len(resp_payload)} bytes): {resp_payload.hex(' ')}") + + # Check for trailer bytes after data_length + trailer = response_frame[consumed + data_length :] + if trailer: + logger.debug(f" Trailer ({len(trailer)} bytes): {trailer.hex(' ')}") + + return resp_payload + + def _init_ssl(self) -> None: + """Send InitSSL request to prepare the connection. + + This is the first S7CommPlus message sent after COTP connect. + The PLC responds with an InitSSL response. For PLCs that support + TLS, the caller should then activate TLS before sending CreateObject. + For V1 PLCs without TLS, the response may indicate that TLS is + not supported, but the connection can continue without it. + + Reference: thomas-v2/S7CommPlusDriver InitSslRequest + """ + seq_num = self._next_sequence_number() + + # InitSSL request: header + padding + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, # Reserved + FunctionCode.INIT_SSL, + 0x0000, # Reserved + seq_num, + 0x00000000, # No session yet + 0x30, # Transport flags (0x30 for InitSSL) + ) + # Trailing padding + request += struct.pack(">I", 0) + + # Wrap in S7CommPlus frame header + trailer + frame = encode_header(ProtocolVersion.V1, len(request)) + request + frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) + + logger.debug(f"=== InitSSL === sending ({len(frame)} bytes): {frame.hex(' ')}") + self._iso_conn.send_data(frame) + + # Receive InitSSL response + response_frame = self._iso_conn.receive_data() + logger.debug(f"=== InitSSL === received ({len(response_frame)} bytes): {response_frame.hex(' ')}") + + # Parse S7CommPlus frame header + version, data_length, consumed = decode_header(response_frame) + response = response_frame[consumed:] + + if len(response) < 14: + from ..error import S7ConnectionError + + raise S7ConnectionError("InitSSL response too short") + + logger.debug(f"InitSSL response: version=V{version}, data_length={data_length}") + logger.debug(f"InitSSL response body ({len(response)} bytes): {response.hex(' ')}") + + def _create_session(self) -> None: + """Send CreateObject request to establish an S7CommPlus session. + + Builds a NullServerSession CreateObject request matching the + structure expected by S7-1200/1500 PLCs: + + Reference: thomas-v2/S7CommPlusDriver CreateObjectRequest.SetNullServerSessionData() + """ + seq_num = self._next_sequence_number() + + # Build CreateObject request header + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.CREATE_OBJECT, + 0x0000, + seq_num, + ObjectId.OBJECT_NULL_SERVER_SESSION, # SessionId = 288 for initial setup + 0x36, # Transport flags + ) + + # RequestId: ObjectServerSessionContainer (285) + request += struct.pack(">I", ObjectId.OBJECT_SERVER_SESSION_CONTAINER) + + # RequestValue: ValueUDInt(0) = DatatypeFlags(0x00) + Datatype.UDInt(0x04) + VLQ(0) + request += bytes([0x00, DataType.UDINT]) + encode_uint32_vlq(0) + + # Unknown padding (always 0) + request += struct.pack(">I", 0) + + # RequestObject: PObject for NullServerSession + # StartOfObject + request += bytes([ElementID.START_OF_OBJECT]) + # RelationId: GetNewRIDOnServer (211) + request += struct.pack(">I", ObjectId.GET_NEW_RID_ON_SERVER) + # ClassId: ClassServerSession (287), VLQ encoded + request += encode_uint32_vlq(ObjectId.CLASS_SERVER_SESSION) + # ClassFlags: 0 + request += encode_uint32_vlq(0) + # AttributeId: None (0) + request += encode_uint32_vlq(0) + + # Attribute: ServerSessionClientRID (300) = RID 0x80c3c901 + request += bytes([ElementID.ATTRIBUTE]) + request += encode_uint32_vlq(ObjectId.SERVER_SESSION_CLIENT_RID) + request += encode_typed_value(DataType.RID, 0x80C3C901) + + # Nested object: ClassSubscriptions + request += bytes([ElementID.START_OF_OBJECT]) + request += struct.pack(">I", ObjectId.GET_NEW_RID_ON_SERVER) + request += encode_uint32_vlq(ObjectId.CLASS_SUBSCRIPTIONS) + request += encode_uint32_vlq(0) # ClassFlags + request += encode_uint32_vlq(0) # AttributeId + request += bytes([ElementID.TERMINATING_OBJECT]) + + # End outer object + request += bytes([ElementID.TERMINATING_OBJECT]) + + # Trailing padding + request += struct.pack(">I", 0) + + # Wrap in S7CommPlus frame header + trailer + frame = encode_header(ProtocolVersion.V1, len(request)) + request + # S7CommPlus trailer (end-of-frame marker) + frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) + + logger.debug(f"=== CreateObject === sending ({len(frame)} bytes): {frame.hex(' ')}") + self._iso_conn.send_data(frame) + + # Receive response + response_frame = self._iso_conn.receive_data() + logger.debug(f"=== CreateObject === received ({len(response_frame)} bytes): {response_frame.hex(' ')}") + + # Parse S7CommPlus frame header + version, data_length, consumed = decode_header(response_frame) + response = response_frame[consumed:] + + logger.debug(f"CreateObject response: version=V{version}, data_length={data_length}") + logger.debug(f"CreateObject response body ({len(response)} bytes): {response.hex(' ')}") + + if len(response) < 14: + from ..error import S7ConnectionError + + raise S7ConnectionError("CreateObject response too short") + + # Extract session ID from response header + self._session_id = struct.unpack_from(">I", response, 9)[0] + self._protocol_version = version + + # Parse and log the full response header + resp_opcode = response[0] + resp_func = struct.unpack_from(">H", response, 3)[0] + resp_seq = struct.unpack_from(">H", response, 7)[0] + resp_transport = response[13] + logger.debug( + f"CreateObject response header: opcode=0x{resp_opcode:02X} function=0x{resp_func:04X} " + f"seq={resp_seq} session=0x{self._session_id:08X} transport=0x{resp_transport:02X}" + ) + logger.debug(f"CreateObject response payload: {response[14:].hex(' ')}") + logger.debug(f"Session created: id=0x{self._session_id:08X} ({self._session_id}), version=V{version}") + + # Parse response payload to extract ServerSessionVersion + self._parse_create_object_response(response[14:]) + + def _parse_create_object_response(self, payload: bytes) -> None: + """Parse CreateObject response payload to extract ServerSessionVersion. + + The response contains a PObject tree with attributes. We scan for + attribute 306 (ServerSessionVersion) which must be echoed back to + complete the session handshake. + + Args: + payload: Response payload after the 14-byte response header + """ + offset = 0 + while offset < len(payload): + tag = payload[offset] + + if tag == ElementID.ATTRIBUTE: + offset += 1 + if offset >= len(payload): + break + attr_id, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + + if attr_id == ObjectId.SERVER_SESSION_VERSION: + # Next bytes are the typed value: flags + datatype + VLQ value + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + if datatype == DataType.UDINT: + value, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + self._server_session_version = value + logger.info(f"ServerSessionVersion = {value}") + return + elif datatype == DataType.DWORD: + value, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + self._server_session_version = value + logger.info(f"ServerSessionVersion = {value}") + return + else: + # Skip unknown type - try to continue scanning + logger.debug(f"ServerSessionVersion has unexpected type {datatype:#04x}") + else: + # Skip this attribute's value - we don't parse it, just advance + # Try to skip the typed value (flags + datatype + value) + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + offset = self._skip_typed_value(payload, offset, datatype, _flags) + + elif tag == ElementID.START_OF_OBJECT: + offset += 1 + # Skip RelationId (4 bytes fixed) + ClassId (VLQ) + ClassFlags (VLQ) + AttributeId (VLQ) + if offset + 4 > len(payload): + break + offset += 4 # RelationId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassFlags + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # AttributeId + + elif tag == ElementID.TERMINATING_OBJECT: + offset += 1 + + elif tag == 0x00: + # Null terminator / padding + offset += 1 + + else: + # Unknown tag - try to skip + offset += 1 + + logger.debug("ServerSessionVersion not found in CreateObject response") + + def _skip_typed_value(self, data: bytes, offset: int, datatype: int, flags: int) -> int: + """Skip over a typed value in the PObject tree. + + Best-effort: advances offset past common value types. + Returns new offset. + """ + is_array = bool(flags & 0x10) + + if is_array: + if offset >= len(data): + return offset + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + # For fixed-size types, skip count * size + elem_size = _element_size(datatype) + if elem_size > 0: + offset += count * elem_size + else: + # Variable-length: skip each VLQ element + for _ in range(count): + if offset >= len(data): + break + _, consumed = decode_uint32_vlq(data, offset) + offset += consumed + return offset + + if datatype == DataType.NULL: + return offset + elif datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return offset + 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return offset + 2 + elif datatype in (DataType.UDINT, DataType.DWORD, DataType.AID, DataType.DINT): + _, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + elif datatype in (DataType.ULINT, DataType.LWORD, DataType.LINT): + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.REAL: + return offset + 4 + elif datatype == DataType.LREAL: + return offset + 8 + elif datatype == DataType.TIMESTAMP: + return offset + 8 + elif datatype == DataType.TIMESPAN: + _, consumed = decode_uint64_vlq(data, offset) # int64 VLQ + return offset + consumed + elif datatype == DataType.RID: + return offset + 4 + elif datatype in (DataType.BLOB, DataType.WSTRING): + length, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + length + elif datatype == DataType.STRUCT: + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + for _ in range(count): + if offset + 2 > len(data): + break + sub_flags = data[offset] + sub_type = data[offset + 1] + offset += 2 + offset = self._skip_typed_value(data, offset, sub_type, sub_flags) + return offset + else: + # Unknown type - can't skip reliably + return offset + + def _setup_session(self) -> None: + """Send SetMultiVariables to echo ServerSessionVersion back to the PLC. + + This completes the session handshake by writing the ServerSessionVersion + attribute back to the session object. Without this step, the PLC rejects + all subsequent data operations with ERROR2 (0x05A9). + + Reference: thomas-v2/S7CommPlusDriver SetSessionSetupData + """ + if self._server_session_version is None: + return + + seq_num = self._next_sequence_number() + + # Build SetMultiVariables request + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + self._session_id, + 0x36, # Transport flags + ) + + payload = bytearray() + # InObjectId = session ID (tells PLC which object we're writing to) + payload += struct.pack(">I", self._session_id) + # Item count = 1 + payload += encode_uint32_vlq(1) + # Total address field count = 1 (just the attribute ID) + payload += encode_uint32_vlq(1) + # Address: attribute ID = ServerSessionVersion (306) as VLQ + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + # Value: ItemNumber = 1 (VLQ) + payload += encode_uint32_vlq(1) + # PValue: flags=0x00, type=UDInt, VLQ-encoded value + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(self._server_session_version) + # Fill byte + payload += bytes([0x00]) + # ObjectQualifier + payload += encode_object_qualifier() + # Trailing padding + payload += struct.pack(">I", 0) + + request += bytes(payload) + + # Wrap in S7CommPlus frame + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + + logger.debug(f"=== SetupSession === sending ({len(frame)} bytes): {frame.hex(' ')}") + self._iso_conn.send_data(frame) + + # Receive response + response_frame = self._iso_conn.receive_data() + logger.debug(f"=== SetupSession === received ({len(response_frame)} bytes): {response_frame.hex(' ')}") + + version, data_length, consumed = decode_header(response_frame) + response = response_frame[consumed : consumed + data_length] + + if len(response) < 14: + from ..error import S7ConnectionError + + raise S7ConnectionError("SetupSession response too short") + + resp_func = struct.unpack_from(">H", response, 3)[0] + logger.debug(f"SetupSession response: function=0x{resp_func:04X}") + + # Parse return value from payload + resp_payload = response[14:] + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value != 0: + logger.warning(f"SetupSession: PLC returned error {return_value}") + else: + logger.info("Session setup completed successfully") + + def _delete_session(self) -> None: + """Send DeleteObject to close the session.""" + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.DELETE_OBJECT, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + request += struct.pack(">I", 0) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + self._iso_conn.send_data(frame) + + # Best-effort receive + try: + self._iso_conn.receive_data() + except Exception: + pass + + def _next_sequence_number(self) -> int: + """Get next sequence number and increment.""" + seq = self._sequence_number + self._sequence_number = (self._sequence_number + 1) & 0xFFFF + return seq + + def _setup_ssl_context( + self, + cert_path: Optional[str] = None, + key_path: Optional[str] = None, + ca_path: Optional[str] = None, + ) -> ssl.SSLContext: + """Create TLS context for S7CommPlus. + + Args: + cert_path: Client certificate path (PEM) + key_path: Client private key path (PEM) + ca_path: PLC CA certificate path (PEM) + + Returns: + Configured SSLContext + """ + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + + if cert_path and key_path: + ctx.load_cert_chain(cert_path, key_path) + + if ca_path: + ctx.load_verify_locations(ca_path) + else: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + return ctx + + def __enter__(self) -> "S7CommPlusConnection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.disconnect() diff --git a/snap7/s7commplus/protocol.py b/snap7/s7commplus/protocol.py new file mode 100644 index 00000000..2095cb29 --- /dev/null +++ b/snap7/s7commplus/protocol.py @@ -0,0 +1,226 @@ +""" +S7CommPlus protocol constants and types. + +Defines the protocol framing, opcodes, function codes, data types, +element IDs, and other constants needed for S7CommPlus communication. + +Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) +Reference: Wireshark S7CommPlus dissector +""" + +from enum import IntEnum + + +# Protocol identification byte (vs 0x32 for legacy S7comm) +PROTOCOL_ID = 0x72 + + +class ProtocolVersion(IntEnum): + """S7CommPlus protocol versions. + + V1: Early S7-1200 FW V4.0 -- simple session handshake + V2: Adds integrity checking and session authentication + V3: Adds public-key-based key exchange + TLS: TIA Portal V17+ -- standard TLS 1.3 with per-device certificates + + For new implementations, TLS (V3 + InitSsl) is the recommended target. + """ + + V1 = 0x01 + V2 = 0x02 + V3 = 0x03 + SYSTEM_EVENT = 0xFE + + +class Opcode(IntEnum): + """S7CommPlus opcodes (first byte after header).""" + + REQUEST = 0x31 + RESPONSE = 0x32 + NOTIFICATION = 0x33 + RESPONSE2 = 0x02 # Seen in some older firmware + + +class FunctionCode(IntEnum): + """S7CommPlus function codes. + + These identify the type of operation in a request/response pair. + """ + + ERROR = 0x04B1 + EXPLORE = 0x04BB + CREATE_OBJECT = 0x04CA + DELETE_OBJECT = 0x04D4 + SET_VARIABLE = 0x04F2 + GET_VARIABLE = 0x04FC # Only in old S7-1200 firmware + ADD_LINK = 0x0506 + REMOVE_LINK = 0x051A + GET_LINK = 0x0524 + SET_MULTI_VARIABLES = 0x0542 + GET_MULTI_VARIABLES = 0x054C + BEGIN_SEQUENCE = 0x0556 + END_SEQUENCE = 0x0560 + INVOKE = 0x056B + SET_VAR_SUBSTREAMED = 0x057C + GET_VAR_SUBSTREAMED = 0x0586 + GET_VARIABLES_ADDRESS = 0x0590 + ABORT = 0x059A + ERROR2 = 0x05A9 + INIT_SSL = 0x05B3 + + +class ElementID(IntEnum): + """Tag IDs used in the object serialization format. + + S7CommPlus uses a tagged object model where data is structured as + nested objects with attributes, similar to TLV encoding. + """ + + START_OF_OBJECT = 0xA1 + TERMINATING_OBJECT = 0xA2 + ATTRIBUTE = 0xA3 + RELATION = 0xA4 + START_OF_TAG_DESCRIPTION = 0xA7 + TERMINATING_TAG_DESCRIPTION = 0xA8 + VARTYPE_LIST = 0xAB + VARNAME_LIST = 0xAC + + +class ObjectId(IntEnum): + """Well-known object IDs used in session establishment. + + Reference: thomas-v2/S7CommPlusDriver/Core/Ids.cs + """ + + NONE = 0 + GET_NEW_RID_ON_SERVER = 211 + CLASS_SUBSCRIPTIONS = 255 + CLASS_SERVER_SESSION_CONTAINER = 284 + OBJECT_SERVER_SESSION_CONTAINER = 285 + CLASS_SERVER_SESSION = 287 + OBJECT_NULL_SERVER_SESSION = 288 + SERVER_SESSION_CLIENT_RID = 300 + SERVER_SESSION_VERSION = 306 + + +# Default TSAP for S7CommPlus connections +# The remote TSAP is the ASCII string "SIMATIC-ROOT-HMI" (16 bytes) +S7COMMPLUS_LOCAL_TSAP = 0x0600 +S7COMMPLUS_REMOTE_TSAP = b"SIMATIC-ROOT-HMI" + + +class DataType(IntEnum): + """S7CommPlus wire data types. + + These identify how values are encoded on the wire in the S7CommPlus + protocol. Note: these differ from the Softdatatype IDs used for + PLC variable type metadata. + """ + + NULL = 0x00 + BOOL = 0x01 + USINT = 0x02 + UINT = 0x03 + UDINT = 0x04 + ULINT = 0x05 + SINT = 0x06 + INT = 0x07 + DINT = 0x08 + LINT = 0x09 + BYTE = 0x0A + WORD = 0x0B + DWORD = 0x0C + LWORD = 0x0D + REAL = 0x0E + LREAL = 0x0F + TIMESTAMP = 0x10 + TIMESPAN = 0x11 + RID = 0x12 + AID = 0x13 + BLOB = 0x14 + WSTRING = 0x15 + VARIANT = 0x16 + STRUCT = 0x17 + S7STRING = 0x19 + + +class Ids(IntEnum): + """Well-known IDs for S7CommPlus protocol structures. + + Reference: thomas-v2/S7CommPlusDriver/Core/Ids.cs + """ + + # Data block access sub-areas + DB_VALUE_ACTUAL = 2550 + CONTROLLER_AREA_VALUE_ACTUAL = 2551 + + # ObjectQualifier structure IDs + OBJECT_QUALIFIER = 1256 + PARENT_RID = 1257 + COMPOSITION_AID = 1258 + KEY_QUALIFIER = 1259 + + # Native object RIDs for memory areas + NATIVE_THE_I_AREA_RID = 80 + NATIVE_THE_Q_AREA_RID = 81 + NATIVE_THE_M_AREA_RID = 82 + NATIVE_THE_S7_COUNTERS_RID = 83 + NATIVE_THE_S7_TIMERS_RID = 84 + + # DB AccessArea base (add DB number to get area ID) + DB_ACCESS_AREA_BASE = 0x8A0E0000 + + +class SoftDataType(IntEnum): + """PLC soft data types (used in variable metadata / tag descriptions). + + These correspond to the data types as they appear in the PLC's symbol + table and are used for symbolic access to optimized data blocks. + """ + + VOID = 0 + BOOL = 1 + BYTE = 2 + CHAR = 3 + WORD = 4 + INT = 5 + DWORD = 6 + DINT = 7 + REAL = 8 + DATE = 9 + TIME_OF_DAY = 10 + TIME = 11 + S5TIME = 12 + DATE_AND_TIME = 14 + ARRAY = 16 + STRUCT = 17 + STRING = 19 + POINTER = 20 + ANY = 22 + BLOCK_FB = 23 + BLOCK_FC = 24 + BLOCK_DB = 25 + BLOCK_SDB = 26 + COUNTER = 28 + TIMER = 29 + IEC_COUNTER = 30 + IEC_TIMER = 31 + BLOCK_SFB = 32 + BLOCK_SFC = 33 + BLOCK_OB = 36 + BLOCK_UDT = 37 + LREAL = 48 + ULINT = 49 + LINT = 50 + LWORD = 51 + USINT = 52 + UINT = 53 + UDINT = 54 + SINT = 55 + WCHAR = 61 + WSTRING = 62 + VARIANT = 63 + LTIME = 64 + LTOD = 65 + LDT = 66 + DTL = 67 diff --git a/snap7/s7commplus/server.py b/snap7/s7commplus/server.py new file mode 100644 index 00000000..cc08a057 --- /dev/null +++ b/snap7/s7commplus/server.py @@ -0,0 +1,902 @@ +""" +S7CommPlus server emulator for testing. + +Emulates an S7-1200/1500 PLC for integration testing without real hardware. +Handles the S7CommPlus protocol including: +- COTP connection setup (reuses ISOTCPConnection transport) +- CreateObject session handshake +- Explore (browse registered data blocks and variables) +- GetMultiVariables / SetMultiVariables (read/write by address) +- Internal PLC memory model with thread-safe access + +This server does NOT implement TLS or the proprietary authentication +layers (V2/V3 crypto). It emulates a V1 PLC for testing purposes, +which is sufficient for validating protocol framing, data encoding, +and client logic. + +Usage:: + + server = S7CommPlusServer() + server.register_db(1, {"temperature": ("Real", 0), "pressure": ("Real", 4)}) + server.start(port=11020) + + # ... run tests against localhost:11020 ... + + server.stop() +""" + +import logging +import socket +import struct +import threading +from enum import IntEnum +from typing import Any, Callable, Optional + +from .protocol import ( + DataType, + ElementID, + FunctionCode, + Opcode, + ProtocolVersion, + SoftDataType, +) +from .vlq import encode_uint32_vlq, decode_uint32_vlq, encode_uint64_vlq +from .codec import ( + encode_header, + decode_header, + encode_typed_value, + encode_pvalue_blob, + decode_pvalue_to_bytes, +) + +logger = logging.getLogger(__name__) + + +class CPUState(IntEnum): + """Emulated CPU operational state.""" + + UNKNOWN = 0 + STOP = 1 + RUN = 2 + + +# Mapping from SoftDataType to wire DataType and byte size +_SOFT_TO_WIRE: dict[int, tuple[int, int]] = { + SoftDataType.BOOL: (DataType.BOOL, 1), + SoftDataType.BYTE: (DataType.BYTE, 1), + SoftDataType.CHAR: (DataType.BYTE, 1), + SoftDataType.WORD: (DataType.WORD, 2), + SoftDataType.INT: (DataType.INT, 2), + SoftDataType.DWORD: (DataType.DWORD, 4), + SoftDataType.DINT: (DataType.DINT, 4), + SoftDataType.REAL: (DataType.REAL, 4), + SoftDataType.LREAL: (DataType.LREAL, 8), + SoftDataType.USINT: (DataType.USINT, 1), + SoftDataType.UINT: (DataType.UINT, 2), + SoftDataType.UDINT: (DataType.UDINT, 4), + SoftDataType.SINT: (DataType.SINT, 1), + SoftDataType.ULINT: (DataType.ULINT, 8), + SoftDataType.LINT: (DataType.LINT, 8), + SoftDataType.LWORD: (DataType.LWORD, 8), + SoftDataType.STRING: (DataType.S7STRING, 256), + SoftDataType.WSTRING: (DataType.WSTRING, 512), +} + +# Map string type names to SoftDataType values +_TYPE_NAME_MAP: dict[str, int] = { + "Bool": SoftDataType.BOOL, + "Byte": SoftDataType.BYTE, + "Char": SoftDataType.CHAR, + "Word": SoftDataType.WORD, + "Int": SoftDataType.INT, + "DWord": SoftDataType.DWORD, + "DInt": SoftDataType.DINT, + "Real": SoftDataType.REAL, + "LReal": SoftDataType.LREAL, + "USInt": SoftDataType.USINT, + "UInt": SoftDataType.UINT, + "UDInt": SoftDataType.UDINT, + "SInt": SoftDataType.SINT, + "ULInt": SoftDataType.ULINT, + "LInt": SoftDataType.LINT, + "LWord": SoftDataType.LWORD, + "String": SoftDataType.STRING, + "WString": SoftDataType.WSTRING, +} + + +class DBVariable: + """A variable in a data block.""" + + def __init__(self, name: str, soft_datatype: int, byte_offset: int): + self.name = name + self.soft_datatype = soft_datatype + self.byte_offset = byte_offset + + wire_info = _SOFT_TO_WIRE.get(soft_datatype, (DataType.BYTE, 1)) + self.wire_datatype = wire_info[0] + self.byte_size = wire_info[1] + + def __repr__(self) -> str: + return f"DBVariable({self.name!r}, type={self.soft_datatype}, offset={self.byte_offset})" + + +class DataBlock: + """An emulated PLC data block with named variables.""" + + def __init__(self, number: int, size: int = 1024): + self.number = number + self.data = bytearray(size) + self.variables: dict[str, DBVariable] = {} + self.lock = threading.Lock() + # Assign a unique object ID for the S7CommPlus object tree + self.object_id = 0x00010000 | (number & 0xFFFF) + + def add_variable(self, name: str, type_name: str, byte_offset: int) -> None: + """Register a named variable in this data block. + + Args: + name: Variable name (e.g. "temperature") + type_name: PLC type name (e.g. "Real", "Int", "Bool") + byte_offset: Byte offset within the data block + """ + soft_type = _TYPE_NAME_MAP.get(type_name) + if soft_type is None: + raise ValueError(f"Unknown type name: {type_name!r}") + self.variables[name] = DBVariable(name, soft_type, byte_offset) + + def read(self, offset: int, size: int) -> bytes: + """Read bytes from the data block.""" + with self.lock: + end = min(offset + size, len(self.data)) + result = bytes(self.data[offset:end]) + # Pad with zeros if reading past end + if len(result) < size: + result += b"\x00" * (size - len(result)) + return result + + def write(self, offset: int, data: bytes) -> None: + """Write bytes to the data block.""" + with self.lock: + end = min(offset + len(data), len(self.data)) + self.data[offset:end] = data[: end - offset] + + def read_variable(self, name: str) -> tuple[int, bytes]: + """Read a named variable. + + Returns: + Tuple of (wire_datatype, raw_bytes) + """ + var = self.variables.get(name) + if var is None: + raise KeyError(f"Variable not found: {name!r}") + raw = self.read(var.byte_offset, var.byte_size) + return var.wire_datatype, raw + + def write_variable(self, name: str, data: bytes) -> None: + """Write a named variable.""" + var = self.variables.get(name) + if var is None: + raise KeyError(f"Variable not found: {name!r}") + self.write(var.byte_offset, data) + + +class S7CommPlusServer: + """S7CommPlus PLC emulator for testing. + + Emulates an S7-1200/1500 PLC with: + - Internal data block storage with named variables + - S7CommPlus protocol handling (V1 level) + - Multi-client support (threaded) + - CPU state management + """ + + def __init__(self) -> None: + self._data_blocks: dict[int, DataBlock] = {} + self._cpu_state = CPUState.RUN + self._protocol_version = ProtocolVersion.V1 + self._next_session_id = 1 + + self._server_socket: Optional[socket.socket] = None + self._server_thread: Optional[threading.Thread] = None + self._client_threads: list[threading.Thread] = [] + self._running = False + self._lock = threading.Lock() + self._event_callback: Optional[Callable[..., None]] = None + + @property + def cpu_state(self) -> CPUState: + return self._cpu_state + + @cpu_state.setter + def cpu_state(self, state: CPUState) -> None: + self._cpu_state = state + + def register_db(self, db_number: int, variables: dict[str, tuple[str, int]], size: int = 1024) -> DataBlock: + """Register a data block with named variables. + + Args: + db_number: Data block number (e.g. 1 for DB1) + variables: Dict mapping variable name to (type_name, byte_offset) + e.g. {"temperature": ("Real", 0), "count": ("Int", 4)} + size: Data block size in bytes + + Returns: + The created DataBlock + + Example:: + + server.register_db(1, { + "temperature": ("Real", 0), + "pressure": ("Real", 4), + "running": ("Bool", 8), + "count": ("DInt", 10), + }) + """ + db = DataBlock(db_number, size) + for name, (type_name, offset) in variables.items(): + db.add_variable(name, type_name, offset) + self._data_blocks[db_number] = db + return db + + def register_raw_db(self, db_number: int, data: bytearray) -> DataBlock: + """Register a data block with raw data (no named variables). + + Args: + db_number: Data block number + data: Initial data block content + + Returns: + The created DataBlock + """ + db = DataBlock(db_number, len(data)) + db.data = data + self._data_blocks[db_number] = db + return db + + def get_db(self, db_number: int) -> Optional[DataBlock]: + """Get a registered data block.""" + return self._data_blocks.get(db_number) + + def start(self, host: str = "127.0.0.1", port: int = 11020) -> None: + """Start the server. + + Args: + host: Bind address + port: TCP port to listen on + """ + if self._running: + raise RuntimeError("Server is already running") + + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_socket.settimeout(1.0) + self._server_socket.bind((host, port)) + self._server_socket.listen(5) + + self._running = True + self._server_thread = threading.Thread(target=self._server_loop, daemon=True, name="s7commplus-server") + self._server_thread.start() + logger.info(f"S7CommPlus server started on {host}:{port}") + + def stop(self) -> None: + """Stop the server.""" + self._running = False + + if self._server_socket: + try: + self._server_socket.close() + except Exception: + pass + self._server_socket = None + + if self._server_thread: + self._server_thread.join(timeout=5.0) + self._server_thread = None + + for t in self._client_threads: + t.join(timeout=2.0) + self._client_threads.clear() + + logger.info("S7CommPlus server stopped") + + def _server_loop(self) -> None: + """Main server accept loop.""" + while self._running: + try: + if self._server_socket is None: + break + client_sock, address = self._server_socket.accept() + logger.info(f"Client connected from {address}") + t = threading.Thread( + target=self._handle_client, + args=(client_sock, address), + daemon=True, + name=f"s7commplus-client-{address}", + ) + self._client_threads.append(t) + t.start() + except socket.timeout: + continue + except OSError: + break + + def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) -> None: + """Handle a single client connection.""" + try: + client_sock.settimeout(5.0) + + # Step 1: COTP handshake + if not self._handle_cotp_connect(client_sock): + return + + # Step 2: S7CommPlus session + session_id = 0 + + while self._running: + try: + # Receive TPKT + COTP DT + S7CommPlus data + data = self._recv_s7commplus_frame(client_sock) + if data is None: + break + + # Process the S7CommPlus request + response = self._process_request(data, session_id) + + if response is not None: + # Check if session ID was assigned + if session_id == 0 and len(response) >= 14: + # Extract session ID from response for tracking + session_id = struct.unpack_from(">I", response, 9)[0] + + self._send_s7commplus_frame(client_sock, response) + + except socket.timeout: + continue + except (ConnectionError, OSError): + break + + except Exception as e: + logger.debug(f"Client handler error: {e}") + finally: + try: + client_sock.close() + except Exception: + pass + logger.info(f"Client disconnected: {address}") + + def _handle_cotp_connect(self, sock: socket.socket) -> bool: + """Handle COTP Connection Request / Confirm.""" + try: + # Receive TPKT header + tpkt_header = self._recv_exact(sock, 4) + version, _, length = struct.unpack(">BBH", tpkt_header) + if version != 3: + return False + + # Receive COTP CR + payload = self._recv_exact(sock, length - 4) + if len(payload) < 7: + return False + + _pdu_len, pdu_type = payload[0], payload[1] + if pdu_type != 0xE0: # COTP CR + return False + + # Parse source ref from CR + src_ref = struct.unpack_from(">H", payload, 4)[0] + + # Build COTP CC response + cc_pdu = struct.pack( + ">BBHHB", + 6, # PDU length + 0xD0, # COTP CC + src_ref, # Destination ref (client's src ref) + 0x0001, # Source ref (our ref) + 0x00, # Class 0 + ) + + # Add PDU size parameter + pdu_size_param = struct.pack(">BBB", 0xC0, 1, 0x0A) # 1024 bytes + cc_pdu = struct.pack(">B", 6 + len(pdu_size_param)) + cc_pdu[1:] + pdu_size_param + + # Send TPKT + CC + tpkt = struct.pack(">BBH", 3, 0, 4 + len(cc_pdu)) + cc_pdu + sock.sendall(tpkt) + + logger.debug("COTP connection established") + return True + + except Exception as e: + logger.debug(f"COTP handshake failed: {e}") + return False + + def _recv_s7commplus_frame(self, sock: socket.socket) -> Optional[bytes]: + """Receive a TPKT/COTP/S7CommPlus frame, return the S7CommPlus payload.""" + try: + # TPKT header + tpkt_header = self._recv_exact(sock, 4) + version, _, length = struct.unpack(">BBH", tpkt_header) + if version != 3 or length <= 4: + return None + + # Remaining data + payload = self._recv_exact(sock, length - 4) + + # Skip COTP DT header (3 bytes: length, type 0xF0, EOT) + if len(payload) < 3 or payload[1] != 0xF0: + return None + + return payload[3:] # S7CommPlus data + + except Exception: + return None + + def _send_s7commplus_frame(self, sock: socket.socket, data: bytes) -> None: + """Send an S7CommPlus frame wrapped in TPKT/COTP.""" + # S7CommPlus header (4 bytes) + data + trailer (4 bytes) + s7plus_frame = encode_header(self._protocol_version, len(data)) + data + s7plus_frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + + # COTP DT header + cotp_dt = struct.pack(">BBB", 2, 0xF0, 0x80) + s7plus_frame + + # TPKT + tpkt = struct.pack(">BBH", 3, 0, 4 + len(cotp_dt)) + cotp_dt + sock.sendall(tpkt) + + def _process_request(self, data: bytes, session_id: int) -> Optional[bytes]: + """Process an S7CommPlus request and return a response.""" + if len(data) < 4: + return None + + # Parse S7CommPlus frame header + try: + version, data_length, consumed = decode_header(data) + except ValueError: + return None + + # Use data_length to exclude any trailer + payload = data[consumed : consumed + data_length] + if len(payload) < 14: + return None + + # Parse request header + opcode = payload[0] + if opcode != Opcode.REQUEST: + return None + + function_code = struct.unpack_from(">H", payload, 3)[0] + seq_num = struct.unpack_from(">H", payload, 7)[0] + req_session_id = struct.unpack_from(">I", payload, 9)[0] + request_data = payload[14:] + + if function_code == FunctionCode.INIT_SSL: + return self._handle_init_ssl(seq_num) + elif function_code == FunctionCode.CREATE_OBJECT: + return self._handle_create_object(seq_num, request_data) + elif function_code == FunctionCode.DELETE_OBJECT: + return self._handle_delete_object(seq_num, req_session_id) + elif function_code == FunctionCode.EXPLORE: + return self._handle_explore(seq_num, req_session_id, request_data) + elif function_code == FunctionCode.GET_MULTI_VARIABLES: + return self._handle_get_multi_variables(seq_num, req_session_id, request_data) + elif function_code == FunctionCode.SET_MULTI_VARIABLES: + return self._handle_set_multi_variables(seq_num, req_session_id, request_data) + else: + return self._build_error_response(seq_num, req_session_id, function_code) + + def _handle_init_ssl(self, seq_num: int) -> bytes: + """Handle InitSSL -- respond to SSL initialization (V1 emulation, no real TLS).""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.INIT_SSL, + 0x0000, + seq_num, + 0x00000000, + 0x00, # Transport flags + ) + response += encode_uint32_vlq(0) # Return code: success + response += struct.pack(">I", 0) + return bytes(response) + + def _handle_create_object(self, seq_num: int, request_data: bytes) -> bytes: + """Handle CreateObject -- establish a session.""" + with self._lock: + session_id = self._next_session_id + self._next_session_id += 1 + + # Build CreateObject response + response = bytearray() + + # Response header + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, # Reserved + FunctionCode.CREATE_OBJECT, + 0x0000, # Reserved + seq_num, + session_id, + 0x00, # Transport flags + ) + + # Return code: success + response += encode_uint32_vlq(0) + + # Object with session info + response += bytes([ElementID.START_OF_OBJECT]) + response += struct.pack(">I", 0x00000001) # Relation ID + response += encode_uint32_vlq(0x00000000) # Class ID + response += encode_uint32_vlq(0x00000000) # Class flags + response += encode_uint32_vlq(0x00000000) # Attribute ID + + # Session ID attribute + response += bytes([ElementID.ATTRIBUTE]) + response += encode_uint32_vlq(0x0131) # ServerSession ID attribute + response += encode_typed_value(DataType.UDINT, session_id) + + # Protocol version attribute + response += bytes([ElementID.ATTRIBUTE]) + response += encode_uint32_vlq(0x0132) # Protocol version attribute + response += encode_typed_value(DataType.USINT, self._protocol_version) + + response += bytes([ElementID.TERMINATING_OBJECT]) + + # Trailing zeros + response += struct.pack(">I", 0) + + return bytes(response) + + def _handle_delete_object(self, seq_num: int, session_id: int) -> bytes: + """Handle DeleteObject -- close a session.""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.DELETE_OBJECT, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint32_vlq(0) # Return code: success + response += struct.pack(">I", 0) + return bytes(response) + + def _handle_explore(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle Explore -- return the object tree (registered data blocks).""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.EXPLORE, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint32_vlq(0) # Return code: success + + # Return list of data blocks as objects + for db_num, db in sorted(self._data_blocks.items()): + response += bytes([ElementID.START_OF_OBJECT]) + response += struct.pack(">I", db.object_id) # Relation ID + response += encode_uint32_vlq(0x00000100) # Class: DataBlock + response += encode_uint32_vlq(0x00000000) # Class flags + response += encode_uint32_vlq(0x00000000) # Attribute ID + + # DB number attribute + response += bytes([ElementID.ATTRIBUTE]) + response += encode_uint32_vlq(0x0001) # DB number attribute ID + response += encode_typed_value(DataType.UINT, db_num) + + # DB size attribute + response += bytes([ElementID.ATTRIBUTE]) + response += encode_uint32_vlq(0x0002) # DB size attribute ID + response += encode_typed_value(DataType.UDINT, len(db.data)) + + # Variable list + if db.variables: + response += bytes([ElementID.VARNAME_LIST]) + response += encode_uint32_vlq(len(db.variables)) + for var_name, var in db.variables.items(): + name_bytes = var_name.encode("utf-8") + response += encode_uint32_vlq(len(name_bytes)) + response += name_bytes + response += encode_uint32_vlq(var.soft_datatype) + response += encode_uint32_vlq(var.byte_offset) + + response += bytes([ElementID.TERMINATING_OBJECT]) + + # Final terminator + response += struct.pack(">I", 0) + return bytes(response) + + def _handle_get_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle GetMultiVariables -- read variables from data blocks. + + Parses the S7CommPlus request format with ItemAddress structures. + The server extracts db_number from AccessArea and byte offset/size + from the LID values. + + Reference: thomas-v2/S7CommPlusDriver/Core/GetMultiVariablesRequest.cs + """ + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.GET_MULTI_VARIABLES, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # Parse request payload + items = _server_parse_read_request(request_data) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Value list: ItemNumber (1-based) + PValue, terminated by ItemNumber=0 + for i, (db_num, byte_offset, byte_size) in enumerate(items, 1): + db = self._data_blocks.get(db_num) + if db is not None: + data = db.read(byte_offset, byte_size) + response += encode_uint32_vlq(i) # ItemNumber + response += encode_pvalue_blob(data) # Value as BLOB + # Errors handled in error list below + + # Terminate value list + response += encode_uint32_vlq(0) + + # Error list + for i, (db_num, byte_offset, byte_size) in enumerate(items, 1): + db = self._data_blocks.get(db_num) + if db is None: + response += encode_uint32_vlq(i) # ErrorItemNumber + response += encode_uint64_vlq(0x8104) # Error: object not found + + # Terminate error list + response += encode_uint32_vlq(0) + + # IntegrityId + response += encode_uint32_vlq(0) + + return bytes(response) + + def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle SetMultiVariables -- write variables to data blocks. + + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs + """ + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # Parse request payload + items, values = _server_parse_write_request(request_data) + + # Write data + errors: list[tuple[int, int]] = [] + for i, ((db_num, byte_offset, _), data) in enumerate(zip(items, values), 1): + db = self._data_blocks.get(db_num) + if db is not None: + db.write(byte_offset, data) + else: + errors.append((i, 0x8104)) # Object not found + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Error list + for err_item, err_code in errors: + response += encode_uint32_vlq(err_item) + response += encode_uint64_vlq(err_code) + + # Terminate error list + response += encode_uint32_vlq(0) + + # IntegrityId + response += encode_uint32_vlq(0) + + return bytes(response) + + def _build_error_response(self, seq_num: int, session_id: int, function_code: int) -> bytes: + """Build a generic error response for unsupported function codes.""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.ERROR, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint32_vlq(0x04B1) # Error function code + response += struct.pack(">I", 0) + return bytes(response) + + @staticmethod + def _recv_exact(sock: socket.socket, size: int) -> bytes: + """Receive exactly the specified number of bytes.""" + data = bytearray() + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise ConnectionError("Connection closed") + data.extend(chunk) + return bytes(data) + + def __enter__(self) -> "S7CommPlusServer": + return self + + def __exit__(self, *args: Any) -> None: + self.stop() + + +# -- Server-side request parsers -- + + +def _server_parse_read_request(request_data: bytes) -> list[tuple[int, int, int]]: + """Parse a GetMultiVariables request payload on the server side. + + Extracts (db_number, byte_offset, byte_size) for each item from the + S7CommPlus ItemAddress format. + + Returns: + List of (db_number, byte_offset, byte_size) tuples + """ + if not request_data: + return [] + + offset = 0 + items: list[tuple[int, int, int]] = [] + + # LinkId (UInt32 fixed) + if offset + 4 > len(request_data): + return [] + offset += 4 + + # ItemCount (VLQ) + item_count, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # FieldCount (VLQ) + _field_count, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # Parse each ItemAddress + for _ in range(item_count): + if offset >= len(request_data): + break + + # SymbolCrc + _symbol_crc, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # AccessArea + access_area, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # NumberOfLIDs + num_lids, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # AccessSubArea (first LID) + _access_sub_area, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # Additional LIDs + lids: list[int] = [] + for _ in range(num_lids - 1): # -1 because AccessSubArea counts as one + if offset >= len(request_data): + break + lid_val, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + lids.append(lid_val) + + # Extract db_number from AccessArea + db_num = access_area & 0xFFFF + + # Extract byte offset and size from LIDs (LID offsets are 1-based) + byte_offset = (lids[0] - 1) if len(lids) > 0 else 0 + byte_size = lids[1] if len(lids) > 1 else 1 + + items.append((db_num, byte_offset, byte_size)) + + return items + + +def _server_parse_write_request(request_data: bytes) -> tuple[list[tuple[int, int, int]], list[bytes]]: + """Parse a SetMultiVariables request payload on the server side. + + Returns: + Tuple of (items, values) where items is list of (db_number, byte_offset, byte_size) + and values is list of raw bytes to write + """ + if not request_data: + return [], [] + + offset = 0 + + # InObjectId (UInt32 fixed) + if offset + 4 > len(request_data): + return [], [] + offset += 4 + + # ItemCount (VLQ) + item_count, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # FieldCount (VLQ) + _field_count, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # Parse each ItemAddress + items: list[tuple[int, int, int]] = [] + for _ in range(item_count): + if offset >= len(request_data): + break + + # SymbolCrc + _symbol_crc, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # AccessArea + access_area, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # NumberOfLIDs + num_lids, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # AccessSubArea + _access_sub_area, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + + # Additional LIDs + lids: list[int] = [] + for _ in range(num_lids - 1): + if offset >= len(request_data): + break + lid_val, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + lids.append(lid_val) + + db_num = access_area & 0xFFFF + byte_offset = (lids[0] - 1) if len(lids) > 0 else 0 # LID offsets are 1-based + byte_size = lids[1] if len(lids) > 1 else 1 + items.append((db_num, byte_offset, byte_size)) + + # Parse value list: ItemNumber (VLQ, 1-based) + PValue + values: list[bytes] = [] + for _ in range(item_count): + if offset >= len(request_data): + break + item_nr, consumed = decode_uint32_vlq(request_data, offset) + offset += consumed + if item_nr == 0: + break + raw_bytes, consumed = decode_pvalue_to_bytes(request_data, offset) + offset += consumed + values.append(raw_bytes) + + return items, values diff --git a/snap7/s7commplus/vlq.py b/snap7/s7commplus/vlq.py new file mode 100644 index 00000000..19e9c388 --- /dev/null +++ b/snap7/s7commplus/vlq.py @@ -0,0 +1,338 @@ +""" +Variable-Length Quantity (VLQ) encoding for S7CommPlus. + +S7CommPlus uses VLQ encoding for integer values in the protocol framing. +This is similar to MIDI VLQ or protobuf varints, with some S7-specific +variations for signed values and 64-bit special handling. + +Encoding scheme: + - Each byte uses 7 data bits + 1 continuation bit (MSB) + - continuation bit = 1 means more bytes follow + - continuation bit = 0 means this is the last byte + - Big-endian byte order (most significant group first) + - Signed values use bit 6 of the first byte as a sign flag + +64-bit special case: + - 8 bytes of 7-bit groups = 56 bits, which is less than 64 + - The 9th byte uses all 8 bits (no continuation flag) + - This avoids needing a 10th byte + +Reference: thomas-v2/S7CommPlusDriver/Core/S7p.cs +""" + + +def encode_uint32_vlq(value: int) -> bytes: + """Encode an unsigned 32-bit integer as VLQ. + + Args: + value: Unsigned integer (0 to 2^32-1) + + Returns: + VLQ-encoded bytes (1-5 bytes) + """ + if value < 0 or value > 0xFFFFFFFF: + raise ValueError(f"Value out of range for uint32 VLQ: {value}") + + result = bytearray() + + # Find the highest non-zero 7-bit group + num_groups = 1 + for i in range(4, 0, -1): + if value & (0x7F << (i * 7)): + num_groups = i + 1 + break + + # Encode each group, MSB first + for i in range(num_groups - 1, -1, -1): + group = (value >> (i * 7)) & 0x7F + if i > 0: + group |= 0x80 # Set continuation bit + result.append(group) + + return bytes(result) + + +def decode_uint32_vlq(data: bytes, offset: int = 0) -> tuple[int, int]: + """Decode a VLQ-encoded unsigned 32-bit integer. + + Args: + data: Buffer containing VLQ data + offset: Starting position in buffer + + Returns: + Tuple of (decoded_value, bytes_consumed) + """ + value = 0 + consumed = 0 + + for _ in range(5): # Max 5 bytes for 32-bit + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + + value = (value << 7) | (octet & 0x7F) + + if not (octet & 0x80): # No continuation bit + break + + return value, consumed + + +def encode_int32_vlq(value: int) -> bytes: + """Encode a signed 32-bit integer as VLQ. + + Signed VLQ uses bit 6 of the first byte as a sign indicator. + Negative values are encoded in a compact two's-complement-like form. + + Args: + value: Signed integer (-2^31 to 2^31-1) + + Returns: + VLQ-encoded bytes (1-5 bytes) + """ + if value < -0x80000000 or value > 0x7FFFFFFF: + raise ValueError(f"Value out of range for int32 VLQ: {value}") + + result = bytearray() + + if value == -0x80000000: + abs_v = 0x80000000 + else: + abs_v = abs(value) + + b = [0] * 5 + b[0] = value & 0x7F + length = 1 + + for i in range(1, 5): + if abs_v >= 0x40: + length += 1 + abs_v >>= 7 + value >>= 7 + b[i] = ((value & 0x7F) + 0x80) & 0xFF + else: + break + + # Emit in reverse order (big-endian) + for i in range(length - 1, -1, -1): + result.append(b[i]) + + return bytes(result) + + +def decode_int32_vlq(data: bytes, offset: int = 0) -> tuple[int, int]: + """Decode a VLQ-encoded signed 32-bit integer. + + Args: + data: Buffer containing VLQ data + offset: Starting position in buffer + + Returns: + Tuple of (decoded_value, bytes_consumed) + """ + value = 0 + consumed = 0 + + for counter in range(1, 6): # Max 5 bytes for 32-bit + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + + if counter == 1 and (octet & 0x40): # Check sign bit + octet &= 0xBF + value = -64 # Pre-load with one's complement + else: + value <<= 7 + + value += octet & 0x7F + + if not (octet & 0x80): # No continuation bit + break + + return value, consumed + + +def encode_uint64_vlq(value: int) -> bytes: + """Encode an unsigned 64-bit integer as VLQ. + + 64-bit VLQ has special handling: since 8 groups of 7 bits = 56 bits < 64, + the 9th byte uses all 8 bits (no continuation flag). + + Args: + value: Unsigned integer (0 to 2^64-1) + + Returns: + VLQ-encoded bytes (1-9 bytes) + """ + if value < 0 or value > 0xFFFFFFFFFFFFFFFF: + raise ValueError(f"Value out of range for uint64 VLQ: {value}") + + special = value > 0x00FFFFFFFFFFFFFF + + b = [0] * 9 + if special: + b[0] = value & 0xFF + else: + b[0] = value & 0x7F + + length = 1 + for i in range(1, 9): + if value >= 0x80: + length += 1 + if i == 1 and special: + value >>= 8 + else: + value >>= 7 + b[i] = ((value & 0x7F) + 0x80) & 0xFF + else: + break + + if special and length == 8: + length += 1 + b[8] = 0x80 + + # Emit in reverse order + result = bytearray() + for i in range(length - 1, -1, -1): + result.append(b[i]) + + return bytes(result) + + +def decode_uint64_vlq(data: bytes, offset: int = 0) -> tuple[int, int]: + """Decode a VLQ-encoded unsigned 64-bit integer. + + Args: + data: Buffer containing VLQ data + offset: Starting position in buffer + + Returns: + Tuple of (decoded_value, bytes_consumed) + """ + value = 0 + consumed = 0 + cont = 0 + + for counter in range(1, 9): # Max 8 groups of 7 bits + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + + value = (value << 7) | (octet & 0x7F) + cont = octet & 0x80 + + if not cont: + break + + if cont: + # 9th byte: all 8 bits are data (special 64-bit handling) + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + value = (value << 8) | octet + + return value, consumed + + +def encode_int64_vlq(value: int) -> bytes: + """Encode a signed 64-bit integer as VLQ. + + Args: + value: Signed integer (-2^63 to 2^63-1) + + Returns: + VLQ-encoded bytes (1-9 bytes) + """ + if value < -0x8000000000000000 or value > 0x7FFFFFFFFFFFFFFF: + raise ValueError(f"Value out of range for int64 VLQ: {value}") + + if value == -0x8000000000000000: + abs_v = 0x8000000000000000 + else: + abs_v = abs(value) + + special = abs_v > 0x007FFFFFFFFFFFFF + + b = [0] * 9 + if special: + b[0] = value & 0xFF + else: + b[0] = value & 0x7F + + length = 1 + for i in range(1, 9): + if abs_v >= 0x40: + length += 1 + if i == 1 and special: + abs_v >>= 8 + value >>= 8 + else: + abs_v >>= 7 + value >>= 7 + b[i] = ((value & 0x7F) + 0x80) & 0xFF + else: + break + + if special and length == 8: + length += 1 + b[8] = 0x80 if value >= 0 else 0xFF + + # Emit in reverse order + result = bytearray() + for i in range(length - 1, -1, -1): + result.append(b[i]) + + return bytes(result) + + +def decode_int64_vlq(data: bytes, offset: int = 0) -> tuple[int, int]: + """Decode a VLQ-encoded signed 64-bit integer. + + Args: + data: Buffer containing VLQ data + offset: Starting position in buffer + + Returns: + Tuple of (decoded_value, bytes_consumed) + """ + value = 0 + consumed = 0 + cont = 0 + + for counter in range(1, 9): # Max 8 groups of 7 bits + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + + if counter == 1 and (octet & 0x40): # Check sign bit + octet &= 0xBF + value = -64 # Pre-load with one's complement + else: + value <<= 7 + + cont = octet & 0x80 + value += octet & 0x7F + + if not cont: + break + + if cont: + # 9th byte: all 8 bits are data + if offset + consumed >= len(data): + raise ValueError("Unexpected end of VLQ data") + + octet = data[offset + consumed] + consumed += 1 + value = (value << 8) | octet + + return value, consumed diff --git a/snap7/util/db.py b/snap7/util/db.py index 47f65aa2..48834898 100644 --- a/snap7/util/db.py +++ b/snap7/util/db.py @@ -635,7 +635,9 @@ def get_value(self, byte_index: Union[str, int], type_: str) -> ValueType: return type_to_func[type_](bytearray_, byte_index) raise ValueError - def set_value(self, byte_index: Union[str, int], type_: str, value: Union[bool, str, float]) -> Optional[bytearray]: + def set_value( + self, byte_index: Union[str, int], type_: str, value: Union[bool, str, float, date, datetime, timedelta] + ) -> Optional[Union[bytearray, memoryview]]: """Sets the value for a specific type in the specified byte index. Args: @@ -685,7 +687,7 @@ def set_value(self, byte_index: Union[str, int], type_: str, value: Union[bool, set_wstring(bytearray_, byte_index, value, max_size_int) return None - if type_ == "REAL": + if type_ == "REAL" and isinstance(value, (bool, str, float, int)): return set_real(bytearray_, byte_index, value) if type_ == "LREAL" and isinstance(value, float): diff --git a/snap7/util/getters.py b/snap7/util/getters.py index 32b85433..01c2f963 100644 --- a/snap7/util/getters.py +++ b/snap7/util/getters.py @@ -1,12 +1,16 @@ import struct from datetime import timedelta, datetime, date -from typing import NoReturn +from typing import NoReturn, Union from logging import getLogger +#: Buffer types accepted by getter functions. +#: Both :class:`bytearray` and :class:`memoryview` are supported. +Buffer = Union[bytearray, memoryview] + logger = getLogger(__name__) -def get_bool(bytearray_: bytearray, byte_index: int, bool_index: int) -> bool: +def get_bool(bytearray_: Buffer, byte_index: int, bool_index: int) -> bool: """Get the boolean value from location in bytearray Args: @@ -28,7 +32,7 @@ def get_bool(bytearray_: bytearray, byte_index: int, bool_index: int) -> bool: return current_value == index_value -def get_byte(bytearray_: bytearray, byte_index: int) -> bytes: +def get_byte(bytearray_: Buffer, byte_index: int) -> bytes: """Get byte value from bytearray. Notes: @@ -48,7 +52,7 @@ def get_byte(bytearray_: bytearray, byte_index: int) -> bytes: return value -def get_word(bytearray_: bytearray, byte_index: int) -> bytearray: +def get_word(bytearray_: Buffer, byte_index: int) -> bytearray: """Get word value from bytearray. Notes: @@ -73,7 +77,7 @@ def get_word(bytearray_: bytearray, byte_index: int) -> bytearray: return value -def get_int(bytearray_: bytearray, byte_index: int) -> int: +def get_int(bytearray_: Buffer, byte_index: int) -> int: """Get int value from bytearray. Notes: @@ -98,7 +102,7 @@ def get_int(bytearray_: bytearray, byte_index: int) -> int: return value -def get_uint(bytearray_: bytearray, byte_index: int) -> int: +def get_uint(bytearray_: Buffer, byte_index: int) -> int: """Get unsigned int value from bytearray. Notes: @@ -121,7 +125,7 @@ def get_uint(bytearray_: bytearray, byte_index: int) -> int: return int(get_word(bytearray_, byte_index)) -def get_real(bytearray_: bytearray, byte_index: int) -> float: +def get_real(bytearray_: Buffer, byte_index: int) -> float: """Get real value. Notes: @@ -145,7 +149,7 @@ def get_real(bytearray_: bytearray, byte_index: int) -> float: return real -def get_fstring(bytearray_: bytearray, byte_index: int, max_length: int, remove_padding: bool = True) -> str: +def get_fstring(bytearray_: Buffer, byte_index: int, max_length: int, remove_padding: bool = True) -> str: """Parse space-padded fixed-length string from bytearray Notes: @@ -176,7 +180,7 @@ def get_fstring(bytearray_: bytearray, byte_index: int, max_length: int, remove_ return string -def get_string(bytearray_: bytearray, byte_index: int) -> str: +def get_string(bytearray_: Buffer, byte_index: int) -> str: """Parse string from bytearray Notes: @@ -210,7 +214,7 @@ def get_string(bytearray_: bytearray, byte_index: int) -> str: return "".join(data) -def get_dword(bytearray_: bytearray, byte_index: int) -> int: +def get_dword(bytearray_: Buffer, byte_index: int) -> int: """Gets the dword from the buffer. Notes: @@ -235,7 +239,7 @@ def get_dword(bytearray_: bytearray, byte_index: int) -> int: return dword -def get_dint(bytearray_: bytearray, byte_index: int) -> int: +def get_dint(bytearray_: Buffer, byte_index: int) -> int: """Get dint value from bytearray. Notes: @@ -262,7 +266,7 @@ def get_dint(bytearray_: bytearray, byte_index: int) -> int: return dint -def get_udint(bytearray_: bytearray, byte_index: int) -> int: +def get_udint(bytearray_: Buffer, byte_index: int) -> int: """Get unsigned dint value from bytearray. Notes: @@ -289,7 +293,7 @@ def get_udint(bytearray_: bytearray, byte_index: int) -> int: return dint -def get_s5time(bytearray_: bytearray, byte_index: int) -> str: +def get_s5time(bytearray_: Buffer, byte_index: int) -> str: micro_to_milli = 1000 data_bytearray = bytearray_[byte_index : byte_index + 2] s5time_data_int_like = list(data_bytearray.hex()) @@ -315,7 +319,7 @@ def get_s5time(bytearray_: bytearray, byte_index: int) -> str: return "".join(str(s5time)) -def get_dt(bytearray_: bytearray, byte_index: int) -> str: +def get_dt(bytearray_: Buffer, byte_index: int) -> str: """Get DATE_AND_TIME Value from bytearray as ISO 8601 formatted Date String Notes: Datatype `DATE_AND_TIME` consists in 8 bytes in the PLC. @@ -331,7 +335,7 @@ def get_dt(bytearray_: bytearray, byte_index: int) -> str: return get_date_time_object(bytearray_, byte_index).isoformat(timespec="microseconds") -def get_date_time_object(bytearray_: bytearray, byte_index: int) -> datetime: +def get_date_time_object(bytearray_: Buffer, byte_index: int) -> datetime: """Get DATE_AND_TIME Value from bytearray as python datetime object Notes: Datatype `DATE_AND_TIME` consists in 8 bytes in the PLC. @@ -364,7 +368,7 @@ def bcd_to_byte(byte: int) -> int: return datetime(year, month, day, hour, min_, sec, microsec) -def get_time(bytearray_: bytearray, byte_index: int) -> str: +def get_time(bytearray_: Buffer, byte_index: int) -> str: """Get time value from bytearray. Notes: @@ -408,7 +412,7 @@ def get_time(bytearray_: bytearray, byte_index: int) -> str: return time_str -def get_usint(bytearray_: bytearray, byte_index: int) -> int: +def get_usint(bytearray_: Buffer, byte_index: int) -> int: """Get the unsigned small int from the bytearray Notes: @@ -434,7 +438,7 @@ def get_usint(bytearray_: bytearray, byte_index: int) -> int: return value -def get_sint(bytearray_: bytearray, byte_index: int) -> int: +def get_sint(bytearray_: Buffer, byte_index: int) -> int: """Get the small int Notes: @@ -460,7 +464,7 @@ def get_sint(bytearray_: bytearray, byte_index: int) -> int: return value -def get_lint(bytearray_: bytearray, byte_index: int) -> int: +def get_lint(bytearray_: Buffer, byte_index: int) -> int: """Get the long int THIS VALUE IS NEITHER TESTED NOR VERIFIED BY A REAL PLC AT THE MOMENT @@ -490,7 +494,7 @@ def get_lint(bytearray_: bytearray, byte_index: int) -> int: return int(lint) -def get_lreal(bytearray_: bytearray, byte_index: int) -> float: +def get_lreal(bytearray_: Buffer, byte_index: int) -> float: """Get the long real Datatype `lreal` (long real) consists in 8 bytes in the PLC. @@ -515,7 +519,7 @@ def get_lreal(bytearray_: bytearray, byte_index: int) -> float: return float(struct.unpack_from(">d", bytearray_, offset=byte_index)[0]) -def get_lword(bytearray_: bytearray, byte_index: int) -> int: +def get_lword(bytearray_: Buffer, byte_index: int) -> int: """Get the long word Notes: @@ -540,7 +544,7 @@ def get_lword(bytearray_: bytearray, byte_index: int) -> int: return lword -def get_ulint(bytearray_: bytearray, byte_index: int) -> int: +def get_ulint(bytearray_: Buffer, byte_index: int) -> int: """Get ulint value from bytearray. Notes: @@ -565,7 +569,7 @@ def get_ulint(bytearray_: bytearray, byte_index: int) -> int: return lint -def get_tod(bytearray_: bytearray, byte_index: int) -> timedelta: +def get_tod(bytearray_: Buffer, byte_index: int) -> timedelta: len_bytearray_ = len(bytearray_) byte_range = byte_index + 4 if len_bytearray_ < byte_range: @@ -576,7 +580,7 @@ def get_tod(bytearray_: bytearray, byte_index: int) -> timedelta: return time_val -def get_date(bytearray_: bytearray, byte_index: int = 0) -> date: +def get_date(bytearray_: Buffer, byte_index: int = 0) -> date: len_bytearray_ = len(bytearray_) byte_range = byte_index + 2 if len_bytearray_ < byte_range: @@ -587,7 +591,7 @@ def get_date(bytearray_: bytearray, byte_index: int = 0) -> date: return date_val -def get_ltime(bytearray_: bytearray, byte_index: int) -> timedelta: +def get_ltime(bytearray_: Buffer, byte_index: int) -> timedelta: """Get LTIME value from bytearray. Notes: @@ -612,7 +616,7 @@ def get_ltime(bytearray_: bytearray, byte_index: int) -> timedelta: return timedelta(microseconds=nanoseconds // 1000) -def get_ltod(bytearray_: bytearray, byte_index: int) -> timedelta: +def get_ltod(bytearray_: Buffer, byte_index: int) -> timedelta: """Get LTOD (Long Time of Day) value from bytearray. Notes: @@ -635,7 +639,7 @@ def get_ltod(bytearray_: bytearray, byte_index: int) -> timedelta: return result -def get_ldt(bytearray_: bytearray, byte_index: int) -> datetime: +def get_ldt(bytearray_: Buffer, byte_index: int) -> datetime: """Get LDT (Long Date and Time) value from bytearray. Notes: @@ -655,7 +659,7 @@ def get_ldt(bytearray_: bytearray, byte_index: int) -> datetime: return epoch + timedelta(microseconds=nanoseconds // 1000) -def get_dtl(bytearray_: bytearray, byte_index: int) -> datetime: +def get_dtl(bytearray_: Buffer, byte_index: int) -> datetime: time_to_datetime = datetime( year=int.from_bytes(bytearray_[byte_index : byte_index + 2], byteorder="big"), month=int(bytearray_[byte_index + 2]), @@ -670,7 +674,7 @@ def get_dtl(bytearray_: bytearray, byte_index: int) -> datetime: return time_to_datetime -def get_char(bytearray_: bytearray, byte_index: int) -> str: +def get_char(bytearray_: Buffer, byte_index: int) -> str: """Get char value from bytearray. Notes: @@ -694,7 +698,7 @@ def get_char(bytearray_: bytearray, byte_index: int) -> str: return char -def get_wchar(bytearray_: bytearray, byte_index: int) -> str: +def get_wchar(bytearray_: Buffer, byte_index: int) -> str: """Get wchar value from bytearray. Datatype `wchar` in the PLC is represented in 2 bytes. It has to be in utf-16-be format. @@ -715,10 +719,10 @@ def get_wchar(bytearray_: bytearray, byte_index: int) -> str: """ if bytearray_[byte_index] == 0: return chr(bytearray_[byte_index + 1]) - return bytearray_[byte_index : byte_index + 2].decode("utf-16-be") + return bytes(bytearray_[byte_index : byte_index + 2]).decode("utf-16-be") -def get_wstring(bytearray_: bytearray, byte_index: int) -> str: +def get_wstring(bytearray_: Buffer, byte_index: int) -> str: """Parse wstring from bytearray Notes: @@ -759,8 +763,8 @@ def get_wstring(bytearray_: bytearray, byte_index: int) -> str: f"expected or is larger than 16382. Bytearray doesn't seem to be a valid string." ) - return bytearray_[wstring_start : wstring_start + wstr_symbols_amount].decode("utf-16-be") + return bytes(bytearray_[wstring_start : wstring_start + wstr_symbols_amount]).decode("utf-16-be") -def get_array(bytearray_: bytearray, byte_index: int) -> NoReturn: +def get_array(bytearray_: Buffer, byte_index: int) -> NoReturn: raise NotImplementedError diff --git a/snap7/util/setters.py b/snap7/util/setters.py index 4cf8ad60..31d6d174 100644 --- a/snap7/util/setters.py +++ b/snap7/util/setters.py @@ -5,8 +5,12 @@ from .getters import get_bool +#: Buffer types accepted by setter functions. +#: Both :class:`bytearray` and writable :class:`memoryview` are supported. +Buffer = Union[bytearray, memoryview] -def set_bool(bytearray_: bytearray, byte_index: int, bool_index: int, value: bool) -> bytearray: + +def set_bool(bytearray_: Buffer, byte_index: int, bool_index: int, value: bool) -> Buffer: """Set boolean value on location in bytearray. Args: @@ -40,7 +44,7 @@ def set_bool(bytearray_: bytearray, byte_index: int, bool_index: int, value: boo return bytearray_ -def set_byte(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_byte(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set value in bytearray to byte Args: @@ -61,7 +65,7 @@ def set_byte(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_word(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_word(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set value in bytearray to word Notes: @@ -80,7 +84,7 @@ def set_word(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_int(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_int(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set value in bytearray to int Notes: @@ -105,7 +109,7 @@ def set_int(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_uint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_uint(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set value in bytearray to unsigned int Notes: @@ -131,7 +135,7 @@ def set_uint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_real(bytearray_: bytearray, byte_index: int, real: Union[bool, str, float, int]) -> bytearray: +def set_real(bytearray_: Buffer, byte_index: int, real: Union[bool, str, float, int]) -> Buffer: """Set Real value Notes: @@ -155,7 +159,7 @@ def set_real(bytearray_: bytearray, byte_index: int, real: Union[bool, str, floa return bytearray_ -def set_fstring(bytearray_: bytearray, byte_index: int, value: str, max_length: int) -> bytearray: +def set_fstring(bytearray_: Buffer, byte_index: int, value: str, max_length: int) -> Buffer: """Set space-padded fixed-length string value Args: @@ -193,7 +197,7 @@ def set_fstring(bytearray_: bytearray, byte_index: int, value: str, max_length: return bytearray_ -def set_string(bytearray_: bytearray, byte_index: int, value: str, max_size: int = 254) -> bytearray: +def set_string(bytearray_: Buffer, byte_index: int, value: str, max_size: int = 254) -> Buffer: """Set string value Args: @@ -248,7 +252,7 @@ def set_string(bytearray_: bytearray, byte_index: int, value: str, max_size: int return bytearray_ -def set_dword(bytearray_: bytearray, byte_index: int, dword: int) -> bytearray: +def set_dword(bytearray_: Buffer, byte_index: int, dword: int) -> Buffer: """Set a DWORD to the buffer. Notes: @@ -271,7 +275,7 @@ def set_dword(bytearray_: bytearray, byte_index: int, dword: int) -> bytearray: return bytearray_ -def set_dint(bytearray_: bytearray, byte_index: int, dint: int) -> bytearray: +def set_dint(bytearray_: Buffer, byte_index: int, dint: int) -> Buffer: """Set value in bytearray to dint Notes: @@ -295,7 +299,7 @@ def set_dint(bytearray_: bytearray, byte_index: int, dint: int) -> bytearray: return bytearray_ -def set_udint(bytearray_: bytearray, byte_index: int, udint: int) -> bytearray: +def set_udint(bytearray_: Buffer, byte_index: int, udint: int) -> Buffer: """Set value in bytearray to unsigned dint Notes: @@ -319,7 +323,7 @@ def set_udint(bytearray_: bytearray, byte_index: int, udint: int) -> bytearray: return bytearray_ -def set_time(bytearray_: bytearray, byte_index: int, time_string: str) -> bytearray: +def set_time(bytearray_: Buffer, byte_index: int, time_string: str) -> Buffer: """Set value in bytearray to time Notes: @@ -366,7 +370,7 @@ def set_time(bytearray_: bytearray, byte_index: int, time_string: str) -> bytear raise ValueError("time value out of range, please check the value interval") -def set_usint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_usint(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set unsigned small int Notes: @@ -392,7 +396,7 @@ def set_usint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_sint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: +def set_sint(bytearray_: Buffer, byte_index: int, _int: int) -> Buffer: """Set small int to the buffer. Notes: @@ -418,7 +422,7 @@ def set_sint(bytearray_: bytearray, byte_index: int, _int: int) -> bytearray: return bytearray_ -def set_lreal(bytearray_: bytearray, byte_index: int, lreal: float) -> bytearray: +def set_lreal(bytearray_: Buffer, byte_index: int, lreal: float) -> Buffer: """Set the long real Notes: @@ -447,7 +451,7 @@ def set_lreal(bytearray_: bytearray, byte_index: int, lreal: float) -> bytearray return bytearray_ -def set_lword(bytearray_: bytearray, byte_index: int, lword: int) -> bytearray: +def set_lword(bytearray_: Buffer, byte_index: int, lword: int) -> Buffer: """Set the long word Notes: @@ -474,7 +478,7 @@ def set_lword(bytearray_: bytearray, byte_index: int, lword: int) -> bytearray: return bytearray_ -def set_char(bytearray_: bytearray, byte_index: int, chr_: str) -> bytearray: +def set_char(bytearray_: Buffer, byte_index: int, chr_: str) -> Buffer: """Set char value in a bytearray. Notes: @@ -510,7 +514,7 @@ def set_char(bytearray_: bytearray, byte_index: int, chr_: str) -> bytearray: raise ValueError(f"chr_ : {chr_} contains ascii value > 255, which is not compatible with PLC Type CHAR.") -def set_date(bytearray_: bytearray, byte_index: int, date_: date) -> bytearray: +def set_date(bytearray_: Buffer, byte_index: int, date_: date) -> Buffer: """Set value in bytearray to date Notes: Datatype `date` consists in the number of days elapsed from 1990-01-01. @@ -534,7 +538,7 @@ def set_date(bytearray_: bytearray, byte_index: int, date_: date) -> bytearray: return bytearray_ -def set_wchar(bytearray_: bytearray, byte_index: int, chr_: str) -> bytearray: +def set_wchar(bytearray_: Buffer, byte_index: int, chr_: str) -> Buffer: """Set wchar value in a bytearray. Notes: @@ -563,7 +567,7 @@ def set_wchar(bytearray_: bytearray, byte_index: int, chr_: str) -> bytearray: return bytearray_ -def set_wstring(bytearray_: bytearray, byte_index: int, value: str, max_size: int = 16382) -> None: +def set_wstring(bytearray_: Buffer, byte_index: int, value: str, max_size: int = 16382) -> None: """Set wstring value Notes: @@ -606,7 +610,7 @@ def set_wstring(bytearray_: bytearray, byte_index: int, value: str, max_size: in bytearray_[byte_index + 4 : byte_index + 4 + len(encoded)] = encoded -def set_tod(bytearray_: bytearray, byte_index: int, tod: timedelta) -> bytearray: +def set_tod(bytearray_: Buffer, byte_index: int, tod: timedelta) -> Buffer: """Set TIME_OF_DAY value in bytearray. Notes: @@ -633,7 +637,7 @@ def set_tod(bytearray_: bytearray, byte_index: int, tod: timedelta) -> bytearray return bytearray_ -def set_dtl(bytearray_: bytearray, byte_index: int, dt_: datetime) -> bytearray: +def set_dtl(bytearray_: Buffer, byte_index: int, dt_: datetime) -> Buffer: """Set DTL (Date and Time Long) value in bytearray. Notes: @@ -678,7 +682,7 @@ def set_dtl(bytearray_: bytearray, byte_index: int, dt_: datetime) -> bytearray: return bytearray_ -def set_dt(bytearray_: bytearray, byte_index: int, dt_: datetime) -> bytearray: +def set_dt(bytearray_: Buffer, byte_index: int, dt_: datetime) -> Buffer: """Set DATE_AND_TIME value in bytearray. Notes: diff --git a/tests/conftest.py b/tests/conftest.py index c0e3eac1..4e53e6d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,8 +65,13 @@ def pytest_configure(config: pytest.Config) -> None: def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: """Propagate CLI options and skip e2e tests unless --e2e flag is provided.""" - # Propagate CLI options to test_client_e2e module globals - for mod_name in ["tests.test_client_e2e", "test_client_e2e"]: + # Propagate CLI options to e2e test module globals + for mod_name in [ + "tests.test_client_e2e", + "test_client_e2e", + "tests.test_s7commplus_e2e", + "test_s7commplus_e2e", + ]: e2e = sys.modules.get(mod_name) if e2e is not None: e2e.PLC_IP = str(config.getoption("--plc-ip")) @@ -75,7 +80,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item e2e.PLC_PORT = int(config.getoption("--plc-port")) e2e.DB_READ_ONLY = int(config.getoption("--plc-db-read")) e2e.DB_READ_WRITE = int(config.getoption("--plc-db-write")) - break # Skip e2e tests if flag not provided if config.getoption("--e2e"): diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 00000000..86f55617 --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,329 @@ +"""Tests for the native async client (AsyncClient). + +Uses the same Server fixture as test_client.py for integration tests. +""" + +import asyncio +import logging +from collections.abc import AsyncGenerator, Generator + +import pytest +import pytest_asyncio + +from snap7.async_client import AsyncClient +from snap7.server import Server +from snap7.type import SrvArea, Area, Parameter + +logging.basicConfig(level=logging.WARNING) + +ip = "127.0.0.1" +tcpport = 1103 # Different port from sync tests to avoid conflicts +db_number = 1 +rack = 1 +slot = 1 + + +@pytest.fixture(scope="module") +def server() -> Generator[Server]: + srv = Server() + srv.register_area(SrvArea.DB, 0, bytearray(600)) + srv.register_area(SrvArea.DB, 1, bytearray(600)) + srv.register_area(SrvArea.PA, 0, bytearray(100)) + srv.register_area(SrvArea.PA, 1, bytearray(100)) + srv.register_area(SrvArea.PE, 0, bytearray(100)) + srv.register_area(SrvArea.PE, 1, bytearray(100)) + srv.register_area(SrvArea.MK, 0, bytearray(100)) + srv.register_area(SrvArea.MK, 1, bytearray(100)) + srv.register_area(SrvArea.TM, 0, bytearray(100)) + srv.register_area(SrvArea.TM, 1, bytearray(100)) + srv.register_area(SrvArea.CT, 0, bytearray(100)) + srv.register_area(SrvArea.CT, 1, bytearray(100)) + srv.start(tcp_port=tcpport) + yield srv + srv.stop() + srv.destroy() + + +@pytest_asyncio.fixture +async def client(server: Server) -> AsyncGenerator[AsyncClient]: + c = AsyncClient() + await c.connect(ip, rack, slot, tcpport) + yield c + await c.disconnect() + + +# ------------------------------------------------------------------- +# Connection +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_connect_disconnect(server: Server) -> None: + c = AsyncClient() + await c.connect(ip, rack, slot, tcpport) + assert c.get_connected() + await c.disconnect() + assert not c.get_connected() + + +@pytest.mark.asyncio +async def test_context_manager(server: Server) -> None: + async with AsyncClient() as c: + await c.connect(ip, rack, slot, tcpport) + assert c.get_connected() + assert not c.get_connected() + + +# ------------------------------------------------------------------- +# DB read / write +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_db_read(client: AsyncClient) -> None: + data = bytearray(40) + await client.db_write(db_number=1, start=0, data=data) + result = await client.db_read(db_number=1, start=0, size=40) + assert data == result + + +@pytest.mark.asyncio +async def test_db_write(client: AsyncClient) -> None: + data = bytearray(b"\x01\x02\x03\x04") + await client.db_write(db_number=1, start=0, data=data) + result = await client.db_read(db_number=1, start=0, size=4) + assert result == data + + +@pytest.mark.asyncio +async def test_db_get(client: AsyncClient) -> None: + result = await client.db_get(db_number=1) + assert isinstance(result, bytearray) + assert len(result) > 0 + + +# ------------------------------------------------------------------- +# read_area / write_area +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_read_write_area(client: AsyncClient) -> None: + data = bytearray(b"\xaa\xbb\xcc\xdd") + await client.write_area(Area.DB, 1, 0, data) + result = await client.read_area(Area.DB, 1, 0, 4) + assert result == data + + +@pytest.mark.asyncio +async def test_read_area_large(client: AsyncClient) -> None: + """Test chunked read for data larger than PDU.""" + size = 500 # Exceeds typical single-PDU payload + data = bytearray(range(256)) * 2 # 512 bytes of pattern + data = data[:size] + await client.write_area(Area.DB, 1, 0, data) + result = await client.read_area(Area.DB, 1, 0, size) + assert result == data + + +# ------------------------------------------------------------------- +# Memory area convenience methods +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ab_read_write(client: AsyncClient) -> None: + data = bytearray(b"\x01\x02\x03\x04") + await client.ab_write(0, data) + result = await client.ab_read(0, 4) + assert result == data + + +@pytest.mark.asyncio +async def test_eb_read_write(client: AsyncClient) -> None: + data = bytearray(b"\x05\x06\x07\x08") + await client.eb_write(0, 4, data) + result = await client.eb_read(0, 4) + assert result == data + + +@pytest.mark.asyncio +async def test_mb_read_write(client: AsyncClient) -> None: + data = bytearray(b"\x0a\x0b\x0c\x0d") + await client.mb_write(0, 4, data) + result = await client.mb_read(0, 4) + assert result == data + + +# ------------------------------------------------------------------- +# Concurrent safety (the key fix) +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_concurrent_reads(client: AsyncClient) -> None: + """Verify asyncio.gather with multiple reads doesn't corrupt data. + + This is the critical test — it validates that the asyncio.Lock + serializes send/receive cycles correctly. + """ + # Write known data + data1 = bytearray(b"\x11\x22\x33\x44") + data2 = bytearray(b"\xaa\xbb\xcc\xdd") + await client.db_write(1, 0, data1) + await client.db_write(1, 10, data2) + + # Read concurrently + results = await asyncio.gather( + client.db_read(1, 0, 4), + client.db_read(1, 10, 4), + ) + + assert results[0] == data1 + assert results[1] == data2 + + +@pytest.mark.asyncio +async def test_concurrent_read_write(client: AsyncClient) -> None: + """Verify concurrent read and write don't interfere.""" + write_data = bytearray(b"\xff\xfe\xfd\xfc") + + async def do_write() -> None: + await client.db_write(1, 20, write_data) + + async def do_read() -> bytearray: + return await client.db_read(1, 0, 4) + + await asyncio.gather(do_write(), do_read()) + + # Verify write went through + result = await client.db_read(1, 20, 4) + assert result == write_data + + +@pytest.mark.asyncio +async def test_many_concurrent_reads(client: AsyncClient) -> None: + """Stress test with many concurrent reads.""" + # Write test data + for i in range(10): + await client.db_write(1, i * 4, bytearray([i] * 4)) + + # Read all concurrently + tasks = [client.db_read(1, i * 4, 4) for i in range(10)] + results = await asyncio.gather(*tasks) + + for i, result in enumerate(results): + assert result == bytearray([i] * 4), f"Mismatch at index {i}" + + +# ------------------------------------------------------------------- +# Multi-var +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_read_multi_vars(client: AsyncClient) -> None: + await client.db_write(1, 0, bytearray(b"\x01\x02\x03\x04")) + await client.db_write(1, 4, bytearray(b"\x05\x06\x07\x08")) + + items = [ + {"area": Area.DB, "db_number": 1, "start": 0, "size": 4}, + {"area": Area.DB, "db_number": 1, "start": 4, "size": 4}, + ] + code, results = await client.read_multi_vars(items) + assert code == 0 + assert results[0] == bytearray(b"\x01\x02\x03\x04") + assert results[1] == bytearray(b"\x05\x06\x07\x08") + + +@pytest.mark.asyncio +async def test_write_multi_vars(client: AsyncClient) -> None: + items = [ + {"area": Area.DB, "db_number": 1, "start": 0, "data": bytearray(b"\xaa\xbb")}, + {"area": Area.DB, "db_number": 1, "start": 2, "data": bytearray(b"\xcc\xdd")}, + ] + result = await client.write_multi_vars(items) + assert result == 0 + + data = await client.db_read(1, 0, 4) + assert data == bytearray(b"\xaa\xbb\xcc\xdd") + + +# ------------------------------------------------------------------- +# Synchronous helpers (no I/O) +# ------------------------------------------------------------------- + + +def test_get_pdu_length() -> None: + c = AsyncClient() + assert c.get_pdu_length() == 480 + + +def test_error_text() -> None: + c = AsyncClient() + assert c.error_text(0) == "OK" + assert "Not connected" in c.error_text(0x0003) + + +def test_set_clear_session_password() -> None: + c = AsyncClient() + assert c.session_password is None + c.set_session_password("secret") + assert c.session_password == "secret" + c.clear_session_password() + assert c.session_password is None + + +def test_set_connection_params() -> None: + c = AsyncClient() + c.set_connection_params("10.0.0.1", 0x0100, 0x0200) + assert c.host == "10.0.0.1" + assert c.local_tsap == 0x0100 + assert c.remote_tsap == 0x0200 + + +def test_set_connection_type() -> None: + c = AsyncClient() + c.set_connection_type(2) + assert c.connection_type == 2 + + +def test_get_set_param() -> None: + c = AsyncClient() + c.set_param(Parameter.PDURequest, 960) + assert c.get_param(Parameter.PDURequest) == 960 + assert c.pdu_length == 960 + + +def test_get_param_non_client_raises() -> None: + c = AsyncClient() + with pytest.raises(RuntimeError): + c.get_param(Parameter.LocalPort) + + +# ------------------------------------------------------------------- +# Block info / CPU info (against server) +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_list_blocks(client: AsyncClient) -> None: + result = await client.list_blocks() + assert hasattr(result, "DBCount") + + +@pytest.mark.asyncio +async def test_get_cpu_state(client: AsyncClient) -> None: + state = await client.get_cpu_state() + assert isinstance(state, str) + + +@pytest.mark.asyncio +async def test_get_cpu_info(client: AsyncClient) -> None: + info = await client.get_cpu_info() + assert hasattr(info, "ModuleTypeName") + + +@pytest.mark.asyncio +async def test_get_pdu_length_after_connect(client: AsyncClient) -> None: + assert client.get_pdu_length() > 0 diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..dababa78 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,170 @@ +"""Tests for the CLI tools.""" + +import unittest + +import pytest + +click = pytest.importorskip("click") +from click.testing import CliRunner # noqa: E402 + +from snap7.cli import main # noqa: E402 +from snap7.server import Server # noqa: E402 +from snap7.type import SrvArea # noqa: E402 + +ip = "127.0.0.1" +tcpport = 1102 +rack = 1 +slot = 1 + + +@pytest.mark.client +class TestCLI(unittest.TestCase): + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(600)) + cls.server.register_area(SrvArea.DB, 1, bytearray(600)) + cls.server.register_area(SrvArea.PA, 0, bytearray(100)) + cls.server.register_area(SrvArea.PE, 0, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.register_area(SrvArea.TM, 0, bytearray(100)) + cls.server.register_area(SrvArea.CT, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.runner = CliRunner() + + def test_help(self) -> None: + result = self.runner.invoke(main, ["--help"]) + assert result.exit_code == 0 + assert "s7" in result.output + + def test_version(self) -> None: + result = self.runner.invoke(main, ["--version"]) + assert result.exit_code == 0 + + def test_read_bytes(self) -> None: + result = self.runner.invoke(main, ["read", ip, "--db", "1", "--offset", "0", "--size", "4", "--port", str(tcpport)]) + assert result.exit_code == 0 + assert "0000" in result.output + + def test_read_bytes_missing_size(self) -> None: + result = self.runner.invoke(main, ["read", ip, "--db", "1", "--offset", "0", "--port", str(tcpport)]) + assert result.exit_code != 0 + + def test_read_int(self) -> None: + result = self.runner.invoke(main, ["read", ip, "--db", "1", "--offset", "0", "--type", "int", "--port", str(tcpport)]) + assert result.exit_code == 0 + + def test_read_real(self) -> None: + result = self.runner.invoke(main, ["read", ip, "--db", "1", "--offset", "0", "--type", "real", "--port", str(tcpport)]) + assert result.exit_code == 0 + + def test_read_bool(self) -> None: + result = self.runner.invoke( + main, ["read", ip, "--db", "1", "--offset", "0", "--type", "bool", "--bit", "0", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert result.output.strip() in ("True", "False") + + def test_write_int(self) -> None: + result = self.runner.invoke( + main, ["write", ip, "--db", "1", "--offset", "0", "--type", "int", "--value", "42", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert "OK" in result.output + + # Verify + result = self.runner.invoke(main, ["read", ip, "--db", "1", "--offset", "0", "--type", "int", "--port", str(tcpport)]) + assert result.exit_code == 0 + assert "42" in result.output + + def test_write_real(self) -> None: + result = self.runner.invoke( + main, ["write", ip, "--db", "1", "--offset", "4", "--type", "real", "--value", "3.14", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_write_bool(self) -> None: + result = self.runner.invoke( + main, + [ + "write", + ip, + "--db", + "1", + "--offset", + "10", + "--type", + "bool", + "--value", + "true", + "--bit", + "3", + "--port", + str(tcpport), + ], + ) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_write_bytes_hex(self) -> None: + result = self.runner.invoke( + main, ["write", ip, "--db", "1", "--offset", "20", "--type", "bytes", "--value", "DEADBEEF", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_dump(self) -> None: + result = self.runner.invoke(main, ["dump", ip, "--db", "1", "--size", "32", "--port", str(tcpport)]) + assert result.exit_code == 0 + assert "DB1" in result.output + assert "0000" in result.output + + def test_dump_bytes_format(self) -> None: + result = self.runner.invoke(main, ["dump", ip, "--db", "1", "--size", "16", "--format", "bytes", "--port", str(tcpport)]) + assert result.exit_code == 0 + + def test_info(self) -> None: + result = self.runner.invoke(main, ["info", ip, "--port", str(tcpport)]) + assert result.exit_code == 0 + + def test_read_connection_failure(self) -> None: + result = self.runner.invoke(main, ["read", "192.0.2.1", "--db", "1", "--offset", "0", "--size", "4", "--port", "9999"]) + assert result.exit_code != 0 + assert "Connection failed" in result.output + + def test_server_help(self) -> None: + result = self.runner.invoke(main, ["server", "--help"]) + assert result.exit_code == 0 + + def test_write_dint(self) -> None: + result = self.runner.invoke( + main, ["write", ip, "--db", "1", "--offset", "30", "--type", "dint", "--value", "-100000", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_write_word(self) -> None: + result = self.runner.invoke( + main, ["write", ip, "--db", "1", "--offset", "34", "--type", "word", "--value", "1234", "--port", str(tcpport)] + ) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_read_all_types(self) -> None: + """Test that all type names are accepted without error.""" + for type_name in ["byte", "uint", "word", "dword", "udint", "lreal"]: + result = self.runner.invoke( + main, ["read", ip, "--db", "1", "--offset", "0", "--type", type_name, "--port", str(tcpport)] + ) + assert result.exit_code == 0, f"Failed for type {type_name}: {result.output}" diff --git a/tests/test_conformance.py b/tests/test_conformance.py new file mode 100644 index 00000000..4c4a3557 --- /dev/null +++ b/tests/test_conformance.py @@ -0,0 +1,529 @@ +"""Protocol conformance test suite. + +Validates that the S7 protocol implementation correctly encodes/decodes +packets according to the TPKT, COTP, and S7 protocol specifications. +""" + +import struct + +import pytest + +from snap7.connection import ISOTCPConnection, TPDUSize +from snap7.datatypes import S7Area, S7WordLen +from snap7.error import S7ConnectionError, S7ProtocolError +from snap7.s7protocol import S7Function, S7PDUType, S7Protocol, S7_RETURN_CODES + + +@pytest.mark.conformance +class TestTPKTConformance: + """Verify TPKT frame encoding per RFC 1006.""" + + def test_tpkt_header_format(self) -> None: + """TPKT header: version=3, reserved=0, 2-byte big-endian length.""" + conn = ISOTCPConnection("127.0.0.1") + payload = b"\x01\x02\x03" + frame = conn._build_tpkt(payload) + + assert frame[0] == 3, "TPKT version must be 3" + assert frame[1] == 0, "TPKT reserved must be 0" + + def test_tpkt_length_includes_header(self) -> None: + """Length field includes the 4-byte TPKT header.""" + conn = ISOTCPConnection("127.0.0.1") + payload = b"\x01\x02\x03" + frame = conn._build_tpkt(payload) + + length = struct.unpack(">H", frame[2:4])[0] + assert length == len(payload) + 4 + + def test_tpkt_payload_preserved(self) -> None: + """Payload appears intact after the 4-byte header.""" + conn = ISOTCPConnection("127.0.0.1") + payload = b"\xde\xad\xbe\xef" + frame = conn._build_tpkt(payload) + + assert frame[4:] == payload + + def test_tpkt_empty_payload(self) -> None: + """Empty payload produces a 4-byte frame.""" + conn = ISOTCPConnection("127.0.0.1") + frame = conn._build_tpkt(b"") + + assert len(frame) == 4 + length = struct.unpack(">H", frame[2:4])[0] + assert length == 4 + + def test_tpkt_large_payload(self) -> None: + """Length field correctly handles large payloads.""" + conn = ISOTCPConnection("127.0.0.1") + payload = b"\x00" * 1000 + frame = conn._build_tpkt(payload) + + length = struct.unpack(">H", frame[2:4])[0] + assert length == 1004 + + +@pytest.mark.conformance +class TestCOTPConformance: + """Verify COTP PDU encoding per ISO 8073.""" + + def test_cotp_cr_pdu_type(self) -> None: + """CR PDU type code is 0xE0.""" + conn = ISOTCPConnection("127.0.0.1") + cr = conn._build_cotp_cr() + assert cr[1] == 0xE0 + + def test_cotp_cr_destination_reference_zero(self) -> None: + """CR destination reference must be 0x0000.""" + conn = ISOTCPConnection("127.0.0.1") + cr = conn._build_cotp_cr() + dst_ref = struct.unpack(">H", cr[2:4])[0] + assert dst_ref == 0x0000 + + def test_cotp_cr_source_reference(self) -> None: + """CR source reference matches connection setting.""" + conn = ISOTCPConnection("127.0.0.1") + conn.src_ref = 0x1234 + cr = conn._build_cotp_cr() + src_ref = struct.unpack(">H", cr[4:6])[0] + assert src_ref == 0x1234 + + def test_cotp_cr_class_zero(self) -> None: + """CR class/option byte is 0x00 (Class 0, no extended formats).""" + conn = ISOTCPConnection("127.0.0.1") + cr = conn._build_cotp_cr() + assert cr[6] == 0x00 + + def test_cotp_cr_contains_tsap_parameters(self) -> None: + """CR includes calling TSAP (0xC1) and called TSAP (0xC2) parameters.""" + conn = ISOTCPConnection("127.0.0.1", local_tsap=0x0100, remote_tsap=0x0102) + cr = conn._build_cotp_cr() + # Search for parameter codes in the parameter section + param_data = cr[7:] # Parameters start after the 7-byte base header + param_codes = [] + offset = 0 + while offset < len(param_data): + param_codes.append(param_data[offset]) + param_len = param_data[offset + 1] + offset += 2 + param_len + assert 0xC1 in param_codes, "Must contain calling TSAP parameter" + assert 0xC2 in param_codes, "Must contain called TSAP parameter" + + def test_cotp_cr_pdu_size_parameter(self) -> None: + """CR includes PDU size parameter (0xC0).""" + conn = ISOTCPConnection("127.0.0.1") + cr = conn._build_cotp_cr() + param_data = cr[7:] + param_codes = [] + offset = 0 + while offset < len(param_data): + param_codes.append(param_data[offset]) + param_len = param_data[offset + 1] + offset += 2 + param_len + assert 0xC0 in param_codes, "Must contain PDU size parameter" + + def test_cotp_dt_pdu_format(self) -> None: + """DT PDU: length=2, type=0xF0, EOT=0x80.""" + conn = ISOTCPConnection("127.0.0.1") + dt = conn._build_cotp_dt(b"\x01\x02") + assert dt[0] == 2, "DT PDU length must be 2" + assert dt[1] == 0xF0, "DT PDU type must be 0xF0" + assert dt[2] == 0x80, "EOT+number must be 0x80" + + def test_cotp_dt_carries_data(self) -> None: + """DT PDU correctly carries the S7 data payload.""" + conn = ISOTCPConnection("127.0.0.1") + payload = b"\xde\xad\xbe\xef" + dt = conn._build_cotp_dt(payload) + assert dt[3:] == payload + + def test_cotp_cc_parsing(self) -> None: + """CC PDU parsing extracts destination reference.""" + conn = ISOTCPConnection("127.0.0.1") + # Build a minimal CC: pdu_len, type=0xD0, dst_ref, src_ref, class + cc = struct.pack(">BBHHB", 6, 0xD0, 0x0042, 0x0001, 0x00) + conn._parse_cotp_cc(cc) + assert conn.dst_ref == 0x0042 + + def test_cotp_cc_wrong_type_rejected(self) -> None: + """Non-CC PDU type raises error.""" + conn = ISOTCPConnection("127.0.0.1") + bad_cc = struct.pack(">BBHHB", 6, 0xE0, 0x0000, 0x0001, 0x00) + with pytest.raises(S7ConnectionError, match="Expected COTP CC"): + conn._parse_cotp_cc(bad_cc) + + def test_cotp_cc_too_short_rejected(self) -> None: + """CC PDU shorter than 7 bytes is rejected.""" + conn = ISOTCPConnection("127.0.0.1") + with pytest.raises(S7ConnectionError, match="too short"): + conn._parse_cotp_cc(b"\x06\xd0\x00") + + def test_cotp_data_parsing(self) -> None: + """Data parsing extracts payload from DT PDU.""" + conn = ISOTCPConnection("127.0.0.1") + cotp_pdu = struct.pack(">BBB", 2, 0xF0, 0x80) + b"\x32\x01\x02\x03" + data = conn._parse_cotp_data(cotp_pdu) + assert data == b"\x32\x01\x02\x03" + + def test_cotp_data_wrong_type_rejected(self) -> None: + """Non-DT PDU type in data parsing raises error.""" + conn = ISOTCPConnection("127.0.0.1") + bad_dt = struct.pack(">BBB", 2, 0xE0, 0x80) + b"\x01" + with pytest.raises(S7ConnectionError, match="Expected COTP DT"): + conn._parse_cotp_data(bad_dt) + + def test_cotp_data_too_short_rejected(self) -> None: + """DT PDU shorter than 3 bytes is rejected.""" + conn = ISOTCPConnection("127.0.0.1") + with pytest.raises(S7ConnectionError, match="too short"): + conn._parse_cotp_data(b"\x02\xf0") + + +@pytest.mark.conformance +class TestS7HeaderConformance: + """Verify S7 PDU header encoding.""" + + def test_protocol_id(self) -> None: + """S7 protocol ID is always 0x32.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + assert pdu[0] == 0x32 + + def test_request_pdu_type(self) -> None: + """Read/write requests use PDU type 0x01 (REQUEST).""" + proto = S7Protocol() + read_pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + assert read_pdu[1] == S7PDUType.REQUEST + + proto2 = S7Protocol() + write_pdu = proto2.build_write_request(S7Area.DB, 1, 0, S7WordLen.BYTE, b"\x00") + assert write_pdu[1] == S7PDUType.REQUEST + + def test_header_reserved_zero(self) -> None: + """Reserved field (bytes 2-3) is always 0x0000.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + reserved = struct.unpack(">H", pdu[2:4])[0] + assert reserved == 0x0000 + + def test_sequence_number_increments(self) -> None: + """Sequence number increments with each request.""" + proto = S7Protocol() + pdu1 = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + pdu2 = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + seq1 = struct.unpack(">H", pdu1[4:6])[0] + seq2 = struct.unpack(">H", pdu2[4:6])[0] + assert seq2 == seq1 + 1 + + def test_header_is_12_bytes(self) -> None: + """S7 request header is exactly 12 bytes (proto, type, reserved, seq, param_len, data_len).""" + proto = S7Protocol() + pdu = proto.build_setup_communication_request() + # Header: proto(1) + type(1) + reserved(2) + seq(2) + param_len(2) + data_len(2) = 10 + # Actually for REQUEST type it's 10 bytes + assert pdu[0] == 0x32 + assert len(pdu) >= 10 + + +@pytest.mark.conformance +class TestS7FunctionCodes: + """Verify S7 function codes match the specification.""" + + def test_read_area_function_code(self) -> None: + """Read area function code is 0x04.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 1) + # Function code is first byte of parameter section (after 10-byte header) + assert pdu[10] == 0x04 + + def test_write_area_function_code(self) -> None: + """Write area function code is 0x05.""" + proto = S7Protocol() + pdu = proto.build_write_request(S7Area.DB, 1, 0, S7WordLen.BYTE, b"\x00") + assert pdu[10] == 0x05 + + def test_setup_communication_function_code(self) -> None: + """Setup communication function code is 0xF0.""" + proto = S7Protocol() + pdu = proto.build_setup_communication_request() + assert pdu[10] == 0xF0 + + def test_plc_control_function_code(self) -> None: + """PLC control function code is 0x28.""" + proto = S7Protocol() + pdu = proto.build_plc_control_request("hot_start") + assert pdu[10] == 0x28 + + +@pytest.mark.conformance +class TestS7AreaCodes: + """Verify S7 area codes match the specification.""" + + def test_area_code_pe(self) -> None: + assert S7Area.PE.value == 0x81 + + def test_area_code_pa(self) -> None: + assert S7Area.PA.value == 0x82 + + def test_area_code_mk(self) -> None: + assert S7Area.MK.value == 0x83 + + def test_area_code_db(self) -> None: + assert S7Area.DB.value == 0x84 + + def test_area_code_ct(self) -> None: + assert S7Area.CT.value == 0x1C + + def test_area_code_tm(self) -> None: + assert S7Area.TM.value == 0x1D + + +@pytest.mark.conformance +class TestS7WordLenCodes: + """Verify S7 word length codes match the specification.""" + + def test_wordlen_bit(self) -> None: + assert S7WordLen.BIT.value == 0x01 + + def test_wordlen_byte(self) -> None: + assert S7WordLen.BYTE.value == 0x02 + + def test_wordlen_char(self) -> None: + assert S7WordLen.CHAR.value == 0x03 + + def test_wordlen_word(self) -> None: + assert S7WordLen.WORD.value == 0x04 + + def test_wordlen_int(self) -> None: + assert S7WordLen.INT.value == 0x05 + + def test_wordlen_dword(self) -> None: + assert S7WordLen.DWORD.value == 0x06 + + def test_wordlen_dint(self) -> None: + assert S7WordLen.DINT.value == 0x07 + + def test_wordlen_real(self) -> None: + assert S7WordLen.REAL.value == 0x08 + + def test_wordlen_counter(self) -> None: + assert S7WordLen.COUNTER.value == 0x1C + + def test_wordlen_timer(self) -> None: + assert S7WordLen.TIMER.value == 0x1D + + +@pytest.mark.conformance +class TestS7PDUTypes: + """Verify S7 PDU type codes match the specification.""" + + def test_pdu_type_request(self) -> None: + assert S7PDUType.REQUEST.value == 0x01 + + def test_pdu_type_ack(self) -> None: + assert S7PDUType.ACK.value == 0x02 + + def test_pdu_type_ack_data(self) -> None: + assert S7PDUType.ACK_DATA.value == 0x03 + + def test_pdu_type_userdata(self) -> None: + assert S7PDUType.USERDATA.value == 0x07 + + +@pytest.mark.conformance +class TestS7ReadRequestEncoding: + """Verify read request PDU structure.""" + + def test_read_request_item_count(self) -> None: + """Read request has item count = 1.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 4) + assert pdu[11] == 0x01 # Item count + + def test_read_request_variable_spec(self) -> None: + """Variable specification marker is 0x12.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 4) + assert pdu[12] == 0x12 + + def test_read_request_data_length_zero(self) -> None: + """Read requests have data length = 0.""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 4) + data_len = struct.unpack(">H", pdu[8:10])[0] + assert data_len == 0 + + def test_read_request_parameter_length(self) -> None: + """Read request parameter length is 14 (function + count + address spec).""" + proto = S7Protocol() + pdu = proto.build_read_request(S7Area.DB, 1, 0, S7WordLen.BYTE, 4) + param_len = struct.unpack(">H", pdu[6:8])[0] + assert param_len == 14 + + +@pytest.mark.conformance +class TestS7WriteRequestEncoding: + """Verify write request PDU structure.""" + + def test_write_request_has_data_section(self) -> None: + """Write requests include a data section.""" + proto = S7Protocol() + data = b"\x01\x02\x03\x04" + pdu = proto.build_write_request(S7Area.DB, 1, 0, S7WordLen.BYTE, data) + data_len = struct.unpack(">H", pdu[8:10])[0] + assert data_len > 0 + + def test_write_request_data_section_structure(self) -> None: + """Write data section: reserved(1) + transport_size(1) + bit_length(2) + data.""" + proto = S7Protocol() + data = b"\x01\x02\x03\x04" + pdu = proto.build_write_request(S7Area.DB, 1, 0, S7WordLen.BYTE, data) + # Data section starts after header(10) + parameters(14) + data_section = pdu[24:] + assert data_section[0] == 0x00 # Reserved + assert len(data_section) >= 4 + len(data) # transport header + data + + def test_write_request_bit_length(self) -> None: + """Bit length in data section is data_bytes * 8.""" + proto = S7Protocol() + data = b"\x01\x02\x03\x04" + pdu = proto.build_write_request(S7Area.DB, 1, 0, S7WordLen.BYTE, data) + data_section = pdu[24:] + bit_length = struct.unpack(">H", data_section[2:4])[0] + assert bit_length == len(data) * 8 + + +@pytest.mark.conformance +class TestS7SetupCommunication: + """Verify setup communication PDU structure.""" + + def test_setup_comm_pdu_size(self) -> None: + """Setup communication encodes requested PDU size.""" + proto = S7Protocol() + pdu = proto.build_setup_communication_request(pdu_length=480) + # Parameter section: function(1) + reserved(1) + max_amq_caller(2) + max_amq_callee(2) + pdu_len(2) + param_start = 10 + pdu_length = struct.unpack(">H", pdu[param_start + 6 : param_start + 8])[0] + assert pdu_length == 480 + + def test_setup_comm_amq_values(self) -> None: + """Setup communication encodes AMQ caller/callee.""" + proto = S7Protocol() + pdu = proto.build_setup_communication_request(max_amq_caller=3, max_amq_callee=3, pdu_length=960) + param_start = 10 + amq_caller = struct.unpack(">H", pdu[param_start + 2 : param_start + 4])[0] + amq_callee = struct.unpack(">H", pdu[param_start + 4 : param_start + 6])[0] + assert amq_caller == 3 + assert amq_callee == 3 + + +@pytest.mark.conformance +class TestS7ResponseParsing: + """Verify S7 response PDU parsing.""" + + def test_parse_valid_ack_data(self) -> None: + """Valid ACK_DATA response parses without error.""" + proto = S7Protocol() + # Build a minimal ACK_DATA response: header(12 bytes) + pdu = struct.pack( + ">BBHHHHBB", + 0x32, # Protocol ID + S7PDUType.ACK_DATA, + 0x0000, # Reserved + 0x0001, # Sequence + 0x0000, # Parameter length + 0x0000, # Data length + 0x00, # Error class + 0x00, # Error code + ) + response = proto.parse_response(pdu) + assert response["sequence"] == 1 + assert response["error_code"] == 0 + + def test_parse_ack_response(self) -> None: + """ACK (write response) parses correctly.""" + proto = S7Protocol() + # ACK with function code + item count in parameters (min 2 bytes for write response) + pdu = struct.pack( + ">BBHHHHBB", + 0x32, + S7PDUType.ACK, + 0x0000, + 0x0005, + 0x0002, # Param length = 2 + 0x0000, # Data length + 0x00, + 0x00, + ) + struct.pack(">BB", S7Function.WRITE_AREA, 0x01) + response = proto.parse_response(pdu) + assert response["error_code"] == 0 + + def test_reject_invalid_protocol_id(self) -> None: + """Non-0x32 protocol ID raises error.""" + proto = S7Protocol() + pdu = struct.pack(">BBHHHHBB", 0x33, S7PDUType.ACK_DATA, 0, 1, 0, 0, 0, 0) + with pytest.raises(S7ProtocolError, match="Invalid protocol ID"): + proto.parse_response(pdu) + + def test_reject_request_pdu_type(self) -> None: + """REQUEST PDU type in response is rejected.""" + proto = S7Protocol() + pdu = struct.pack(">BBHHHHBB", 0x32, S7PDUType.REQUEST, 0, 1, 0, 0, 0, 0) + with pytest.raises(S7ProtocolError, match="Expected response PDU"): + proto.parse_response(pdu) + + def test_reject_too_short_pdu(self) -> None: + """PDU shorter than 10 bytes is rejected.""" + proto = S7Protocol() + with pytest.raises(S7ProtocolError, match="too short"): + proto.parse_response(b"\x32\x03\x00") + + def test_error_class_raises(self) -> None: + """Non-zero error class raises S7ProtocolError.""" + proto = S7Protocol() + pdu = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK_DATA, 0, 1, 0, 0, 0x81, 0x04) + with pytest.raises(S7ProtocolError): + proto.parse_response(pdu) + + +@pytest.mark.conformance +class TestS7ReturnCodes: + """Verify S7 return code definitions.""" + + def test_success_code(self) -> None: + assert S7_RETURN_CODES[0xFF] == "Success" + + def test_hardware_error_code(self) -> None: + assert S7_RETURN_CODES[0x01] == "Hardware error" + + def test_invalid_address_code(self) -> None: + assert S7_RETURN_CODES[0x05] == "Invalid address" + + def test_object_does_not_exist_code(self) -> None: + assert S7_RETURN_CODES[0x0A] == "Object does not exist" + + def test_all_codes_have_descriptions(self) -> None: + """Every defined return code has a non-empty description.""" + for code, desc in S7_RETURN_CODES.items(): + assert desc, f"Return code {code:#04x} has empty description" + + +@pytest.mark.conformance +class TestTPDUSizes: + """Verify TPDU size constants match ISO 8073.""" + + def test_tpdu_sizes_are_powers_of_two(self) -> None: + """Each TPDU size value is an exponent where actual_size = 2^value.""" + for size in TPDUSize: + actual = 1 << size.value + assert actual >= 128 + assert actual <= 8192 + + def test_tpdu_size_values(self) -> None: + assert TPDUSize.S_128.value == 0x07 + assert TPDUSize.S_256.value == 0x08 + assert TPDUSize.S_512.value == 0x09 + assert TPDUSize.S_1024.value == 0x0A + assert TPDUSize.S_2048.value == 0x0B + assert TPDUSize.S_4096.value == 0x0C + assert TPDUSize.S_8192.value == 0x0D diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..ed784e67 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,475 @@ +"""Tests for snap7.connection module — socket mocking, COTP parsing, exception paths.""" + +import socket +import struct +import pytest +from unittest.mock import patch, MagicMock + +from snap7.connection import ISOTCPConnection, TPDUSize +from snap7.error import S7ConnectionError, S7TimeoutError + + +class TestTPDUSize: + """Test TPDUSize enum values.""" + + def test_sizes(self) -> None: + assert TPDUSize.S_128.value == 0x07 + assert TPDUSize.S_1024.value == 0x0A + assert TPDUSize.S_8192.value == 0x0D + + +class TestISOTCPConnectionInit: + """Test constructor defaults.""" + + def test_defaults(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + assert conn.host == "1.2.3.4" + assert conn.port == 102 + assert conn.connected is False + assert conn.socket is None + assert conn.pdu_size == 240 + + def test_custom_params(self) -> None: + conn = ISOTCPConnection("1.2.3.4", port=1102, local_tsap=0x200, remote_tsap=0x300, tpdu_size=TPDUSize.S_512) + assert conn.port == 1102 + assert conn.local_tsap == 0x200 + assert conn.remote_tsap == 0x300 + assert conn.tpdu_size == TPDUSize.S_512 + + +class TestBuildTPKT: + """Test TPKT frame building.""" + + def test_tpkt_structure(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + payload = b"\x01\x02\x03" + frame = conn._build_tpkt(payload) + assert frame[:2] == b"\x03\x00" # version=3, reserved=0 + length = struct.unpack(">H", frame[2:4])[0] + assert length == 7 # 4 header + 3 payload + assert frame[4:] == payload + + +class TestBuildCOTPCR: + """Test COTP Connection Request building.""" + + def test_cr_structure(self) -> None: + conn = ISOTCPConnection("1.2.3.4", local_tsap=0x0100, remote_tsap=0x0102) + cr = conn._build_cotp_cr() + # First byte = PDU length + pdu_type = cr[1] + assert pdu_type == 0xE0 # COTP_CR + # Should contain parameters for TSAP and PDU size + assert len(cr) > 7 + + +class TestBuildCOTPDT: + """Test COTP Data Transfer building.""" + + def test_dt_structure(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + data = b"\xaa\xbb" + dt = conn._build_cotp_dt(data) + assert dt[0] == 2 # PDU length + assert dt[1] == 0xF0 # COTP_DT + assert dt[2] == 0x80 # EOT + assert dt[3:] == data + + +class TestParseCOTPCC: + """Test COTP Connection Confirm parsing.""" + + def test_valid_cc(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + # Build a valid CC: len, type, dst_ref, src_ref, class_opt + cc_data = struct.pack(">BBHHB", 6, 0xD0, 0x1234, 0x0001, 0x00) + conn._parse_cotp_cc(cc_data) + assert conn.dst_ref == 0x1234 + + def test_cc_too_short(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="too short"): + conn._parse_cotp_cc(b"\x00\x01\x02") + + def test_cc_wrong_type(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + cc_data = struct.pack(">BBHHB", 6, 0xE0, 0x0000, 0x0001, 0x00) # CR instead of CC + with pytest.raises(S7ConnectionError, match="Expected COTP CC"): + conn._parse_cotp_cc(cc_data) + + def test_cc_with_pdu_size_param_1byte(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + base = struct.pack(">BBHHB", 10, 0xD0, 0x0001, 0x0001, 0x00) + # PDU size parameter: code=0xC0, len=1, value=0x0A (=1024) + param = struct.pack(">BBB", 0xC0, 1, 0x0A) + conn._parse_cotp_cc(base + param) + assert conn.pdu_size == 1024 + + def test_cc_with_pdu_size_param_2byte(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + base = struct.pack(">BBHHB", 11, 0xD0, 0x0001, 0x0001, 0x00) + # PDU size parameter: code=0xC0, len=2, value=2048 + param = struct.pack(">BBH", 0xC0, 2, 2048) + conn._parse_cotp_cc(base + param) + assert conn.pdu_size == 2048 + + +class TestParseCOTPParameters: + """Test COTP parameter parsing edge cases.""" + + def test_unknown_parameter(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + # Unknown param code 0xFF, length 1, data 0x00 + params = struct.pack(">BBB", 0xFF, 1, 0x00) + conn._parse_cotp_parameters(params) + # Should not crash; pdu_size should remain default + assert conn.pdu_size == 240 + + def test_truncated_params(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + # Just one byte — should break out of loop + conn._parse_cotp_parameters(b"\xc0") + assert conn.pdu_size == 240 + + def test_param_len_exceeds_data(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + # code=0xC0, len=5, but only 1 byte of data follows + params = struct.pack(">BBB", 0xC0, 5, 0x0A) + conn._parse_cotp_parameters(params) + # Should break early without error + assert conn.pdu_size == 240 + + +class TestParseCOTPData: + """Test COTP Data Transfer parsing.""" + + def test_valid_dt(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + pdu = struct.pack(">BBB", 2, 0xF0, 0x80) + b"\xde\xad" + result = conn._parse_cotp_data(pdu) + assert result == b"\xde\xad" + + def test_dt_too_short(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="too short"): + conn._parse_cotp_data(b"\x02") + + def test_dt_wrong_type(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + pdu = struct.pack(">BBB", 2, 0xD0, 0x80) # CC instead of DT + with pytest.raises(S7ConnectionError, match="Expected COTP DT"): + conn._parse_cotp_data(pdu) + + +class TestSendData: + """Test send_data() error paths.""" + + def test_send_when_not_connected(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="Not connected"): + conn.send_data(b"\x00") + + def test_send_socket_error(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + conn.socket = MagicMock() + conn.socket.sendall.side_effect = socket.error("broken pipe") + with pytest.raises(S7ConnectionError, match="Send failed"): + conn.send_data(b"\x00") + assert conn.connected is False + + +class TestReceiveData: + """Test receive_data() error paths.""" + + def test_receive_when_not_connected(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="Not connected"): + conn.receive_data() + + def test_receive_invalid_tpkt_version(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + # TPKT with version 5 instead of 3 + mock_socket.recv.return_value = struct.pack(">BBH", 5, 0, 10) + with pytest.raises(S7ConnectionError, match="Invalid TPKT version"): + conn.receive_data() + + def test_receive_invalid_tpkt_length(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + # Length = 3, remaining = -1 + mock_socket.recv.return_value = struct.pack(">BBH", 3, 0, 3) + with pytest.raises(S7ConnectionError, match="Invalid TPKT length"): + conn.receive_data() + + def test_receive_timeout(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.recv.side_effect = socket.timeout("timeout") + with pytest.raises(S7TimeoutError, match="Receive timeout"): + conn.receive_data() + assert conn.connected is False + + def test_receive_socket_error(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + # First recv returns valid TPKT header, second raises error + mock_socket.recv.side_effect = [struct.pack(">BBH", 3, 0, 10), socket.error("reset")] + with pytest.raises(S7ConnectionError, match="Receive error"): + conn.receive_data() + assert conn.connected is False + + +class TestRecvExact: + """Test _recv_exact() with various scenarios.""" + + def test_socket_none(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="Socket not initialized"): + conn._recv_exact(4) + + def test_connection_closed(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = MagicMock() + conn.socket.recv.return_value = b"" # empty = connection closed + with pytest.raises(S7ConnectionError, match="Connection closed"): + conn._recv_exact(4) + assert conn.connected is False + + def test_partial_reads(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = MagicMock() + conn.socket.recv.side_effect = [b"\x01\x02", b"\x03\x04"] + result = conn._recv_exact(4) + assert result == b"\x01\x02\x03\x04" + + def test_timeout(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = MagicMock() + conn.socket.recv.side_effect = socket.timeout("timeout") + with pytest.raises(S7TimeoutError): + conn._recv_exact(4) + + def test_socket_error(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = MagicMock() + conn.socket.recv.side_effect = socket.error("broken") + with pytest.raises(S7ConnectionError, match="Receive error"): + conn._recv_exact(4) + + +class TestSendCOTPDisconnect: + """Test _send_cotp_disconnect().""" + + def test_disconnect_no_socket(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = None + # Should return without error + conn._send_cotp_disconnect() + + def test_disconnect_sends_dr(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + mock_socket = MagicMock() + conn.socket = mock_socket + conn._send_cotp_disconnect() + mock_socket.sendall.assert_called_once() + + def test_disconnect_ignores_socket_error(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + mock_socket = MagicMock() + mock_socket.sendall.side_effect = socket.error("broken") + conn.socket = mock_socket + # Should not raise + conn._send_cotp_disconnect() + + +class TestConnect: + """Test connect() orchestration.""" + + @patch.object(ISOTCPConnection, "_tcp_connect") + @patch.object(ISOTCPConnection, "_iso_connect") + def test_successful_connect(self, mock_iso: MagicMock, mock_tcp: MagicMock) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connect(timeout=2.0) + assert conn.connected is True + assert conn.timeout == 2.0 + mock_tcp.assert_called_once() + mock_iso.assert_called_once() + + @patch.object(ISOTCPConnection, "_tcp_connect", side_effect=OSError("connection refused")) + @patch.object(ISOTCPConnection, "disconnect") + def test_connect_failure_wraps_in_s7error(self, mock_disc: MagicMock, mock_tcp: MagicMock) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="Connection failed"): + conn.connect() + mock_disc.assert_called_once() + + @patch.object(ISOTCPConnection, "_tcp_connect") + @patch.object(ISOTCPConnection, "_iso_connect", side_effect=S7ConnectionError("COTP fail")) + @patch.object(ISOTCPConnection, "disconnect") + def test_connect_reraises_s7_errors(self, mock_disc: MagicMock, mock_iso: MagicMock, mock_tcp: MagicMock) -> None: + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="COTP fail"): + conn.connect() + + +class TestDisconnect: + """Test disconnect() behavior.""" + + def test_disconnect_when_no_socket(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + # Should not raise + conn.disconnect() + + def test_disconnect_closes_socket(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + mock_socket = MagicMock() + conn.socket = mock_socket + conn.connected = True + conn.disconnect() + mock_socket.close.assert_called_once() + assert conn.socket is None + assert conn.connected is False + + def test_disconnect_ignores_errors(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + mock_socket = MagicMock() + mock_socket.close.side_effect = OSError("already closed") + conn.socket = mock_socket + conn.connected = False + conn.disconnect() + assert conn.socket is None + + +class TestContextManager: + """Test __enter__ / __exit__.""" + + def test_enter_returns_self(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + assert conn.__enter__() is conn + + def test_exit_calls_disconnect(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = MagicMock() + conn.connected = True + conn.__exit__(None, None, None) + assert conn.socket is None + assert conn.connected is False + + def test_context_manager_protocol(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + with conn as c: + assert c is conn + assert conn.connected is False + + +class TestCheckConnection: + """Test check_connection() method.""" + + def test_not_connected(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + assert conn.check_connection() is False + + def test_socket_none(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + conn.socket = None + assert conn.check_connection() is False + + def test_connection_alive_no_data(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.gettimeout.return_value = 5.0 + mock_socket.recv.side_effect = BlockingIOError + assert conn.check_connection() is True + + def test_connection_alive_with_data(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.gettimeout.return_value = 5.0 + mock_socket.recv.return_value = b"\x00" + assert conn.check_connection() is True + + def test_connection_closed_by_peer(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.gettimeout.return_value = 5.0 + mock_socket.recv.return_value = b"" + assert conn.check_connection() is False + assert conn.connected is False + + def test_connection_socket_error(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.gettimeout.return_value = 5.0 + mock_socket.recv.side_effect = socket.error("reset") + assert conn.check_connection() is False + assert conn.connected is False + + def test_connection_exception_in_outer_try(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.connected = True + mock_socket = MagicMock() + conn.socket = mock_socket + mock_socket.gettimeout.side_effect = Exception("unexpected") + assert conn.check_connection() is False + + +class TestTCPConnect: + """Test _tcp_connect().""" + + @patch("snap7.connection.socket.socket") + def test_tcp_connect_failure(self, mock_socket_cls: MagicMock) -> None: + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.connect.side_effect = socket.error("refused") + conn = ISOTCPConnection("1.2.3.4") + with pytest.raises(S7ConnectionError, match="TCP connection failed"): + conn._tcp_connect() + + @patch("snap7.connection.socket.socket") + def test_tcp_connect_success(self, mock_socket_cls: MagicMock) -> None: + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + conn = ISOTCPConnection("1.2.3.4") + conn._tcp_connect() + mock_sock.settimeout.assert_called_once() + mock_sock.connect.assert_called_once_with(("1.2.3.4", 102)) + + +class TestISOConnect: + """Test _iso_connect().""" + + def test_iso_connect_no_socket(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + conn.socket = None + with pytest.raises(S7ConnectionError, match="Socket not initialized"): + conn._iso_connect() + + def test_iso_connect_bad_tpkt_version(self) -> None: + conn = ISOTCPConnection("1.2.3.4") + mock_socket = MagicMock() + conn.socket = mock_socket + # Build a valid CC response wrapped in a bad TPKT + cc = struct.pack(">BBHHB", 6, 0xD0, 0x0001, 0x0001, 0x00) + bad_tpkt = struct.pack(">BBH", 5, 0, 4 + len(cc)) + mock_socket.recv.side_effect = [bad_tpkt, cc] + with pytest.raises(S7ConnectionError, match="Invalid TPKT version"): + conn._iso_connect() diff --git a/tests/test_error.py b/tests/test_error.py new file mode 100644 index 00000000..7e32f9e4 --- /dev/null +++ b/tests/test_error.py @@ -0,0 +1,181 @@ +"""Tests for snap7.error module — error routing, check_error(), error_wrap() decorator.""" + +import pytest + +from snap7.error import ( + S7Error, + S7ConnectionError, + S7ProtocolError, + S7TimeoutError, + S7AuthenticationError, + S7StalePacketError, + S7PacketLostError, + get_error_message, + get_protocol_error_message, + check_error, + error_text, + error_wrap, +) + + +class TestExceptionClasses: + """Verify all exception classes can be instantiated with expected attributes.""" + + def test_s7error_with_code(self) -> None: + err = S7Error("msg", error_code=42) + assert str(err) == "msg" + assert err.error_code == 42 + + def test_s7error_without_code(self) -> None: + err = S7Error("msg") + assert err.error_code is None + + def test_subclass_hierarchy(self) -> None: + assert issubclass(S7ConnectionError, S7Error) + assert issubclass(S7ProtocolError, S7Error) + assert issubclass(S7TimeoutError, S7Error) + assert issubclass(S7AuthenticationError, S7Error) + assert issubclass(S7StalePacketError, S7ProtocolError) + assert issubclass(S7PacketLostError, S7ProtocolError) + + def test_all_subclasses_instantiate(self) -> None: + for cls in ( + S7ConnectionError, + S7ProtocolError, + S7TimeoutError, + S7AuthenticationError, + S7StalePacketError, + S7PacketLostError, + ): + e = cls("test", error_code=1) + assert str(e) == "test" + assert e.error_code == 1 + + +class TestGetErrorMessage: + """Tests for get_error_message() — known and unknown codes.""" + + def test_success_code(self) -> None: + assert get_error_message(0x00000000) == "Success" + + def test_known_client_error(self) -> None: + # Use a code unique to client errors (not overlapping with server: 0x009+) + assert get_error_message(0x00900000) == "errCliAddressOutOfRange" + + def test_known_isotcp_error(self) -> None: + assert get_error_message(0x00010000) == "errIsoConnect" + + def test_known_server_error(self) -> None: + assert get_error_message(0x00200000) == "errSrvDBNullPointer" + + def test_unknown_code(self) -> None: + msg = get_error_message(0xDEADBEEF) + assert "Unknown error" in msg + assert "0xdeadbeef" in msg + + +class TestGetProtocolErrorMessage: + """Tests for get_protocol_error_message() — known and unknown protocol codes.""" + + def test_known_protocol_code(self) -> None: + assert get_protocol_error_message(0x0000) == "No error" + + def test_known_protocol_error(self) -> None: + assert "block number" in get_protocol_error_message(0x0110).lower() + + def test_unknown_protocol_code(self) -> None: + msg = get_protocol_error_message(0xFFFF) + assert "Unknown protocol error" in msg + + +class TestErrorText: + """Tests for error_text() with different contexts.""" + + def test_client_context(self) -> None: + msg = error_text(0x00100000, "client") + assert msg == "errNegotiatingPDU" + + def test_server_context(self) -> None: + # Server dict has its own 0x00100000 entry + msg = error_text(0x00100000, "server") + assert msg == "errSrvCannotStart" + + def test_partner_context(self) -> None: + # Partner uses client errors + msg = error_text(0x00100000, "partner") + assert msg == "errNegotiatingPDU" + + def test_unknown_context_falls_back_to_client(self) -> None: + msg = error_text(0x00100000, "unknown_context") + assert msg == "errNegotiatingPDU" + + def test_unknown_error_code(self) -> None: + msg = error_text(0xBADC0DE, "client") + assert "Unknown error" in msg + + def test_caching(self) -> None: + # Calling twice should return the same cached result + a = error_text(0x00100000, "client") + b = error_text(0x00100000, "client") + assert a == b + + +class TestCheckError: + """Tests for check_error() — routes error codes to exception types.""" + + def test_zero_returns_none(self) -> None: + # Should not raise + check_error(0) + + def test_iso_connect_raises_connection_error(self) -> None: + with pytest.raises(S7ConnectionError): + check_error(0x00010000) + + def test_iso_disconnect_raises_connection_error(self) -> None: + with pytest.raises(S7ConnectionError): + check_error(0x00020000) + + def test_timeout_raises_timeout_error(self) -> None: + with pytest.raises(S7TimeoutError): + check_error(0x02000000) + + def test_other_isotcp_raises_connection_error(self) -> None: + with pytest.raises(S7ConnectionError): + check_error(0x00030000) # errIsoInvalidPDU + + def test_generic_error_raises_runtime_error(self) -> None: + with pytest.raises(RuntimeError): + check_error(0x00100000) # errNegotiatingPDU (client error) + + +class TestErrorWrap: + """Tests for error_wrap() decorator.""" + + def test_no_error(self) -> None: + @error_wrap("client") + def ok_func() -> int: + return 0 + + # Should not raise, returns None (decorator suppresses return value) + result = ok_func() + assert result is None + + def test_raises_on_error(self) -> None: + @error_wrap("client") + def bad_func() -> int: + return 0x02000000 # timeout + + with pytest.raises(S7TimeoutError): + bad_func() + + def test_passes_args_through(self) -> None: + @error_wrap("client") + def func_with_args(a: int, b: int) -> int: + return a + b + + # 0 + 0 = 0, no error + func_with_args(0, 0) + + with pytest.raises(RuntimeError): + # Non-zero = error + func_with_args(0x00100000, 0) diff --git a/tests/test_logo_client.py b/tests/test_logo_client.py index 58bf5d5c..a5d48a6f 100644 --- a/tests/test_logo_client.py +++ b/tests/test_logo_client.py @@ -4,8 +4,9 @@ from typing import Optional import snap7 +from snap7.logo import Logo, parse_address from snap7.server import Server -from snap7.type import Parameter, SrvArea +from snap7.type import Parameter, SrvArea, WordLen logging.basicConfig(level=logging.WARNING) @@ -124,5 +125,247 @@ def test_set_param(self) -> None: self.client.set_param(param, value) +logo_coverage_tcpport = 11102 + + +# --------------------------------------------------------------------------- +# parse_address() unit tests (no server needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.logo +class TestParseAddress(unittest.TestCase): + """Test every branch of parse_address().""" + + def test_byte_address(self) -> None: + start, wl = parse_address("V10") + self.assertEqual(start, 10) + self.assertEqual(wl, WordLen.Byte) + + def test_byte_address_large(self) -> None: + start, wl = parse_address("V999") + self.assertEqual(start, 999) + self.assertEqual(wl, WordLen.Byte) + + def test_word_address(self) -> None: + start, wl = parse_address("VW20") + self.assertEqual(start, 20) + self.assertEqual(wl, WordLen.Word) + + def test_word_address_zero(self) -> None: + start, wl = parse_address("VW0") + self.assertEqual(start, 0) + self.assertEqual(wl, WordLen.Word) + + def test_dword_address(self) -> None: + start, wl = parse_address("VD30") + self.assertEqual(start, 30) + self.assertEqual(wl, WordLen.DWord) + + def test_bit_address(self) -> None: + start, wl = parse_address("V10.3") + # bit offset = 10*8 + 3 = 83 + self.assertEqual(start, 83) + self.assertEqual(wl, WordLen.Bit) + + def test_bit_address_zero(self) -> None: + start, wl = parse_address("V0.0") + self.assertEqual(start, 0) + self.assertEqual(wl, WordLen.Bit) + + def test_bit_address_high_bit(self) -> None: + start, wl = parse_address("V0.7") + self.assertEqual(start, 7) + self.assertEqual(wl, WordLen.Bit) + + def test_invalid_address_raises(self) -> None: + with self.assertRaises(ValueError): + parse_address("INVALID") + + def test_invalid_address_empty(self) -> None: + with self.assertRaises(ValueError): + parse_address("") + + def test_invalid_address_wrong_prefix(self) -> None: + with self.assertRaises(ValueError): + parse_address("M10") + + +# --------------------------------------------------------------------------- +# Integration tests: Logo client against the built-in Server +# --------------------------------------------------------------------------- + + +@pytest.mark.logo +class TestLogoReadWrite(unittest.TestCase): + """Test Logo read/write against a real server with DB1 registered.""" + + server: Optional[Server] = None + db_data: bytearray + + @classmethod + def setUpClass(cls) -> None: + cls.db_data = bytearray(256) + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(256)) + cls.server.register_area(SrvArea.DB, 1, cls.db_data) + cls.server.start(tcp_port=logo_coverage_tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Logo() + self.client.connect(ip, 0x1000, 0x2000, logo_coverage_tcpport) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + # -- read tests --------------------------------------------------------- + + def test_read_byte(self) -> None: + """Write a known byte into DB1 via client, then read it back.""" + self.client.write("V5", 0xAB) + result = self.client.read("V5") + self.assertEqual(result, 0xAB) + + def test_read_word(self) -> None: + """Write and read back a word (signed 16-bit big-endian).""" + self.client.write("VW10", 1234) + result = self.client.read("VW10") + self.assertEqual(result, 1234) + + def test_read_word_negative(self) -> None: + """Words are signed — negative values should round-trip.""" + self.client.write("VW12", -500) + result = self.client.read("VW12") + self.assertEqual(result, -500) + + def test_read_dword(self) -> None: + """Write and read back a dword (signed 32-bit big-endian).""" + self.client.write("VD20", 70000) + result = self.client.read("VD20") + self.assertEqual(result, 70000) + + def test_read_dword_negative(self) -> None: + """DWords are signed — negative values should round-trip.""" + self.client.write("VD24", -123456) + result = self.client.read("VD24") + self.assertEqual(result, -123456) + + def test_read_bit_set(self) -> None: + """Write bit=1, then read it back.""" + self.client.write("V50.2", 1) + result = self.client.read("V50.2") + self.assertEqual(result, 1) + + def test_read_bit_clear(self) -> None: + """Write bit=0, then read it back.""" + # First set it so we know we're actually clearing + self.client.write("V51.5", 1) + self.assertEqual(self.client.read("V51.5"), 1) + self.client.write("V51.5", 0) + result = self.client.read("V51.5") + self.assertEqual(result, 0) + + def test_read_bit_zero(self) -> None: + """Read bit 0 of byte 0.""" + self.client.write("V60", 0) # clear byte first + self.client.write("V60.0", 1) + self.assertEqual(self.client.read("V60.0"), 1) + # Other bits should be 0 + self.assertEqual(self.client.read("V60.1"), 0) + + def test_read_bit_seven(self) -> None: + """Read bit 7 of a byte.""" + self.client.write("V61", 0) # clear byte + self.client.write("V61.7", 1) + self.assertEqual(self.client.read("V61.7"), 1) + # Byte should be 0x80 + self.assertEqual(self.client.read("V61"), 0x80) + + # -- write tests -------------------------------------------------------- + + def test_write_byte(self) -> None: + """Write a byte and verify.""" + result = self.client.write("V70", 42) + self.assertEqual(result, 0) + self.assertEqual(self.client.read("V70"), 42) + + def test_write_word(self) -> None: + """Write a word and verify.""" + result = self.client.write("VW80", 2000) + self.assertEqual(result, 0) + self.assertEqual(self.client.read("VW80"), 2000) + + def test_write_dword(self) -> None: + """Write a dword and verify.""" + result = self.client.write("VD90", 100000) + self.assertEqual(result, 0) + self.assertEqual(self.client.read("VD90"), 100000) + + def test_write_bit_true(self) -> None: + """Write a bit to True.""" + result = self.client.write("V100.4", 1) + self.assertEqual(result, 0) + self.assertEqual(self.client.read("V100.4"), 1) + + def test_write_bit_false(self) -> None: + """Write a bit to False after setting it.""" + self.client.write("V101.6", 1) + result = self.client.write("V101.6", 0) + self.assertEqual(result, 0) + self.assertEqual(self.client.read("V101.6"), 0) + + def test_write_bit_preserves_other_bits(self) -> None: + """Setting one bit should not disturb other bits in the same byte.""" + # Write 0xFF to the byte + self.client.write("V110", 0xFF) + # Clear bit 3 + self.client.write("V110.3", 0) + # Byte should now be 0xF7 (all bits set except bit 3) + self.assertEqual(self.client.read("V110"), 0xF7) + # Set bit 3 back + self.client.write("V110.3", 1) + self.assertEqual(self.client.read("V110"), 0xFF) + + def test_write_byte_boundary_values(self) -> None: + """Test boundary values: 0 and 255.""" + self.client.write("V120", 0) + self.assertEqual(self.client.read("V120"), 0) + self.client.write("V120", 255) + self.assertEqual(self.client.read("V120"), 255) + + def test_write_word_boundary_values(self) -> None: + """Test word boundary values: max positive and max negative.""" + self.client.write("VW130", 32767) + self.assertEqual(self.client.read("VW130"), 32767) + self.client.write("VW130", -32768) + self.assertEqual(self.client.read("VW130"), -32768) + + def test_write_dword_boundary_values(self) -> None: + """Test dword boundary values.""" + self.client.write("VD140", 2147483647) + self.assertEqual(self.client.read("VD140"), 2147483647) + self.client.write("VD140", -2147483648) + self.assertEqual(self.client.read("VD140"), -2147483648) + + def test_read_write_multiple_addresses(self) -> None: + """Verify different address types can coexist.""" + self.client.write("V200", 0x42) + self.client.write("VW202", 1000) + self.client.write("VD204", 50000) + self.client.write("V208.1", 1) + + self.assertEqual(self.client.read("V200"), 0x42) + self.assertEqual(self.client.read("VW202"), 1000) + self.assertEqual(self.client.read("VD204"), 50000) + self.assertEqual(self.client.read("V208.1"), 1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_partner.py b/tests/test_partner.py index 34c9cb27..570fbca9 100644 --- a/tests/test_partner.py +++ b/tests/test_partner.py @@ -1,10 +1,16 @@ import logging +import socket +import struct +import threading +import time import pytest import unittest as unittest -from snap7.error import error_text +from snap7.connection import ISOTCPConnection +from snap7.error import error_text, S7Error, S7ConnectionError import snap7.partner +from snap7.partner import Partner, PartnerStatus from snap7.type import Parameter logging.basicConfig(level=logging.WARNING) @@ -116,5 +122,614 @@ def test_wait_as_b_send_completion(self) -> None: self.assertRaises(RuntimeError, self.partner.wait_as_b_send_completion) +def _free_port() -> int: + """Return a free TCP port chosen by the OS.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port: int = s.getsockname()[1] + return port + + +# --------------------------------------------------------------------------- +# PDU building / parsing unit tests (no network required) +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPartnerPDU: + """Unit tests for partner PDU building and parsing.""" + + def test_build_partner_data_pdu_small(self) -> None: + p = Partner() + data = b"\x01\x02\x03" + pdu = p._build_partner_data_pdu(data) + assert pdu[0:1] == b"\x32" + assert pdu[1:2] == b"\x07" + assert struct.unpack(">H", pdu[2:4])[0] == len(data) + assert pdu[6:] == data + + def test_build_partner_data_pdu_empty(self) -> None: + p = Partner() + pdu = p._build_partner_data_pdu(b"") + assert pdu[0:1] == b"\x32" + assert struct.unpack(">H", pdu[2:4])[0] == 0 + + def test_build_partner_data_pdu_large(self) -> None: + p = Partner() + data = bytes(range(256)) * 4 # 1024 bytes + pdu = p._build_partner_data_pdu(data) + assert struct.unpack(">H", pdu[2:4])[0] == 1024 + assert pdu[6:] == data + + def test_parse_partner_data_pdu_roundtrip(self) -> None: + p = Partner() + original = b"Hello, Partner!" + pdu = p._build_partner_data_pdu(original) + parsed = p._parse_partner_data_pdu(pdu) + assert parsed == original + + def test_parse_partner_data_pdu_roundtrip_various_sizes(self) -> None: + p = Partner() + for size in [0, 1, 10, 100, 500, 1024]: + data = (bytes(range(256)) * (size // 256 + 1))[:size] + pdu = p._build_partner_data_pdu(data) + assert p._parse_partner_data_pdu(pdu) == data + + def test_parse_partner_data_pdu_too_short(self) -> None: + p = Partner() + with pytest.raises(S7Error, match="too short"): + p._parse_partner_data_pdu(b"\x32\x07\x00") + + def test_build_partner_ack(self) -> None: + p = Partner() + ack = p._build_partner_ack() + assert len(ack) == 6 + assert ack[0:1] == b"\x32" + assert ack[1:2] == b"\x08" + + def test_parse_partner_ack_valid(self) -> None: + p = Partner() + ack = p._build_partner_ack() + p._parse_partner_ack(ack) + + def test_parse_partner_ack_too_short(self) -> None: + p = Partner() + with pytest.raises(S7Error, match="too short"): + p._parse_partner_ack(b"\x32") + + def test_parse_partner_ack_wrong_type(self) -> None: + p = Partner() + bad_ack = struct.pack(">BBHH", 0x32, 0x07, 0x0000, 0x0000) + with pytest.raises(S7Error, match="Expected partner ACK"): + p._parse_partner_ack(bad_ack) + + def test_ack_roundtrip(self) -> None: + p = Partner() + ack = p._build_partner_ack() + p._parse_partner_ack(ack) + + +# --------------------------------------------------------------------------- +# Status, stats, lifecycle tests +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPartnerLifecycle: + """Tests for partner lifecycle, status, and context manager.""" + + def test_initial_status_stopped(self) -> None: + p = Partner() + assert p.get_status().value == PartnerStatus.STOPPED + + def test_status_running_passive(self) -> None: + port = _free_port() + p = Partner(active=False) + p.port = port + try: + p.start_to("127.0.0.1", "", 0x0100, 0x0102) + assert p.running is True + assert p.get_status().value == PartnerStatus.RUNNING + finally: + p.stop() + + def test_stop_idempotent(self) -> None: + p = Partner() + p.stop() + p.stop() + + def test_destroy_returns_zero(self) -> None: + p = Partner() + assert p.destroy() == 0 + + def test_context_manager(self) -> None: + port = _free_port() + with Partner(active=False) as p: + p.port = port + p.start_to("127.0.0.1", "", 0x0100, 0x0102) + assert p.running is True + assert p.running is False + + def test_del_cleanup(self) -> None: + port = _free_port() + p = Partner(active=False) + p.port = port + p.start_to("127.0.0.1", "", 0x0100, 0x0102) + assert p.running is True + p.__del__() + assert p.running is False + + def test_create_noop(self) -> None: + p = Partner() + p.create(active=True) + + def test_get_stats_initial(self) -> None: + p = Partner() + sent, recv, s_err, r_err = p.get_stats() + assert sent.value == 0 + assert recv.value == 0 + assert s_err.value == 0 + assert r_err.value == 0 + + def test_get_times_initial(self) -> None: + p = Partner() + send_t, recv_t = p.get_times() + assert send_t.value == 0 + assert recv_t.value == 0 + + def test_get_last_error_initial(self) -> None: + p = Partner() + assert p.get_last_error().value == 0 + + +# --------------------------------------------------------------------------- +# Send / recv data buffer tests +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPartnerSendRecvBuffers: + """Tests for set_send_data / get_recv_data and error paths.""" + + def test_set_send_data_and_retrieve(self) -> None: + p = Partner() + assert p._send_data is None + p.set_send_data(b"test") + assert p._send_data == b"test" + + def test_get_recv_data_initially_none(self) -> None: + p = Partner() + assert p.get_recv_data() is None + + def test_b_send_no_data(self) -> None: + p = Partner() + assert p.b_send() == -1 + + def test_b_send_not_connected(self) -> None: + p = Partner() + p.set_send_data(b"data") + with pytest.raises(S7ConnectionError, match="Not connected"): + p.b_send() + + def test_b_recv_not_connected(self) -> None: + p = Partner() + result = p.b_recv() + assert result == -1 + assert p.get_recv_data() is None + + def test_as_b_send_no_data(self) -> None: + p = Partner() + assert p.as_b_send() == -1 + + def test_as_b_send_not_connected(self) -> None: + p = Partner() + p.set_send_data(b"data") + result = p.as_b_send() + assert result == -1 + + def test_check_as_b_recv_completion_empty(self) -> None: + p = Partner() + assert p.check_as_b_recv_completion() == 1 + + def test_check_as_b_recv_completion_with_data(self) -> None: + p = Partner() + p._async_recv_queue.put(b"queued data") + assert p.check_as_b_recv_completion() == 0 + assert p._recv_data == b"queued data" + + def test_check_as_b_send_completion_not_in_progress(self) -> None: + p = Partner() + status, result = p.check_as_b_send_completion() + assert status == "job complete" + + def test_check_as_b_send_completion_in_progress(self) -> None: + p = Partner() + p._async_send_in_progress = True + status, result = p.check_as_b_send_completion() + assert status == "job in progress" + + def test_wait_as_b_send_no_operation(self) -> None: + p = Partner() + with pytest.raises(RuntimeError, match="No async send"): + p.wait_as_b_send_completion() + + def test_wait_as_b_send_timeout(self) -> None: + p = Partner() + p._async_send_in_progress = True + result = p.wait_as_b_send_completion(timeout=50) + assert result == -1 + + def test_wait_as_b_send_completes(self) -> None: + p = Partner() + p._async_send_in_progress = True + p._async_send_result = 0 + + def clear_flag() -> None: + time.sleep(0.05) + p._async_send_in_progress = False + + t = threading.Thread(target=clear_flag) + t.start() + result = p.wait_as_b_send_completion(timeout=2000) + t.join() + assert result == 0 + + +# --------------------------------------------------------------------------- +# Parameter tests +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPartnerParams: + """Tests for get_param / set_param.""" + + def test_get_param_unsupported(self) -> None: + p = Partner() + with pytest.raises(RuntimeError, match="not supported"): + p.get_param(Parameter.MaxClients) + + def test_set_param_remote_port_raises(self) -> None: + p = Partner() + with pytest.raises(RuntimeError, match="Cannot set"): + p.set_param(Parameter.RemotePort, 1234) + + def test_set_param_local_port(self) -> None: + p = Partner() + p.set_param(Parameter.LocalPort, 5555) + assert p.local_port == 5555 + + def test_set_param_returns_zero(self) -> None: + p = Partner() + assert p.set_param(Parameter.PingTimeout, 999) == 0 + + def test_set_recv_callback_returns_zero(self) -> None: + p = Partner() + assert p.set_recv_callback() == 0 + + def test_set_send_callback_returns_zero(self) -> None: + p = Partner() + assert p.set_send_callback() == 0 + + +# --------------------------------------------------------------------------- +# Dual-partner integration tests using raw socket pairing +# --------------------------------------------------------------------------- + + +def _make_socket_pair() -> tuple[socket.socket, socket.socket]: + """Create a connected TCP socket pair via a temporary server socket.""" + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind(("127.0.0.1", 0)) + srv.listen(1) + port = srv.getsockname()[1] + + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client.connect(("127.0.0.1", port)) + server_side, _ = srv.accept() + srv.close() + return client, server_side + + +def _wire_partner(partner: Partner, sock: socket.socket) -> None: + """Wire a connected socket into a Partner so it appears connected.""" + conn = ISOTCPConnection(host="127.0.0.1", port=0, local_tsap=0x0100, remote_tsap=0x0102) + conn.socket = sock + conn.connected = True + partner._socket = sock + partner._connection = conn + partner.connected = True + partner.running = True + + +@pytest.mark.partner +class TestDualPartner: + """Integration tests using two Partner instances exchanging data over sockets.""" + + def test_active_to_passive_send(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + payload = b"Hello from A" + pa.set_send_data(payload) + + errors: list[Exception] = [] + + def do_send() -> None: + try: + pa.b_send() + except Exception as e: + errors.append(e) + + t = threading.Thread(target=do_send) + t.start() + + assert pb.b_recv() == 0 + t.join(timeout=3.0) + assert pb.get_recv_data() == payload + assert not errors + finally: + pa.stop() + pb.stop() + + def test_passive_to_active_send(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + payload = b"Hello from B" + pb.set_send_data(payload) + + errors: list[Exception] = [] + + def do_send() -> None: + try: + pb.b_send() + except Exception as e: + errors.append(e) + + t = threading.Thread(target=do_send) + t.start() + + assert pa.b_recv() == 0 + t.join(timeout=3.0) + assert pa.get_recv_data() == payload + assert not errors + finally: + pa.stop() + pb.stop() + + def test_bidirectional_exchange(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + errors: list[Exception] = [] + + # A -> B + pa.set_send_data(b"A->B") + + def send_a() -> None: + try: + pa.b_send() + except Exception as e: + errors.append(e) + + t1 = threading.Thread(target=send_a) + t1.start() + pb.b_recv() + t1.join(timeout=3.0) + assert pb.get_recv_data() == b"A->B" + + # B -> A + pb.set_send_data(b"B->A") + + def send_b() -> None: + try: + pb.b_send() + except Exception as e: + errors.append(e) + + t2 = threading.Thread(target=send_b) + t2.start() + pa.b_recv() + t2.join(timeout=3.0) + assert pa.get_recv_data() == b"B->A" + assert not errors + finally: + pa.stop() + pb.stop() + + def test_various_payload_sizes(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + for size in [1, 10, 100, 480]: + payload = (bytes(range(256)) * (size // 256 + 1))[:size] + pa.set_send_data(payload) + errors: list[Exception] = [] + + def do_send() -> None: + try: + pa.b_send() + except Exception as e: + errors.append(e) + + t = threading.Thread(target=do_send) + t.start() + pb.b_recv() + t.join(timeout=3.0) + assert pb.get_recv_data() == payload, f"Failed for size {size}" + assert not errors + finally: + pa.stop() + pb.stop() + + def test_stats_updated_after_exchange(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + payload = b"stats test" + pa.set_send_data(payload) + + def do_send() -> None: + pa.b_send() + + t = threading.Thread(target=do_send) + t.start() + pb.b_recv() + t.join(timeout=3.0) + + sent, _, s_err, _ = pa.get_stats() + assert sent.value == len(payload) + assert s_err.value == 0 + + _, recv, _, r_err = pb.get_stats() + assert recv.value == len(payload) + assert r_err.value == 0 + + send_t, _ = pa.get_times() + assert send_t.value >= 0 + _, recv_t = pb.get_times() + assert recv_t.value >= 0 + finally: + pa.stop() + pb.stop() + + def test_status_connected(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + assert pa.get_status().value == PartnerStatus.CONNECTED + assert pb.get_status().value == PartnerStatus.CONNECTED + finally: + pa.stop() + pb.stop() + + def test_status_after_stop(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + pa.stop() + assert pa.get_status().value == PartnerStatus.STOPPED + finally: + pa.stop() + pb.stop() + + def test_recv_callback_fires(self) -> None: + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + + received_data: list[bytes] = [] + pb._recv_callback = lambda data: received_data.append(data) + + payload = b"callback test" + pa.set_send_data(payload) + + def do_send() -> None: + pa.b_send() + + t = threading.Thread(target=do_send) + t.start() + pb.b_recv() + t.join(timeout=3.0) + + assert len(received_data) == 1 + assert received_data[0] == payload + finally: + pa.stop() + pb.stop() + + def test_b_recv_error_returns_negative(self) -> None: + """b_recv returns -1 on receive error when no data arrives.""" + sock_a, sock_b = _make_socket_pair() + pa, pb = Partner(), Partner() + try: + _wire_partner(pa, sock_a) + _wire_partner(pb, sock_b) + # Close sender side so receiver gets an error + sock_a.close() + result = pb.b_recv() + assert result == -1 + finally: + pa.stop() + pb.stop() + + +# --------------------------------------------------------------------------- +# Passive partner accept/listen tests +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPassivePartner: + """Tests for passive partner listening and accept behavior.""" + + def test_accept_connection_server_socket_none(self) -> None: + """_accept_connection returns immediately if server socket is None.""" + p = Partner(active=False) + p._server_socket = None + p._accept_connection() # Should not raise + + +# --------------------------------------------------------------------------- +# Active partner connection error tests +# --------------------------------------------------------------------------- + + +@pytest.mark.partner +class TestPartnerConnectionErrors: + """Tests for connection error paths.""" + + def test_active_no_remote_ip(self) -> None: + p = Partner(active=True) + with pytest.raises(S7ConnectionError, match="Remote IP"): + p.start_to("127.0.0.1", "", 0x0100, 0x0102) + + def test_active_connect_refused(self) -> None: + p = Partner(active=True) + port = _free_port() + p.port = port + with pytest.raises(S7ConnectionError): + p.start_to("127.0.0.1", "127.0.0.1", 0x0100, 0x0102) + + def test_b_send_increments_send_errors(self) -> None: + p = Partner() + p.set_send_data(b"data") + try: + p.b_send() + except S7ConnectionError: + pass + _, _, s_err, _ = p.get_stats() + assert s_err.value == 1 + + def test_b_recv_increments_recv_errors(self) -> None: + p = Partner() + p.b_recv() + _, _, _, r_err = p.get_stats() + assert r_err.value == 1 + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_s7commplus_codec.py b/tests/test_s7commplus_codec.py new file mode 100644 index 00000000..9b03881e --- /dev/null +++ b/tests/test_s7commplus_codec.py @@ -0,0 +1,629 @@ +"""Tests for S7CommPlus codec (header encoding, typed values, payload builders).""" + +import struct +import pytest + +from snap7.s7commplus.codec import ( + encode_header, + decode_header, + encode_request_header, + decode_response_header, + encode_typed_value, + encode_uint8, + decode_uint8, + encode_uint16, + decode_uint16, + encode_uint32, + decode_uint32, + encode_uint64, + decode_uint64, + encode_int16, + decode_int16, + encode_int32, + decode_int32, + encode_int64, + decode_int64, + encode_float32, + decode_float32, + encode_float64, + decode_float64, + encode_wstring, + decode_wstring, + encode_item_address, + encode_pvalue_blob, + decode_pvalue_to_bytes, + encode_object_qualifier, + _pvalue_element_size, +) +from snap7.s7commplus.protocol import PROTOCOL_ID, DataType, Opcode, FunctionCode, Ids +from snap7.s7commplus.vlq import encode_uint32_vlq, encode_int32_vlq, encode_uint64_vlq, encode_int64_vlq + + +class TestFrameHeader: + def test_encode_header(self) -> None: + header = encode_header(version=0x03, data_length=100) + assert len(header) == 4 + assert header[0] == PROTOCOL_ID + assert header[1] == 0x03 + assert struct.unpack(">H", header[2:4])[0] == 100 + + def test_decode_header(self) -> None: + header = encode_header(version=0x03, data_length=256) + version, length, consumed = decode_header(header) + assert version == 0x03 + assert length == 256 + assert consumed == 4 + + def test_decode_header_with_offset(self) -> None: + prefix = bytes([0x00, 0x00]) + header = encode_header(version=0x01, data_length=42) + version, length, consumed = decode_header(prefix + header, offset=2) + assert version == 0x01 + assert length == 42 + + def test_decode_header_wrong_protocol_id(self) -> None: + bad_header = bytes([0x32, 0x03, 0x00, 0x10]) # S7comm ID, not S7CommPlus + with pytest.raises(ValueError, match="Invalid protocol ID"): + decode_header(bad_header) + + def test_decode_header_too_short(self) -> None: + with pytest.raises(ValueError, match="Not enough data"): + decode_header(bytes([0x72, 0x03])) + + +class TestRequestHeader: + def test_encode_request_header(self) -> None: + header = encode_request_header( + function_code=FunctionCode.CREATE_OBJECT, + sequence_number=1, + session_id=0, + transport_flags=0x36, + ) + assert len(header) == 14 + assert header[0] == Opcode.REQUEST + + def test_roundtrip_request_response_header(self) -> None: + header = encode_request_header( + function_code=FunctionCode.GET_MULTI_VARIABLES, + sequence_number=42, + session_id=0x12345678, + ) + result = decode_response_header(header) + assert result["function_code"] == FunctionCode.GET_MULTI_VARIABLES + assert result["sequence_number"] == 42 + assert result["session_id"] == 0x12345678 + assert result["bytes_consumed"] == 14 + + def test_decode_response_header_too_short(self) -> None: + with pytest.raises(ValueError, match="Not enough data"): + decode_response_header(bytes(10)) + + +class TestFixedWidth: + def test_uint8_roundtrip(self) -> None: + for val in [0, 1, 127, 255]: + encoded = encode_uint8(val) + decoded, consumed = decode_uint8(encoded) + assert decoded == val + assert consumed == 1 + + def test_uint16_roundtrip(self) -> None: + for val in [0, 1, 0xFF, 0xFFFF]: + encoded = encode_uint16(val) + decoded, consumed = decode_uint16(encoded) + assert decoded == val + assert consumed == 2 + + def test_uint32_roundtrip(self) -> None: + for val in [0, 1, 0xFFFF, 0xFFFFFFFF]: + encoded = encode_uint32(val) + decoded, consumed = decode_uint32(encoded) + assert decoded == val + assert consumed == 4 + + def test_uint64_roundtrip(self) -> None: + for val in [0, 1, 0xFFFFFFFF, 0xFFFFFFFFFFFFFFFF]: + encoded = encode_uint64(val) + decoded, consumed = decode_uint64(encoded) + assert decoded == val + assert consumed == 8 + + def test_int16_roundtrip(self) -> None: + for val in [0, 1, -1, -32768, 32767]: + encoded = encode_int16(val) + decoded, consumed = decode_int16(encoded) + assert decoded == val + assert consumed == 2 + + def test_int32_roundtrip(self) -> None: + for val in [0, 1, -1, -2147483648, 2147483647]: + encoded = encode_int32(val) + decoded, consumed = decode_int32(encoded) + assert decoded == val + assert consumed == 4 + + def test_int64_roundtrip(self) -> None: + for val in [0, 1, -1, -(2**63), 2**63 - 1]: + encoded = encode_int64(val) + decoded, consumed = decode_int64(encoded) + assert decoded == val + assert consumed == 8 + + def test_float32_roundtrip(self) -> None: + for val in [0.0, 1.0, -1.0, 3.14]: + encoded = encode_float32(val) + decoded, consumed = decode_float32(encoded) + assert abs(decoded - val) < 1e-6 + assert consumed == 4 + + def test_float64_roundtrip(self) -> None: + for val in [0.0, 1.0, -1.0, 3.141592653589793]: + encoded = encode_float64(val) + decoded, consumed = decode_float64(encoded) + assert decoded == val + assert consumed == 8 + + def test_uint8_with_offset(self) -> None: + data = bytes([0xFF, 42, 0xFF]) + decoded, consumed = decode_uint8(data, offset=1) + assert decoded == 42 + + def test_uint64_with_offset(self) -> None: + prefix = bytes(4) + data = prefix + encode_uint64(0x123456789ABCDEF0) + decoded, consumed = decode_uint64(data, offset=4) + assert decoded == 0x123456789ABCDEF0 + + def test_int16_with_offset(self) -> None: + prefix = bytes(3) + data = prefix + encode_int16(-1000) + decoded, consumed = decode_int16(data, offset=3) + assert decoded == -1000 + + def test_int32_with_offset(self) -> None: + prefix = bytes(2) + data = prefix + encode_int32(-100000) + decoded, consumed = decode_int32(data, offset=2) + assert decoded == -100000 + + def test_int64_with_offset(self) -> None: + prefix = bytes(5) + data = prefix + encode_int64(-(2**50)) + decoded, consumed = decode_int64(data, offset=5) + assert decoded == -(2**50) + + def test_float32_with_offset(self) -> None: + prefix = bytes(1) + data = prefix + encode_float32(2.5) + decoded, consumed = decode_float32(data, offset=1) + assert abs(decoded - 2.5) < 1e-6 + + def test_float64_with_offset(self) -> None: + prefix = bytes(3) + data = prefix + encode_float64(1.23456789) + decoded, consumed = decode_float64(data, offset=3) + assert decoded == 1.23456789 + + +class TestWString: + def test_ascii(self) -> None: + encoded = encode_wstring("hello") + decoded, consumed = decode_wstring(encoded, 0, len(encoded)) + assert decoded == "hello" + + def test_unicode(self) -> None: + encoded = encode_wstring("Ölprüfung") + decoded, consumed = decode_wstring(encoded, 0, len(encoded)) + assert decoded == "Ölprüfung" + + def test_empty(self) -> None: + encoded = encode_wstring("") + assert encoded == b"" + decoded, consumed = decode_wstring(encoded, 0, 0) + assert decoded == "" + + +class TestTypedValue: + def test_null(self) -> None: + encoded = encode_typed_value(DataType.NULL, None) + assert encoded == bytes([DataType.NULL]) + + def test_bool_true(self) -> None: + encoded = encode_typed_value(DataType.BOOL, True) + assert encoded == bytes([DataType.BOOL, 0x01]) + + def test_bool_false(self) -> None: + encoded = encode_typed_value(DataType.BOOL, False) + assert encoded == bytes([DataType.BOOL, 0x00]) + + def test_usint(self) -> None: + encoded = encode_typed_value(DataType.USINT, 42) + assert encoded == bytes([DataType.USINT, 42]) + + def test_byte(self) -> None: + encoded = encode_typed_value(DataType.BYTE, 0xAB) + assert encoded == bytes([DataType.BYTE, 0xAB]) + + def test_uint(self) -> None: + encoded = encode_typed_value(DataType.UINT, 0x1234) + assert encoded == bytes([DataType.UINT]) + struct.pack(">H", 0x1234) + + def test_word(self) -> None: + encoded = encode_typed_value(DataType.WORD, 0xBEEF) + assert encoded == bytes([DataType.WORD]) + struct.pack(">H", 0xBEEF) + + def test_udint(self) -> None: + encoded = encode_typed_value(DataType.UDINT, 100000) + assert encoded[0] == DataType.UDINT + # Rest is VLQ-encoded + assert len(encoded) > 1 + + def test_dword(self) -> None: + encoded = encode_typed_value(DataType.DWORD, 0xDEADBEEF) + assert encoded[0] == DataType.DWORD + + def test_ulint(self) -> None: + encoded = encode_typed_value(DataType.ULINT, 2**40) + assert encoded[0] == DataType.ULINT + + def test_lword(self) -> None: + encoded = encode_typed_value(DataType.LWORD, 0xCAFEBABE12345678) + assert encoded[0] == DataType.LWORD + + def test_sint(self) -> None: + encoded = encode_typed_value(DataType.SINT, -42) + assert encoded == bytes([DataType.SINT]) + struct.pack(">b", -42) + + def test_int(self) -> None: + encoded = encode_typed_value(DataType.INT, -1000) + assert encoded == bytes([DataType.INT]) + struct.pack(">h", -1000) + + def test_dint(self) -> None: + encoded = encode_typed_value(DataType.DINT, -100000) + assert encoded[0] == DataType.DINT + + def test_lint(self) -> None: + encoded = encode_typed_value(DataType.LINT, -(2**40)) + assert encoded[0] == DataType.LINT + + def test_real(self) -> None: + encoded = encode_typed_value(DataType.REAL, 1.0) + assert encoded == bytes([DataType.REAL]) + struct.pack(">f", 1.0) + + def test_lreal(self) -> None: + encoded = encode_typed_value(DataType.LREAL, 3.14) + assert encoded == bytes([DataType.LREAL]) + struct.pack(">d", 3.14) + + def test_timestamp(self) -> None: + ts = 0x0001020304050607 + encoded = encode_typed_value(DataType.TIMESTAMP, ts) + assert encoded == bytes([DataType.TIMESTAMP]) + struct.pack(">Q", ts) + + def test_timespan(self) -> None: + encoded = encode_typed_value(DataType.TIMESPAN, -5000) + assert encoded[0] == DataType.TIMESPAN + + def test_rid(self) -> None: + encoded = encode_typed_value(DataType.RID, 0x12345678) + assert encoded == bytes([DataType.RID]) + struct.pack(">I", 0x12345678) + + def test_aid(self) -> None: + encoded = encode_typed_value(DataType.AID, 306) + assert encoded[0] == DataType.AID + + def test_wstring(self) -> None: + encoded = encode_typed_value(DataType.WSTRING, "test") + assert encoded[0] == DataType.WSTRING + assert b"test" in encoded + + def test_blob(self) -> None: + data = bytes([1, 2, 3, 4]) + encoded = encode_typed_value(DataType.BLOB, data) + assert encoded[0] == DataType.BLOB + assert encoded.endswith(data) + + def test_unsupported_type(self) -> None: + with pytest.raises(ValueError, match="Unsupported DataType"): + encode_typed_value(0xFF, None) + + +class TestItemAddress: + def test_basic_db_access(self) -> None: + addr_bytes, field_count = encode_item_address( + access_area=Ids.DB_ACCESS_AREA_BASE + 1, + access_sub_area=Ids.DB_VALUE_ACTUAL, + ) + assert isinstance(addr_bytes, bytes) + assert len(addr_bytes) > 0 + # No LIDs, so field_count = 4 (SymbolCrc + AccessArea + NumLIDs + AccessSubArea) + assert field_count == 4 + + def test_with_lids(self) -> None: + addr_bytes, field_count = encode_item_address( + access_area=Ids.DB_ACCESS_AREA_BASE + 1, + access_sub_area=Ids.DB_VALUE_ACTUAL, + lids=[1, 4], + ) + assert field_count == 6 # 4 + 2 LIDs + + def test_custom_symbol_crc(self) -> None: + addr_bytes, field_count = encode_item_address( + access_area=Ids.DB_ACCESS_AREA_BASE + 1, + access_sub_area=Ids.DB_VALUE_ACTUAL, + symbol_crc=0x1234, + ) + # First bytes should be VLQ(0x1234) which is non-zero + assert addr_bytes[0] != 0 + assert field_count == 4 + + +class TestPValueBlob: + def test_basic_blob(self) -> None: + data = bytes([1, 2, 3, 4]) + encoded = encode_pvalue_blob(data) + assert encoded[0] == 0x00 # flags + assert encoded[1] == DataType.BLOB + assert encoded.endswith(data) + + def test_empty_blob(self) -> None: + encoded = encode_pvalue_blob(b"") + assert encoded[0] == 0x00 + assert encoded[1] == DataType.BLOB + + def test_roundtrip_with_decode(self) -> None: + data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + encoded = encode_pvalue_blob(data) + decoded, consumed = decode_pvalue_to_bytes(encoded, 0) + assert decoded == data + assert consumed == len(encoded) + + +class TestDecodePValue: + """Test decode_pvalue_to_bytes for all scalar and array type branches.""" + + def test_null(self) -> None: + data = bytes([0x00, DataType.NULL]) + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == b"" + assert consumed == 2 + + def test_bool_true(self) -> None: + data = bytes([0x00, DataType.BOOL, 0x01]) + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([0x01]) + assert consumed == 3 + + def test_bool_false(self) -> None: + data = bytes([0x00, DataType.BOOL, 0x00]) + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([0x00]) + + def test_usint(self) -> None: + data = bytes([0x00, DataType.USINT, 42]) + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([42]) + assert consumed == 3 + + def test_byte(self) -> None: + data = bytes([0x00, DataType.BYTE, 0xAB]) + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([0xAB]) + + def test_sint(self) -> None: + data = bytes([0x00, DataType.SINT, 0xD6]) # -42 as unsigned byte + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([0xD6]) + + def test_uint(self) -> None: + raw = struct.pack(">H", 0x1234) + data = bytes([0x00, DataType.UINT]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_word(self) -> None: + raw = struct.pack(">H", 0xBEEF) + data = bytes([0x00, DataType.WORD]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_int(self) -> None: + raw = struct.pack(">H", 0xFC18) # -1000 as unsigned + data = bytes([0x00, DataType.INT]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_udint(self) -> None: + vlq = encode_uint32_vlq(100000) + data = bytes([0x00, DataType.UDINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">I", 100000) + + def test_dword(self) -> None: + vlq = encode_uint32_vlq(0xDEADBEEF) + data = bytes([0x00, DataType.DWORD]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">I", 0xDEADBEEF) + + def test_dint_positive(self) -> None: + vlq = encode_int32_vlq(12345) + data = bytes([0x00, DataType.DINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">i", 12345) + + def test_dint_negative(self) -> None: + vlq = encode_int32_vlq(-100000) + data = bytes([0x00, DataType.DINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">i", -100000) + + def test_real(self) -> None: + raw = struct.pack(">f", 3.14) + data = bytes([0x00, DataType.REAL]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_lreal(self) -> None: + raw = struct.pack(">d", 2.718281828) + data = bytes([0x00, DataType.LREAL]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_ulint(self) -> None: + vlq = encode_uint64_vlq(2**40) + data = bytes([0x00, DataType.ULINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">Q", 2**40) + + def test_lword(self) -> None: + vlq = encode_uint64_vlq(0xCAFEBABE12345678) + data = bytes([0x00, DataType.LWORD]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">Q", 0xCAFEBABE12345678) + + def test_lint_positive(self) -> None: + vlq = encode_int64_vlq(2**50) + data = bytes([0x00, DataType.LINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">q", 2**50) + + def test_lint_negative(self) -> None: + vlq = encode_int64_vlq(-(2**40)) + data = bytes([0x00, DataType.LINT]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">q", -(2**40)) + + def test_timestamp(self) -> None: + ts = 0x0001020304050607 + raw = struct.pack(">Q", ts) + data = bytes([0x00, DataType.TIMESTAMP]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + assert consumed == 10 # 2 header + 8 bytes + + def test_timespan_positive(self) -> None: + vlq = encode_int64_vlq(5000000) + data = bytes([0x00, DataType.TIMESPAN]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">q", 5000000) + + def test_timespan_negative(self) -> None: + vlq = encode_int64_vlq(-5000000) + data = bytes([0x00, DataType.TIMESPAN]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">q", -5000000) + + def test_rid(self) -> None: + raw = struct.pack(">I", 0x12345678) + data = bytes([0x00, DataType.RID]) + raw + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == raw + + def test_aid(self) -> None: + vlq = encode_uint32_vlq(306) + data = bytes([0x00, DataType.AID]) + vlq + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == struct.pack(">I", 306) + + def test_blob(self) -> None: + blob_data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + vlq_len = encode_uint32_vlq(len(blob_data)) + data = bytes([0x00, DataType.BLOB]) + vlq_len + blob_data + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == blob_data + + def test_wstring(self) -> None: + text = "hello".encode("utf-8") + vlq_len = encode_uint32_vlq(len(text)) + data = bytes([0x00, DataType.WSTRING]) + vlq_len + text + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == text + + def test_struct_nested(self) -> None: + # Struct with 2 USINT elements + vlq_count = encode_uint32_vlq(2) + elem1 = bytes([0x00, DataType.USINT, 0x0A]) + elem2 = bytes([0x00, DataType.USINT, 0x14]) + data = bytes([0x00, DataType.STRUCT]) + vlq_count + elem1 + elem2 + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == bytes([0x0A, 0x14]) + + def test_unsupported_type(self) -> None: + data = bytes([0x00, 0xFF]) + with pytest.raises(ValueError, match="Unsupported PValue datatype"): + decode_pvalue_to_bytes(data, 0) + + def test_too_short_header(self) -> None: + with pytest.raises(ValueError, match="Not enough data for PValue header"): + decode_pvalue_to_bytes(bytes([0x00]), 0) + + def test_with_offset(self) -> None: + prefix = bytes([0xFF, 0xFF, 0xFF]) + pvalue = bytes([0x00, DataType.USINT, 42]) + result, consumed = decode_pvalue_to_bytes(prefix + pvalue, 3) + assert result == bytes([42]) + + # -- Array tests -- + + def test_array_fixed_size_usint(self) -> None: + count_vlq = encode_uint32_vlq(3) + elements = bytes([10, 20, 30]) + data = bytes([0x10, DataType.USINT]) + count_vlq + elements + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == elements + + def test_array_fixed_size_uint(self) -> None: + count_vlq = encode_uint32_vlq(2) + elements = struct.pack(">HH", 1000, 2000) + data = bytes([0x10, DataType.UINT]) + count_vlq + elements + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == elements + + def test_array_fixed_size_real(self) -> None: + count_vlq = encode_uint32_vlq(2) + elements = struct.pack(">ff", 1.0, 2.0) + data = bytes([0x10, DataType.REAL]) + count_vlq + elements + result, consumed = decode_pvalue_to_bytes(data, 0) + assert result == elements + + def test_array_variable_length_udint(self) -> None: + # Variable-length array (VLQ-encoded elements) + count_vlq = encode_uint32_vlq(2) + elem1 = encode_uint32_vlq(100) + elem2 = encode_uint32_vlq(200) + data = bytes([0x10, DataType.UDINT]) + count_vlq + elem1 + elem2 + result, consumed = decode_pvalue_to_bytes(data, 0) + # Result re-encodes each element as VLQ + assert result == encode_uint32_vlq(100) + encode_uint32_vlq(200) + + +class TestPValueElementSize: + def test_single_byte_types(self) -> None: + for dt in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + assert _pvalue_element_size(dt) == 1 + + def test_two_byte_types(self) -> None: + for dt in (DataType.UINT, DataType.WORD, DataType.INT): + assert _pvalue_element_size(dt) == 2 + + def test_four_byte_types(self) -> None: + assert _pvalue_element_size(DataType.REAL) == 4 + assert _pvalue_element_size(DataType.RID) == 4 + + def test_eight_byte_types(self) -> None: + assert _pvalue_element_size(DataType.LREAL) == 8 + assert _pvalue_element_size(DataType.TIMESTAMP) == 8 + + def test_variable_length_types(self) -> None: + for dt in (DataType.UDINT, DataType.DWORD, DataType.BLOB, DataType.WSTRING, DataType.STRUCT): + assert _pvalue_element_size(dt) == 0 + + +class TestObjectQualifier: + def test_encode(self) -> None: + result = encode_object_qualifier() + assert isinstance(result, bytes) + assert len(result) > 0 + # Starts with ObjectQualifier ID (1256) as uint32 big-endian + assert result[:4] == struct.pack(">I", Ids.OBJECT_QUALIFIER) + # Ends with null terminator + assert result[-1] == 0x00 diff --git a/tests/test_s7commplus_e2e.py b/tests/test_s7commplus_e2e.py new file mode 100644 index 00000000..f8c8bf0d --- /dev/null +++ b/tests/test_s7commplus_e2e.py @@ -0,0 +1,607 @@ +"""End-to-end tests for S7CommPlus client against a real Siemens S7-1200/1500 PLC. + +These tests require a real PLC connection. Run with: + + pytest tests/test_s7commplus_e2e.py --e2e --plc-ip=YOUR_PLC_IP + +Available options: + --e2e Enable e2e tests (required) + --plc-ip PLC IP address (default: 10.10.10.100) + --plc-rack PLC rack number (default: 0) + --plc-slot PLC slot number (default: 1) + --plc-port PLC TCP port (default: 102) + --plc-db-read Read-only DB number (default: 1) + --plc-db-write Read-write DB number (default: 2) + +The PLC needs two data blocks configured with the same layout as the +regular S7 e2e tests: + +DB1 "Read_only" - Read-only data block with predefined values: + int1: Int = 10 (offset 0, 2 bytes) + int2: Int = 255 (offset 2, 2 bytes) + float1: Real = 123.45 (offset 4, 4 bytes) + float2: Real = 543.21 (offset 8, 4 bytes) + byte1: Byte = 0x0F (offset 12, 1 byte) + byte2: Byte = 0xF0 (offset 13, 1 byte) + word1: Word = 0xABCD (offset 14, 2 bytes) + word2: Word = 0x1234 (offset 16, 2 bytes) + dword1: DWord = 0x12345678 (offset 18, 4 bytes) + dword2: DWord = 0x89ABCDEF (offset 22, 4 bytes) + dint1: DInt = 2147483647 (offset 26, 4 bytes) + dint2: DInt = 42 (offset 30, 4 bytes) + char1: Char = 'F' (offset 34, 1 byte) + char2: Char = '-' (offset 35, 1 byte) + bool0-bool7: Bool (offset 36, 1 byte, value: 0x01) + +DB2 "Data_block_2" - Read/write data block with same structure. + +Note: S7CommPlus targets S7-1200/1500 PLCs, which use optimized block +access. Ensure data blocks have "Optimized block access" disabled in +TIA Portal so that byte offsets match the layout above. +""" + +import logging +import os +import struct +import unittest + +import pytest + +from snap7.s7commplus.client import S7CommPlusClient + +# Enable DEBUG logging for all s7commplus modules so we get full hex dumps +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(name)s %(levelname)s %(message)s", +) +for _mod in ["snap7.s7commplus.client", "snap7.s7commplus.connection", "snap7.connection"]: + logging.getLogger(_mod).setLevel(logging.DEBUG) + +# ============================================================================= +# PLC Connection Configuration +# These can be overridden via pytest command line options or environment variables +# ============================================================================= +PLC_IP = os.environ.get("PLC_IP", "10.10.10.100") +PLC_RACK = int(os.environ.get("PLC_RACK", "0")) +PLC_SLOT = int(os.environ.get("PLC_SLOT", "1")) +PLC_PORT = int(os.environ.get("PLC_PORT", "102")) + +# Data block numbers +DB_READ_ONLY = int(os.environ.get("PLC_DB_READ", "1")) +DB_READ_WRITE = int(os.environ.get("PLC_DB_WRITE", "2")) + + +# ============================================================================= +# DB Structure - Byte offsets for each variable (same as regular S7 e2e tests) +# ============================================================================= +OFFSET_INT1 = 0 # Int (2 bytes) +OFFSET_INT2 = 2 # Int (2 bytes) +OFFSET_FLOAT1 = 4 # Real (4 bytes) +OFFSET_FLOAT2 = 8 # Real (4 bytes) +OFFSET_BYTE1 = 12 # Byte (1 byte) +OFFSET_BYTE2 = 13 # Byte (1 byte) +OFFSET_WORD1 = 14 # Word (2 bytes) +OFFSET_WORD2 = 16 # Word (2 bytes) +OFFSET_DWORD1 = 18 # DWord (4 bytes) +OFFSET_DWORD2 = 22 # DWord (4 bytes) +OFFSET_DINT1 = 26 # DInt (4 bytes) +OFFSET_DINT2 = 30 # DInt (4 bytes) +OFFSET_CHAR1 = 34 # Char (1 byte) +OFFSET_CHAR2 = 35 # Char (1 byte) +OFFSET_BOOLS = 36 # 8 Bools packed in 1 byte + +# Total size of DB +DB_SIZE = 37 + +# ============================================================================= +# Expected values from DB1 "Read_only" +# ============================================================================= +EXPECTED_INT1 = 10 +EXPECTED_INT2 = 255 +EXPECTED_FLOAT1 = 123.45 +EXPECTED_FLOAT2 = 543.21 +EXPECTED_BYTE1 = 0x0F +EXPECTED_BYTE2 = 0xF0 +EXPECTED_WORD1 = 0xABCD +EXPECTED_WORD2 = 0x1234 +EXPECTED_DWORD1 = 0x12345678 +EXPECTED_DWORD2 = 0x89ABCDEF +EXPECTED_DINT1 = 2147483647 +EXPECTED_DINT2 = 42 +EXPECTED_CHAR1 = "F" +EXPECTED_CHAR2 = "-" +EXPECTED_BOOL0 = True +EXPECTED_BOOL1 = False + + +# ============================================================================= +# Test Classes +# ============================================================================= + + +@pytest.mark.e2e +class TestS7CommPlusConnection(unittest.TestCase): + """Tests for S7CommPlus connection.""" + + def test_connect_disconnect(self) -> None: + """Test connect() and disconnect().""" + client = S7CommPlusClient() + client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + self.assertTrue(client.connected) + self.assertGreater(client.protocol_version, 0) + self.assertGreater(client.session_id, 0) + client.disconnect() + self.assertFalse(client.connected) + + def test_context_manager(self) -> None: + """Test S7CommPlusClient as context manager.""" + with S7CommPlusClient() as client: + client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + self.assertTrue(client.connected) + # After exiting context, client should be disconnected + + def test_properties_before_connect(self) -> None: + """Test properties return defaults before connection.""" + client = S7CommPlusClient() + self.assertFalse(client.connected) + self.assertEqual(0, client.protocol_version) + self.assertEqual(0, client.session_id) + + +@pytest.mark.e2e +class TestS7CommPlusDBRead(unittest.TestCase): + """Tests for db_read() - reading from DB1 (read-only).""" + + client: S7CommPlusClient + + @classmethod + def setUpClass(cls) -> None: + cls.client = S7CommPlusClient() + cls.client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.client: + cls.client.disconnect() + + def test_db_read_int(self) -> None: + """Test db_read() for Int values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_INT1, 2) + value = struct.unpack(">h", data)[0] + self.assertEqual(EXPECTED_INT1, value) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_INT2, 2) + value = struct.unpack(">h", data)[0] + self.assertEqual(EXPECTED_INT2, value) + + def test_db_read_real(self) -> None: + """Test db_read() for Real values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_FLOAT1, 4) + value = struct.unpack(">f", data)[0] + self.assertAlmostEqual(EXPECTED_FLOAT1, value, places=2) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_FLOAT2, 4) + value = struct.unpack(">f", data)[0] + self.assertAlmostEqual(EXPECTED_FLOAT2, value, places=2) + + def test_db_read_byte(self) -> None: + """Test db_read() for Byte values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_BYTE1, 1) + self.assertEqual(EXPECTED_BYTE1, data[0]) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_BYTE2, 1) + self.assertEqual(EXPECTED_BYTE2, data[0]) + + def test_db_read_word(self) -> None: + """Test db_read() for Word values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_WORD1, 2) + value = struct.unpack(">H", data)[0] + self.assertEqual(EXPECTED_WORD1, value) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_WORD2, 2) + value = struct.unpack(">H", data)[0] + self.assertEqual(EXPECTED_WORD2, value) + + def test_db_read_dword(self) -> None: + """Test db_read() for DWord values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_DWORD1, 4) + value = struct.unpack(">I", data)[0] + self.assertEqual(EXPECTED_DWORD1, value) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_DWORD2, 4) + value = struct.unpack(">I", data)[0] + self.assertEqual(EXPECTED_DWORD2, value) + + def test_db_read_dint(self) -> None: + """Test db_read() for DInt values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_DINT1, 4) + value = struct.unpack(">i", data)[0] + self.assertEqual(EXPECTED_DINT1, value) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_DINT2, 4) + value = struct.unpack(">i", data)[0] + self.assertEqual(EXPECTED_DINT2, value) + + def test_db_read_char(self) -> None: + """Test db_read() for Char values.""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_CHAR1, 1) + self.assertEqual(EXPECTED_CHAR1, chr(data[0])) + + data = self.client.db_read(DB_READ_ONLY, OFFSET_CHAR2, 1) + self.assertEqual(EXPECTED_CHAR2, chr(data[0])) + + def test_db_read_bool(self) -> None: + """Test db_read() for Bool values (packed in byte).""" + data = self.client.db_read(DB_READ_ONLY, OFFSET_BOOLS, 1) + self.assertEqual(EXPECTED_BOOL0, bool(data[0] & 0x01)) + self.assertEqual(EXPECTED_BOOL1, bool(data[0] & 0x02)) + + def test_db_read_entire_block(self) -> None: + """Test db_read() for entire DB.""" + data = self.client.db_read(DB_READ_ONLY, 0, DB_SIZE) + self.assertEqual(DB_SIZE, len(data)) + + # Verify a few values + int1 = struct.unpack(">h", data[OFFSET_INT1 : OFFSET_INT1 + 2])[0] + self.assertEqual(EXPECTED_INT1, int1) + + float1 = struct.unpack(">f", data[OFFSET_FLOAT1 : OFFSET_FLOAT1 + 4])[0] + self.assertAlmostEqual(EXPECTED_FLOAT1, float1, places=2) + + dword1 = struct.unpack(">I", data[OFFSET_DWORD1 : OFFSET_DWORD1 + 4])[0] + self.assertEqual(EXPECTED_DWORD1, dword1) + + +@pytest.mark.e2e +class TestS7CommPlusDBWrite(unittest.TestCase): + """Tests for db_write() - writing to DB2 (read/write).""" + + client: S7CommPlusClient + + @classmethod + def setUpClass(cls) -> None: + cls.client = S7CommPlusClient() + cls.client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.client: + cls.client.disconnect() + + def test_db_write_int(self) -> None: + """Test db_write() for Int values.""" + test_value = 10 + data = struct.pack(">h", test_value) + self.client.db_write(DB_READ_WRITE, OFFSET_INT1, data) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_INT1, 2) + self.assertEqual(test_value, struct.unpack(">h", result)[0]) + + def test_db_write_real(self) -> None: + """Test db_write() for Real values.""" + test_value = 456.789 + data = struct.pack(">f", test_value) + self.client.db_write(DB_READ_WRITE, OFFSET_FLOAT1, data) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_FLOAT1, 4) + self.assertAlmostEqual(test_value, struct.unpack(">f", result)[0], places=2) + + def test_db_write_byte(self) -> None: + """Test db_write() for Byte values.""" + test_value = 0xAB + self.client.db_write(DB_READ_WRITE, OFFSET_BYTE1, bytes([test_value])) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_BYTE1, 1) + self.assertEqual(test_value, result[0]) + + def test_db_write_word(self) -> None: + """Test db_write() for Word values.""" + test_value = 0x1234 + data = struct.pack(">H", test_value) + self.client.db_write(DB_READ_WRITE, OFFSET_WORD1, data) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_WORD1, 2) + self.assertEqual(test_value, struct.unpack(">H", result)[0]) + + def test_db_write_dword(self) -> None: + """Test db_write() for DWord values.""" + test_value = 0xDEADBEEF + data = struct.pack(">I", test_value) + self.client.db_write(DB_READ_WRITE, OFFSET_DWORD1, data) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_DWORD1, 4) + self.assertEqual(test_value, struct.unpack(">I", result)[0]) + + def test_db_write_dint(self) -> None: + """Test db_write() for DInt values.""" + test_value = -123456789 + data = struct.pack(">i", test_value) + self.client.db_write(DB_READ_WRITE, OFFSET_DINT1, data) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_DINT1, 4) + self.assertEqual(test_value, struct.unpack(">i", result)[0]) + + def test_db_write_char(self) -> None: + """Test db_write() for Char values.""" + test_value = "X" + self.client.db_write(DB_READ_WRITE, OFFSET_CHAR1, test_value.encode("ascii")) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_CHAR1, 1) + self.assertEqual(test_value, chr(result[0])) + + def test_db_write_bool(self) -> None: + """Test db_write() for Bool values (packed in byte).""" + # Read current byte, set bit 0 and bit 7, write back + data = bytearray(self.client.db_read(DB_READ_WRITE, OFFSET_BOOLS, 1)) + data[0] = data[0] | 0x01 | 0x80 # Set bit 0 and bit 7 + self.client.db_write(DB_READ_WRITE, OFFSET_BOOLS, bytes(data)) + + result = self.client.db_read(DB_READ_WRITE, OFFSET_BOOLS, 1) + self.assertTrue(bool(result[0] & 0x01)) + self.assertTrue(bool(result[0] & 0x80)) + + +@pytest.mark.e2e +class TestS7CommPlusMultiRead(unittest.TestCase): + """Tests for db_read_multi() - multiple reads in a single request.""" + + client: S7CommPlusClient + + @classmethod + def setUpClass(cls) -> None: + cls.client = S7CommPlusClient() + cls.client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.client: + cls.client.disconnect() + + def test_multi_read(self) -> None: + """Test db_read_multi() reads multiple regions.""" + items = [ + (DB_READ_ONLY, OFFSET_INT1, 2), + (DB_READ_ONLY, OFFSET_FLOAT1, 4), + (DB_READ_ONLY, OFFSET_DWORD1, 4), + ] + results = self.client.db_read_multi(items) + self.assertEqual(3, len(results)) + + int_val = struct.unpack(">h", results[0])[0] + self.assertEqual(EXPECTED_INT1, int_val) + + float_val = struct.unpack(">f", results[1])[0] + self.assertAlmostEqual(EXPECTED_FLOAT1, float_val, places=2) + + dword_val = struct.unpack(">I", results[2])[0] + self.assertEqual(EXPECTED_DWORD1, dword_val) + + def test_multi_read_across_dbs(self) -> None: + """Test db_read_multi() across different data blocks.""" + # Write a known value to DB2 first + test_int = 777 + self.client.db_write(DB_READ_WRITE, OFFSET_INT1, struct.pack(">h", test_int)) + + items = [ + (DB_READ_ONLY, OFFSET_INT1, 2), + (DB_READ_WRITE, OFFSET_INT1, 2), + ] + results = self.client.db_read_multi(items) + self.assertEqual(2, len(results)) + + self.assertEqual(EXPECTED_INT1, struct.unpack(">h", results[0])[0]) + self.assertEqual(test_int, struct.unpack(">h", results[1])[0]) + + +@pytest.mark.e2e +class TestS7CommPlusExplore(unittest.TestCase): + """Tests for explore() - browsing the PLC object tree.""" + + client: S7CommPlusClient + + @classmethod + def setUpClass(cls) -> None: + cls.client = S7CommPlusClient() + cls.client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.client: + cls.client.disconnect() + + def test_explore(self) -> None: + """Test explore() returns data.""" + try: + data = self.client.explore() + except Exception as e: + pytest.skip(f"Explore not supported: {e}") + self.assertIsInstance(data, bytes) + self.assertGreater(len(data), 0) + + +@pytest.mark.e2e +class TestS7CommPlusDiagnostics(unittest.TestCase): + """Diagnostic tests for debugging protocol issues against real PLCs. + + These tests are designed to dump raw protocol data at every layer + to help diagnose why db_read/db_write fail against real hardware. + """ + + client: S7CommPlusClient + + @classmethod + def setUpClass(cls) -> None: + cls.client = S7CommPlusClient() + cls.client.connect(PLC_IP, PLC_PORT, PLC_RACK, PLC_SLOT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.client: + cls.client.disconnect() + + def test_diag_connection_info(self) -> None: + """Dump connection state after successful connect.""" + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: Connection Info") + print(f" connected: {self.client.connected}") + print(f" protocol_version: V{self.client.protocol_version}") + print(f" session_id: 0x{self.client.session_id:08X} ({self.client.session_id})") + print(f"{'=' * 60}") + self.assertTrue(self.client.connected) + + def test_diag_explore_raw(self) -> None: + """Explore and dump the raw response for analysis.""" + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: Explore raw response") + try: + data = self.client.explore() + print(f" Length: {len(data)} bytes") + # Dump in 32-byte rows + for i in range(0, len(data), 32): + chunk = data[i : i + 32] + hex_str = chunk.hex(" ") + ascii_str = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk) + print(f" {i:04x}: {hex_str:<96s} {ascii_str}") + except Exception as e: + print(f" Explore failed: {e}") + print(f"{'=' * 60}") + + def test_diag_db_read_single_byte(self) -> None: + """Try to read a single byte from DB1 offset 0 and dump everything.""" + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: db_read(DB1, offset=0, size=1)") + try: + data = self.client.db_read(DB_READ_ONLY, 0, 1) + print(f" Success! Got {len(data)} bytes: {data.hex(' ')}") + except Exception as e: + print(f" FAILED: {type(e).__name__}: {e}") + print(f"{'=' * 60}") + + def test_diag_db_read_full_block(self) -> None: + """Try to read the full test DB and dump everything.""" + print(f"\n{'=' * 60}") + print(f"DIAGNOSTIC: db_read(DB{DB_READ_ONLY}, offset=0, size={DB_SIZE})") + try: + data = self.client.db_read(DB_READ_ONLY, 0, DB_SIZE) + print(f" Success! Got {len(data)} bytes:") + for i in range(0, len(data), 16): + chunk = data[i : i + 16] + print(f" {i:04x}: {chunk.hex(' ')}") + except Exception as e: + print(f" FAILED: {type(e).__name__}: {e}") + print(f"{'=' * 60}") + + def test_diag_raw_get_multi_variables(self) -> None: + """Send a raw GetMultiVariables with different payload formats and dump responses. + + This tries several payload encodings to see which ones the PLC accepts. + """ + from snap7.s7commplus.protocol import FunctionCode + from snap7.s7commplus.vlq import encode_uint32_vlq + + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: Raw GetMultiVariables payload experiments") + + assert self.client._connection is not None + + # Experiment 1: Our current format (item_count + object_id + offset + size) + payloads = { + "current_format (count=1, obj=0x00010001, off=0, sz=2)": ( + encode_uint32_vlq(1) + encode_uint32_vlq(0x00010001) + encode_uint32_vlq(0) + encode_uint32_vlq(2) + ), + "empty_payload": b"", + "just_zero": encode_uint32_vlq(0), + "single_vlq_1": encode_uint32_vlq(1), + } + + for label, payload in payloads.items(): + print(f"\n --- {label} ---") + print(f" Payload ({len(payload)} bytes): {payload.hex(' ')}") + try: + response = self.client._connection.send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + print(f" Response ({len(response)} bytes): {response.hex(' ')}") + + # Try to parse return code + if len(response) > 0: + from snap7.s7commplus.vlq import decode_uint32_vlq + + rc, consumed = decode_uint32_vlq(response, 0) + print(f" Return code (VLQ): {rc} (0x{rc:X})") + remaining = response[consumed:] + if remaining: + print(f" After return code ({len(remaining)} bytes): {remaining.hex(' ')}") + except Exception as e: + print(f" EXCEPTION: {type(e).__name__}: {e}") + + print(f"\n{'=' * 60}") + + def test_diag_raw_set_variable(self) -> None: + """Try SetVariable (0x04F2) instead of SetMultiVariables to see if PLC responds differently.""" + from snap7.s7commplus.protocol import FunctionCode + + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: Raw SetVariable / GetVariable experiments") + + assert self.client._connection is not None + + function_codes = { + "GET_VARIABLE (0x04FC)": FunctionCode.GET_VARIABLE, + "GET_MULTI_VARIABLES (0x054C)": FunctionCode.GET_MULTI_VARIABLES, + "SET_VARIABLE (0x04F2)": FunctionCode.SET_VARIABLE, + } + + # Simple payload: just try empty or minimal + for label, fc in function_codes.items(): + print(f"\n --- {label} with empty payload ---") + try: + response = self.client._connection.send_request(fc, b"") + print(f" Response ({len(response)} bytes): {response.hex(' ')}") + except Exception as e: + print(f" EXCEPTION: {type(e).__name__}: {e}") + + print(f"\n{'=' * 60}") + + def test_diag_explore_then_read(self) -> None: + """Explore first to discover object IDs, then try reading using those IDs.""" + from snap7.s7commplus.protocol import FunctionCode, ElementID + from snap7.s7commplus.vlq import encode_uint32_vlq, decode_uint32_vlq + + print(f"\n{'=' * 60}") + print("DIAGNOSTIC: Explore -> extract object IDs -> try reading") + + assert self.client._connection is not None + + try: + explore_data = self.client._connection.send_request(FunctionCode.EXPLORE, b"") + print(f" Explore response ({len(explore_data)} bytes)") + + # Scan for StartOfObject markers and extract relation IDs + object_ids = [] + i = 0 + while i < len(explore_data): + if explore_data[i] == ElementID.START_OF_OBJECT: + if i + 5 <= len(explore_data): + rel_id = struct.unpack_from(">I", explore_data, i + 1)[0] + object_ids.append(rel_id) + print(f" Found object at offset {i}: relation_id=0x{rel_id:08X}") + i += 5 + else: + i += 1 + + # Try reading using each discovered object ID + for obj_id in object_ids[:5]: # Limit to first 5 + print(f"\n --- Read using object_id=0x{obj_id:08X} ---") + payload = encode_uint32_vlq(1) + encode_uint32_vlq(obj_id) + encode_uint32_vlq(0) + encode_uint32_vlq(4) + try: + response = self.client._connection.send_request(FunctionCode.GET_MULTI_VARIABLES, payload) + print(f" Response ({len(response)} bytes): {response.hex(' ')}") + if len(response) > 0: + rc, consumed = decode_uint32_vlq(response, 0) + print(f" Return code: {rc} (0x{rc:X})") + except Exception as e: + print(f" EXCEPTION: {type(e).__name__}: {e}") + + except Exception as e: + print(f" Explore failed: {type(e).__name__}: {e}") + + print(f"\n{'=' * 60}") diff --git a/tests/test_s7commplus_server.py b/tests/test_s7commplus_server.py new file mode 100644 index 00000000..2f08f575 --- /dev/null +++ b/tests/test_s7commplus_server.py @@ -0,0 +1,304 @@ +"""Integration tests for S7CommPlus server, client, and async client.""" + +import struct +import time +from collections.abc import Generator + +import pytest +import asyncio + +from snap7.s7commplus.server import S7CommPlusServer, CPUState, DataBlock +from snap7.s7commplus.client import S7CommPlusClient +from snap7.s7commplus.async_client import S7CommPlusAsyncClient +from snap7.s7commplus.protocol import ProtocolVersion + +# Use a high port to avoid conflicts +TEST_PORT = 11120 + + +@pytest.fixture() +def server() -> Generator[S7CommPlusServer, None, None]: + """Create and start an S7CommPlus server with test data blocks.""" + srv = S7CommPlusServer() + + # Register DB1 with named variables + srv.register_db( + 1, + { + "temperature": ("Real", 0), + "pressure": ("Real", 4), + "running": ("Bool", 8), + "count": ("DInt", 10), + "name": ("Int", 14), + }, + ) + + # Register DB2 with raw data + srv.register_raw_db(2, bytearray(256)) + + # Pre-populate some values in DB1 + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 23.5) # temperature + struct.pack_into(">f", db1.data, 4, 1.013) # pressure + db1.data[8] = 1 # running = True + struct.pack_into(">i", db1.data, 10, 42) # count + + srv.start(port=TEST_PORT) + time.sleep(0.1) # Let server start + + yield srv + + srv.stop() + + +class TestServer: + """Test the server emulator itself.""" + + def test_register_db(self) -> None: + srv = S7CommPlusServer() + db = srv.register_db(1, {"temp": ("Real", 0)}) + assert db.number == 1 + assert "temp" in db.variables + assert db.variables["temp"].byte_offset == 0 + + def test_register_raw_db(self) -> None: + srv = S7CommPlusServer() + data = bytearray(b"\x01\x02\x03\x04") + db = srv.register_raw_db(10, data) + assert db.read(0, 4) == b"\x01\x02\x03\x04" + + def test_cpu_state(self) -> None: + srv = S7CommPlusServer() + assert srv.cpu_state == CPUState.RUN + srv.cpu_state = CPUState.STOP + assert srv.cpu_state == CPUState.STOP + + def test_data_block_read_write(self) -> None: + db = DataBlock(1, 100) + db.write(0, b"\x01\x02\x03\x04") + assert db.read(0, 4) == b"\x01\x02\x03\x04" + + def test_data_block_named_variable(self) -> None: + db = DataBlock(1, 100) + db.add_variable("temp", "Real", 0) + db.write(0, struct.pack(">f", 42.0)) + wire_type, raw = db.read_variable("temp") + value = struct.unpack(">f", raw)[0] + assert abs(value - 42.0) < 0.001 + + def test_data_block_read_past_end(self) -> None: + db = DataBlock(1, 4) + db.write(0, b"\xff\xff\xff\xff") + # Read past end should pad with zeros + data = db.read(2, 4) + assert data == b"\xff\xff\x00\x00" + + def test_unknown_variable_type(self) -> None: + db = DataBlock(1, 100) + with pytest.raises(ValueError, match="Unknown type name"): + db.add_variable("bad", "NonExistentType", 0) + + +class TestClientServerIntegration: + """Test client against the server emulator.""" + + def test_connect_disconnect(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V1 + client.disconnect() + assert not client.connected + + def test_context_manager(self, server: S7CommPlusServer) -> None: + with S7CommPlusClient() as client: + client.connect("127.0.0.1", port=TEST_PORT) + assert client.connected + assert not client.connected + + def test_read_real(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_read_multiple_values(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + # Read temperature and pressure + data = client.db_read(1, 0, 8) + temp = struct.unpack_from(">f", data, 0)[0] + pressure = struct.unpack_from(">f", data, 4)[0] + assert abs(temp - 23.5) < 0.001 + assert abs(pressure - 1.013) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + # Write a new temperature + client.db_write(1, 0, struct.pack(">f", 99.9)) + + # Read it back + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 99.9) < 0.1 + finally: + client.disconnect() + + def test_write_dint(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + # Write count + client.db_write(1, 10, struct.pack(">i", 12345)) + + # Read it back + data = client.db_read(1, 10, 4) + value = struct.unpack(">i", data)[0] + assert value == 12345 + finally: + client.disconnect() + + def test_read_db2_raw(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + # DB2 should be all zeros + data = client.db_read(2, 0, 10) + assert data == b"\x00" * 10 + finally: + client.disconnect() + + def test_multi_read(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + results = client.db_read_multi( + [ + (1, 0, 4), # temperature from DB1 + (1, 4, 4), # pressure from DB1 + (2, 0, 4), # zeros from DB2 + ] + ) + assert len(results) == 3 + temp = struct.unpack(">f", results[0])[0] + assert abs(temp - 23.5) < 0.001 + assert results[2] == b"\x00\x00\x00\x00" + finally: + client.disconnect() + + def test_explore(self, server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=TEST_PORT) + try: + response = client.explore() + # Response should contain data about registered DBs + assert len(response) > 0 + finally: + client.disconnect() + + def test_server_data_persists_across_clients(self, server: S7CommPlusServer) -> None: + # Client 1 writes + c1 = S7CommPlusClient() + c1.connect("127.0.0.1", port=TEST_PORT) + c1.db_write(2, 0, b"\xde\xad\xbe\xef") + c1.disconnect() + + # Client 2 reads + c2 = S7CommPlusClient() + c2.connect("127.0.0.1", port=TEST_PORT) + data = c2.db_read(2, 0, 4) + c2.disconnect() + + assert data == b"\xde\xad\xbe\xef" + + def test_multiple_concurrent_clients(self, server: S7CommPlusServer) -> None: + clients = [] + for _ in range(3): + c = S7CommPlusClient() + c.connect("127.0.0.1", port=TEST_PORT) + clients.append(c) + + # All should have different session IDs + session_ids = {c.session_id for c in clients} + assert len(session_ids) == 3 + + for c in clients: + c.disconnect() + + +@pytest.mark.asyncio +class TestAsyncClientServerIntegration: + """Test async client against the server emulator.""" + + async def test_connect_disconnect(self, server: S7CommPlusServer) -> None: + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=TEST_PORT) + assert client.connected + assert client.session_id != 0 + await client.disconnect() + assert not client.connected + + async def test_async_context_manager(self, server: S7CommPlusServer) -> None: + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + assert client.connected + assert not client.connected + + async def test_read_real(self, server: S7CommPlusServer) -> None: + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + + async def test_write_and_read_back(self, server: S7CommPlusServer) -> None: + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + await client.db_write(1, 0, struct.pack(">f", 77.7)) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 77.7) < 0.1 + + async def test_multi_read(self, server: S7CommPlusServer) -> None: + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + results = await client.db_read_multi( + [ + (1, 0, 4), + (1, 10, 4), + ] + ) + assert len(results) == 2 + temp = struct.unpack(">f", results[0])[0] + assert abs(temp - 23.5) < 0.1 # May be modified by earlier test + + async def test_explore(self, server: S7CommPlusServer) -> None: + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + response = await client.explore() + assert len(response) > 0 + + async def test_concurrent_reads(self, server: S7CommPlusServer) -> None: + """Test that asyncio.Lock prevents interleaved requests.""" + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT) + + async def read_temp() -> float: + data = await client.db_read(1, 0, 4) + return float(struct.unpack(">f", data)[0]) + + results = await asyncio.gather(read_temp(), read_temp(), read_temp()) + assert len(results) == 3 + for r in results: + assert isinstance(r, float) diff --git a/tests/test_s7commplus_unit.py b/tests/test_s7commplus_unit.py new file mode 100644 index 00000000..f7c5e57e --- /dev/null +++ b/tests/test_s7commplus_unit.py @@ -0,0 +1,459 @@ +"""Unit tests for S7CommPlus client payload builders, connection parsing, and error paths.""" + +import struct +import pytest + +from snap7.s7commplus.client import ( + S7CommPlusClient, + _build_read_payload, + _parse_read_response, + _build_write_payload, + _parse_write_response, +) +from snap7.s7commplus.codec import encode_pvalue_blob +from snap7.s7commplus.connection import S7CommPlusConnection, _element_size +from snap7.s7commplus.protocol import DataType, ElementID, ObjectId +from snap7.s7commplus.vlq import ( + encode_uint32_vlq, + encode_uint64_vlq, + encode_int32_vlq, + decode_uint32_vlq, +) + + +# -- Payload builder / parser tests -- + + +class TestBuildReadPayload: + def test_single_item(self) -> None: + payload = _build_read_payload([(1, 0, 4)]) + assert isinstance(payload, bytes) + assert len(payload) > 0 + + def test_multi_item(self) -> None: + payload = _build_read_payload([(1, 0, 4), (2, 10, 8)]) + assert isinstance(payload, bytes) + # Multi-item payload should be larger than single + single = _build_read_payload([(1, 0, 4)]) + assert len(payload) > len(single) + + +class TestParseReadResponse: + @staticmethod + def _build_response( + return_value: int = 0, + items: list[bytes] | None = None, + errors: list[tuple[int, int]] | None = None, + ) -> bytes: + """Build a synthetic GetMultiVariables response.""" + result = bytearray() + # ReturnValue (UInt64 VLQ) + result += encode_uint64_vlq(return_value) + + # Value list + if items: + for i, item_data in enumerate(items, 1): + result += encode_uint32_vlq(i) # ItemNumber + result += encode_pvalue_blob(item_data) # PValue + result += encode_uint32_vlq(0) # Terminator + + # Error list + if errors: + for err_item_nr, err_value in errors: + result += encode_uint32_vlq(err_item_nr) + result += encode_uint64_vlq(err_value) + result += encode_uint32_vlq(0) # Terminator + + return bytes(result) + + def test_single_item_success(self) -> None: + data = bytes([1, 2, 3, 4]) + response = self._build_response(items=[data]) + results = _parse_read_response(response) + assert len(results) == 1 + assert results[0] == data + + def test_multi_item_success(self) -> None: + data1 = bytes([0x0A, 0x0B]) + data2 = bytes([0x0C, 0x0D, 0x0E]) + response = self._build_response(items=[data1, data2]) + results = _parse_read_response(response) + assert len(results) == 2 + assert results[0] == data1 + assert results[1] == data2 + + def test_error_return_value(self) -> None: + response = self._build_response(return_value=0x05A9) + results = _parse_read_response(response) + assert results == [] + + def test_empty_response(self) -> None: + response = self._build_response() + results = _parse_read_response(response) + assert results == [] + + def test_with_error_items(self) -> None: + data1 = bytes([1, 2, 3, 4]) + response = self._build_response(items=[data1], errors=[(2, 0xDEAD)]) + results = _parse_read_response(response) + assert len(results) == 2 + assert results[0] == data1 + assert results[1] is None # Error item + + +class TestParseWriteResponse: + @staticmethod + def _build_response(return_value: int = 0, errors: list[tuple[int, int]] | None = None) -> bytes: + result = bytearray() + result += encode_uint64_vlq(return_value) + if errors: + for err_item_nr, err_value in errors: + result += encode_uint32_vlq(err_item_nr) + result += encode_uint64_vlq(err_value) + result += encode_uint32_vlq(0) # Terminator + return bytes(result) + + def test_success(self) -> None: + response = self._build_response(return_value=0) + _parse_write_response(response) # Should not raise + + def test_error_return_value(self) -> None: + response = self._build_response(return_value=0x05A9) + with pytest.raises(RuntimeError, match="Write failed"): + _parse_write_response(response) + + def test_error_items(self) -> None: + response = self._build_response(return_value=0, errors=[(1, 0xDEAD)]) + with pytest.raises(RuntimeError, match="Write failed"): + _parse_write_response(response) + + +class TestBuildWritePayload: + def test_single_item(self) -> None: + payload = _build_write_payload([(1, 0, bytes([1, 2, 3, 4]))]) + assert isinstance(payload, bytes) + assert len(payload) > 0 + + def test_data_appears_in_payload(self) -> None: + data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + payload = _build_write_payload([(1, 0, data)]) + # The raw data should appear in the payload (inside the BLOB PValue) + assert data in payload + + +# -- Client/server payload agreement -- + + +class TestPayloadAgreement: + """Verify client payloads can be parsed by the server's request parser.""" + + def test_read_payload_roundtrip(self) -> None: + """Build a read payload, then manually verify it has expected structure.""" + payload = _build_read_payload([(1, 0, 4)]) + offset = 0 + + # LinkId (4 bytes fixed) + link_id = struct.unpack_from(">I", payload, offset)[0] + offset += 4 + assert link_id == 0 + + # Item count (VLQ) + item_count, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + assert item_count == 1 + + # Total field count (VLQ) + total_fields, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + assert total_fields == 6 # 4 base + 2 LIDs + + def test_write_read_consistency(self) -> None: + """Build write and read payloads for same address, verify both compile.""" + read_payload = _build_read_payload([(1, 0, 4)]) + write_payload = _build_write_payload([(1, 0, bytes([1, 2, 3, 4]))]) + assert isinstance(read_payload, bytes) + assert isinstance(write_payload, bytes) + + +# -- Connection unit tests -- + + +class TestConnectionElementSize: + def test_single_byte(self) -> None: + for dt in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + assert _element_size(dt) == 1 + + def test_two_byte(self) -> None: + for dt in (DataType.UINT, DataType.WORD, DataType.INT): + assert _element_size(dt) == 2 + + def test_four_byte(self) -> None: + for dt in (DataType.REAL, DataType.RID): + assert _element_size(dt) == 4 + + def test_eight_byte(self) -> None: + for dt in (DataType.LREAL, DataType.TIMESTAMP): + assert _element_size(dt) == 8 + + def test_variable_length(self) -> None: + for dt in (DataType.UDINT, DataType.BLOB, DataType.WSTRING, DataType.STRUCT): + assert _element_size(dt) == 0 + + +class TestSkipTypedValue: + """Test S7CommPlusConnection._skip_typed_value with constructed byte buffers.""" + + @pytest.fixture() + def conn(self) -> S7CommPlusConnection: + return S7CommPlusConnection("127.0.0.1") + + def test_null(self, conn: S7CommPlusConnection) -> None: + assert conn._skip_typed_value(b"", 0, DataType.NULL, 0x00) == 0 + + def test_bool(self, conn: S7CommPlusConnection) -> None: + data = bytes([0x01]) + assert conn._skip_typed_value(data, 0, DataType.BOOL, 0x00) == 1 + + def test_usint(self, conn: S7CommPlusConnection) -> None: + data = bytes([42]) + assert conn._skip_typed_value(data, 0, DataType.USINT, 0x00) == 1 + + def test_byte(self, conn: S7CommPlusConnection) -> None: + data = bytes([0xAB]) + assert conn._skip_typed_value(data, 0, DataType.BYTE, 0x00) == 1 + + def test_sint(self, conn: S7CommPlusConnection) -> None: + data = bytes([0xD6]) + assert conn._skip_typed_value(data, 0, DataType.SINT, 0x00) == 1 + + def test_uint(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">H", 1000) + assert conn._skip_typed_value(data, 0, DataType.UINT, 0x00) == 2 + + def test_word(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">H", 0xBEEF) + assert conn._skip_typed_value(data, 0, DataType.WORD, 0x00) == 2 + + def test_int(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">h", -1000) + assert conn._skip_typed_value(data, 0, DataType.INT, 0x00) == 2 + + def test_udint(self, conn: S7CommPlusConnection) -> None: + vlq = encode_uint32_vlq(100000) + new_offset = conn._skip_typed_value(vlq, 0, DataType.UDINT, 0x00) + assert new_offset == len(vlq) + + def test_dword(self, conn: S7CommPlusConnection) -> None: + vlq = encode_uint32_vlq(0xDEADBEEF) + new_offset = conn._skip_typed_value(vlq, 0, DataType.DWORD, 0x00) + assert new_offset == len(vlq) + + def test_aid(self, conn: S7CommPlusConnection) -> None: + vlq = encode_uint32_vlq(306) + new_offset = conn._skip_typed_value(vlq, 0, DataType.AID, 0x00) + assert new_offset == len(vlq) + + def test_dint(self, conn: S7CommPlusConnection) -> None: + vlq = encode_int32_vlq(-100000) + new_offset = conn._skip_typed_value(vlq, 0, DataType.DINT, 0x00) + assert new_offset == len(vlq) + + def test_ulint(self, conn: S7CommPlusConnection) -> None: + vlq = encode_uint64_vlq(2**40) + new_offset = conn._skip_typed_value(vlq, 0, DataType.ULINT, 0x00) + assert new_offset == len(vlq) + + def test_lword(self, conn: S7CommPlusConnection) -> None: + vlq = encode_uint64_vlq(0xCAFE) + new_offset = conn._skip_typed_value(vlq, 0, DataType.LWORD, 0x00) + assert new_offset == len(vlq) + + def test_lint(self, conn: S7CommPlusConnection) -> None: + from snap7.s7commplus.vlq import encode_int64_vlq + + vlq = encode_int64_vlq(-(2**40)) + new_offset = conn._skip_typed_value(vlq, 0, DataType.LINT, 0x00) + assert new_offset == len(vlq) + + def test_real(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">f", 3.14) + assert conn._skip_typed_value(data, 0, DataType.REAL, 0x00) == 4 + + def test_lreal(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">d", 2.718) + assert conn._skip_typed_value(data, 0, DataType.LREAL, 0x00) == 8 + + def test_timestamp(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">Q", 0x0001020304050607) + assert conn._skip_typed_value(data, 0, DataType.TIMESTAMP, 0x00) == 8 + + def test_timespan(self, conn: S7CommPlusConnection) -> None: + from snap7.s7commplus.vlq import encode_int64_vlq + + vlq = encode_int64_vlq(5000) + # TIMESPAN uses uint64_vlq for skipping in _skip_typed_value + new_offset = conn._skip_typed_value(vlq, 0, DataType.TIMESPAN, 0x00) + assert new_offset == len(vlq) + + def test_rid(self, conn: S7CommPlusConnection) -> None: + data = struct.pack(">I", 0x12345678) + assert conn._skip_typed_value(data, 0, DataType.RID, 0x00) == 4 + + def test_blob(self, conn: S7CommPlusConnection) -> None: + blob_data = bytes([1, 2, 3, 4]) + vlq_len = encode_uint32_vlq(len(blob_data)) + data = vlq_len + blob_data + new_offset = conn._skip_typed_value(data, 0, DataType.BLOB, 0x00) + assert new_offset == len(data) + + def test_wstring(self, conn: S7CommPlusConnection) -> None: + text = "hello".encode("utf-8") + vlq_len = encode_uint32_vlq(len(text)) + data = vlq_len + text + new_offset = conn._skip_typed_value(data, 0, DataType.WSTRING, 0x00) + assert new_offset == len(data) + + def test_struct(self, conn: S7CommPlusConnection) -> None: + # Struct with 2 USINT sub-values + vlq_count = encode_uint32_vlq(2) + sub1 = bytes([0x00, DataType.USINT, 0x0A]) # flags + type + value + sub2 = bytes([0x00, DataType.USINT, 0x14]) + data = vlq_count + sub1 + sub2 + new_offset = conn._skip_typed_value(data, 0, DataType.STRUCT, 0x00) + assert new_offset == len(data) + + def test_unknown_type(self, conn: S7CommPlusConnection) -> None: + # Unknown type should return same offset (can't skip) + assert conn._skip_typed_value(bytes([0xFF]), 0, 0xFF, 0x00) == 0 + + # -- Array tests -- + + def test_array_fixed_size(self, conn: S7CommPlusConnection) -> None: + count_vlq = encode_uint32_vlq(3) + elements = bytes([10, 20, 30]) + data = count_vlq + elements + new_offset = conn._skip_typed_value(data, 0, DataType.USINT, 0x10) + assert new_offset == len(data) + + def test_array_variable_length(self, conn: S7CommPlusConnection) -> None: + count_vlq = encode_uint32_vlq(2) + elem1 = encode_uint32_vlq(100) + elem2 = encode_uint32_vlq(200) + data = count_vlq + elem1 + elem2 + new_offset = conn._skip_typed_value(data, 0, DataType.UDINT, 0x10) + assert new_offset == len(data) + + def test_array_empty_data(self, conn: S7CommPlusConnection) -> None: + # Edge case: array flag but no data + assert conn._skip_typed_value(b"", 0, DataType.USINT, 0x10) == 0 + + +class TestParseCreateObjectResponse: + """Test _parse_create_object_response with constructed payloads.""" + + def _build_create_response_with_session_version(self, version: int, datatype: int = DataType.UDINT) -> bytes: + """Build a minimal CreateObject response containing ServerSessionVersion.""" + payload = bytearray() + # Attribute tag + payload += bytes([ElementID.ATTRIBUTE]) + # Attribute ID = ServerSessionVersion (306) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + # Typed value: flags + datatype + VLQ value + payload += bytes([0x00, datatype]) + payload += encode_uint32_vlq(version) + return bytes(payload) + + def test_parse_udint_version(self) -> None: + conn = S7CommPlusConnection("127.0.0.1") + payload = self._build_create_response_with_session_version(3, DataType.UDINT) + conn._parse_create_object_response(payload) + assert conn._server_session_version == 3 + + def test_parse_dword_version(self) -> None: + conn = S7CommPlusConnection("127.0.0.1") + payload = self._build_create_response_with_session_version(2, DataType.DWORD) + conn._parse_create_object_response(payload) + assert conn._server_session_version == 2 + + def test_version_not_found(self) -> None: + conn = S7CommPlusConnection("127.0.0.1") + # Build payload with a different attribute, not ServerSessionVersion + payload = bytearray() + payload += bytes([ElementID.ATTRIBUTE]) + payload += encode_uint32_vlq(999) # Some other attribute ID + payload += bytes([0x00, DataType.USINT, 42]) + conn._parse_create_object_response(bytes(payload)) + assert conn._server_session_version is None + + def test_with_preceding_attributes(self) -> None: + conn = S7CommPlusConnection("127.0.0.1") + payload = bytearray() + # First attribute: some random one with a UINT value + payload += bytes([ElementID.ATTRIBUTE]) + payload += encode_uint32_vlq(100) # Random attribute ID + payload += bytes([0x00, DataType.UINT]) + payload += struct.pack(">H", 0x1234) + # Second attribute: ServerSessionVersion + payload += bytes([ElementID.ATTRIBUTE]) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(1) + conn._parse_create_object_response(bytes(payload)) + assert conn._server_session_version == 1 + + def test_with_start_of_object(self) -> None: + conn = S7CommPlusConnection("127.0.0.1") + payload = bytearray() + # StartOfObject tag (needs RelationId + ClassId + ClassFlags + AttributeId) + payload += bytes([ElementID.START_OF_OBJECT]) + payload += struct.pack(">I", 0) # RelationId (4 bytes) + payload += encode_uint32_vlq(100) # ClassId + payload += encode_uint32_vlq(0) # ClassFlags + payload += encode_uint32_vlq(0) # AttributeId + # TerminatingObject + payload += bytes([ElementID.TERMINATING_OBJECT]) + # Now the attribute we want + payload += bytes([ElementID.ATTRIBUTE]) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(3) + conn._parse_create_object_response(bytes(payload)) + assert conn._server_session_version == 3 + + +# -- Client error path tests -- + + +class TestClientErrorPaths: + def test_properties_not_connected(self) -> None: + client = S7CommPlusClient() + assert client.connected is False + assert client.protocol_version == 0 + assert client.session_id == 0 + assert client.using_legacy_fallback is False + + def test_db_read_not_connected(self) -> None: + client = S7CommPlusClient() + with pytest.raises(RuntimeError, match="Not connected"): + client.db_read(1, 0, 4) + + def test_db_write_not_connected(self) -> None: + client = S7CommPlusClient() + with pytest.raises(RuntimeError, match="Not connected"): + client.db_write(1, 0, bytes([1, 2, 3, 4])) + + def test_db_read_multi_not_connected(self) -> None: + client = S7CommPlusClient() + with pytest.raises(RuntimeError, match="Not connected"): + client.db_read_multi([(1, 0, 4)]) + + def test_explore_not_connected(self) -> None: + client = S7CommPlusClient() + with pytest.raises(RuntimeError, match="Not connected"): + client.explore() + + def test_context_manager_not_connected(self) -> None: + """Test that context manager works without connection (disconnect is a no-op).""" + with S7CommPlusClient() as client: + assert client.connected is False + # Should not raise diff --git a/tests/test_s7commplus_vlq.py b/tests/test_s7commplus_vlq.py new file mode 100644 index 00000000..d7dbb596 --- /dev/null +++ b/tests/test_s7commplus_vlq.py @@ -0,0 +1,161 @@ +"""Tests for S7CommPlus VLQ (Variable-Length Quantity) encoding.""" + +import pytest + +from snap7.s7commplus.vlq import ( + encode_uint32_vlq, + decode_uint32_vlq, + encode_int32_vlq, + decode_int32_vlq, + encode_uint64_vlq, + decode_uint64_vlq, + encode_int64_vlq, + decode_int64_vlq, +) + + +class TestUInt32Vlq: + """Test unsigned 32-bit VLQ encoding/decoding.""" + + @pytest.mark.parametrize( + "value, expected_bytes", + [ + (0, bytes([0x00])), + (1, bytes([0x01])), + (0x7F, bytes([0x7F])), + (0x80, bytes([0x81, 0x00])), + (0xFF, bytes([0x81, 0x7F])), + (0x100, bytes([0x82, 0x00])), + (0x3FFF, bytes([0xFF, 0x7F])), + (0x4000, bytes([0x81, 0x80, 0x00])), + ], + ) + def test_encode_known_values(self, value: int, expected_bytes: bytes) -> None: + assert encode_uint32_vlq(value) == expected_bytes + + @pytest.mark.parametrize( + "value", + [0, 1, 127, 128, 255, 256, 16383, 16384, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF], + ) + def test_roundtrip(self, value: int) -> None: + encoded = encode_uint32_vlq(value) + decoded, consumed = decode_uint32_vlq(encoded) + assert decoded == value + assert consumed == len(encoded) + + def test_decode_with_offset(self) -> None: + prefix = bytes([0xAA, 0xBB]) + encoded = encode_uint32_vlq(12345) + data = prefix + encoded + decoded, consumed = decode_uint32_vlq(data, offset=2) + assert decoded == 12345 + + def test_encode_out_of_range(self) -> None: + with pytest.raises(ValueError): + encode_uint32_vlq(-1) + with pytest.raises(ValueError): + encode_uint32_vlq(0x100000000) + + def test_decode_truncated(self) -> None: + # Continuation bit set but no more data + with pytest.raises(ValueError): + decode_uint32_vlq(bytes([0x80])) + + +class TestInt32Vlq: + """Test signed 32-bit VLQ encoding/decoding.""" + + @pytest.mark.parametrize( + "value", + [0, 1, -1, 63, -64, 64, -65, 127, -128, 0x7FFFFFFF, -0x80000000, 1234567, -1234567], + ) + def test_roundtrip(self, value: int) -> None: + encoded = encode_int32_vlq(value) + decoded, consumed = decode_int32_vlq(encoded) + assert decoded == value + assert consumed == len(encoded) + + def test_negative_one(self) -> None: + """Test that -1 encodes compactly.""" + encoded = encode_int32_vlq(-1) + decoded, _ = decode_int32_vlq(encoded) + assert decoded == -1 + + def test_min_value(self) -> None: + """Test INT32_MIN boundary.""" + encoded = encode_int32_vlq(-0x80000000) + decoded, _ = decode_int32_vlq(encoded) + assert decoded == -0x80000000 + + def test_encode_out_of_range(self) -> None: + with pytest.raises(ValueError): + encode_int32_vlq(-0x80000001) + with pytest.raises(ValueError): + encode_int32_vlq(0x80000000) + + +class TestUInt64Vlq: + """Test unsigned 64-bit VLQ encoding/decoding.""" + + @pytest.mark.parametrize( + "value", + [ + 0, + 1, + 127, + 128, + 0xFFFF, + 0xFFFFFFFF, + 0xFFFFFFFFFF, + 0x00FFFFFFFFFFFFFF, # Just below the special threshold + 0x00FFFFFFFFFFFFFF + 1, # At the special threshold + 0xFFFFFFFFFFFFFFFF, # Max uint64 + ], + ) + def test_roundtrip(self, value: int) -> None: + encoded = encode_uint64_vlq(value) + decoded, consumed = decode_uint64_vlq(encoded) + assert decoded == value + assert consumed == len(encoded) + + def test_max_encoding_length(self) -> None: + """Max uint64 should encode in at most 9 bytes.""" + encoded = encode_uint64_vlq(0xFFFFFFFFFFFFFFFF) + assert len(encoded) <= 9 + + def test_encode_out_of_range(self) -> None: + with pytest.raises(ValueError): + encode_uint64_vlq(-1) + with pytest.raises(ValueError): + encode_uint64_vlq(0x10000000000000000) + + +class TestInt64Vlq: + """Test signed 64-bit VLQ encoding/decoding.""" + + @pytest.mark.parametrize( + "value", + [ + 0, + 1, + -1, + 63, + -64, + 127, + -128, + 0x7FFFFFFFFFFFFFFF, # Max int64 + -0x8000000000000000, # Min int64 + 123456789012345, + -123456789012345, + ], + ) + def test_roundtrip(self, value: int) -> None: + encoded = encode_int64_vlq(value) + decoded, consumed = decode_int64_vlq(encoded) + assert decoded == value + assert consumed == len(encoded) + + def test_max_encoding_length(self) -> None: + """Max/min int64 should encode in at most 9 bytes.""" + assert len(encode_int64_vlq(0x7FFFFFFFFFFFFFFF)) <= 9 + assert len(encode_int64_vlq(-0x8000000000000000)) <= 9 diff --git a/tests/test_s7protocol.py b/tests/test_s7protocol.py new file mode 100644 index 00000000..c0d62f1b --- /dev/null +++ b/tests/test_s7protocol.py @@ -0,0 +1,537 @@ +"""Tests for snap7.s7protocol — response parsers with crafted PDUs, error paths.""" + +import struct +from typing import Any + +import pytest +from datetime import datetime + +from snap7.s7protocol import ( + S7Protocol, + S7PDUType, + S7Function, + S7UserDataGroup, + S7UserDataSubfunction, + get_return_code_description, +) +from snap7.error import S7ProtocolError + + +class TestGetReturnCodeDescription: + def test_known_code(self) -> None: + assert get_return_code_description(0xFF) == "Success" + + def test_unknown_code(self) -> None: + assert get_return_code_description(0xAB) == "Unknown error" + + +class TestParseResponse: + """Test parse_response() with crafted PDUs.""" + + def setup_method(self) -> None: + self.proto = S7Protocol() + + def _build_ack_data_pdu( + self, + func_code: int, + item_count: int = 1, + data_section: bytes = b"", + error_class: int = 0, + error_code: int = 0, + sequence: int = 1, + ) -> bytes: + """Build a minimal ACK_DATA PDU.""" + params = struct.pack(">BB", func_code, item_count) + header = struct.pack( + ">BBHHHHBB", + 0x32, + S7PDUType.ACK_DATA, + 0x0000, + sequence, + len(params), + len(data_section), + error_class, + error_code, + ) + return header + params + data_section + + def test_pdu_too_short(self) -> None: + with pytest.raises(S7ProtocolError, match="too short"): + self.proto.parse_response(b"\x32\x03\x00") + + def test_invalid_protocol_id(self) -> None: + # Build a valid-length PDU with wrong protocol ID + pdu = struct.pack(">BBHHHHBB", 0x99, S7PDUType.ACK_DATA, 0, 1, 0, 0, 0, 0) + with pytest.raises(S7ProtocolError, match="Invalid protocol ID"): + self.proto.parse_response(pdu) + + def test_unexpected_pdu_type(self) -> None: + # REQUEST type (0x01) is not a valid response + pdu = struct.pack(">BBHHHHBB", 0x32, S7PDUType.REQUEST, 0, 1, 0, 0, 0, 0) + with pytest.raises(S7ProtocolError, match="Expected response PDU"): + self.proto.parse_response(pdu) + + def test_header_error(self) -> None: + pdu = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK_DATA, 0, 1, 0, 0, 0x05, 0x04) + with pytest.raises(S7ProtocolError, match="S7 protocol error"): + self.proto.parse_response(pdu) + + def test_ack_no_data(self) -> None: + """ACK (type 0x02) PDU with no params or data — write response.""" + pdu = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK, 0, 1, 0, 0, 0, 0) + resp = self.proto.parse_response(pdu) + assert resp["sequence"] == 1 + assert resp["parameters"] is None + assert resp["data"] is None + + def test_read_response(self) -> None: + """ACK_DATA with read parameters and data.""" + data_section = struct.pack(">BBH", 0xFF, 0x04, 16) + b"\xab\xcd" # 16 bits = 2 bytes + pdu = self._build_ack_data_pdu(S7Function.READ_AREA, 1, data_section) + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["function_code"] == S7Function.READ_AREA + assert resp["data"]["data"] == b"\xab\xcd" + + def test_write_response_single_byte_data(self) -> None: + """Write response with single-byte data section (return code only).""" + data_section = b"\xff" # success + pdu = self._build_ack_data_pdu(S7Function.WRITE_AREA, 1, data_section) + resp = self.proto.parse_response(pdu) + assert resp["data"]["return_code"] == 0xFF + + def test_setup_comm_response(self) -> None: + params = struct.pack(">BBHHH", S7Function.SETUP_COMMUNICATION, 0x00, 1, 1, 480) + header = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK_DATA, 0, 1, len(params), 0, 0, 0) + pdu = header + params + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["pdu_length"] == 480 + + def test_param_section_extends_beyond_pdu(self) -> None: + # param_len = 100 but PDU is too short + header = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK_DATA, 0, 1, 100, 0, 0, 0) + with pytest.raises(S7ProtocolError, match="Parameter section extends beyond PDU"): + self.proto.parse_response(header) + + def test_data_section_extends_beyond_pdu(self) -> None: + # data_len = 100 but no data follows + params = struct.pack(">BB", S7Function.READ_AREA, 1) + header = struct.pack(">BBHHHHBB", 0x32, S7PDUType.ACK_DATA, 0, 1, len(params), 100, 0, 0) + pdu = header + params + with pytest.raises(S7ProtocolError, match="Data section extends beyond PDU"): + self.proto.parse_response(pdu) + + def test_unknown_function_code(self) -> None: + pdu = self._build_ack_data_pdu(0xAA, 0) + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["function_code"] == 0xAA + + +class TestUserDataParsing: + """Test USERDATA PDU parsing.""" + + def setup_method(self) -> None: + self.proto = S7Protocol() + + def _build_userdata_response( + self, + group: int = S7UserDataGroup.SZL, + subfunction: int = S7UserDataSubfunction.READ_SZL, + sequence_number: int = 0, + last_data_unit: int = 0x00, + error_code: int = 0, + data_payload: bytes = b"", + ) -> bytes: + """Build a USERDATA response PDU.""" + # Parameter section (12 bytes for response) + type_group = 0x80 | (group & 0x0F) # response type + param_data = struct.pack( + ">BBBBBBBBBBH", + 0x00, # Reserved + 0x01, # Parameter count + 0x12, # Type header + 0x08, # Length (response = 8) + 0x12, # Method (response) + type_group, + subfunction, + sequence_number, + 0x00, # Data unit reference + last_data_unit, + error_code, + ) + + # Data section + data_section = ( + struct.pack( + ">BBH", + 0xFF, # Return code (success) + 0x09, # Transport size (octet string) + len(data_payload), + ) + + data_payload + ) + + header = struct.pack( + ">BBHHHH", + 0x32, + S7PDUType.USERDATA, + 0x0000, + 1, + len(param_data), + len(data_section), + ) + + return header + param_data + data_section + + def test_userdata_too_short(self) -> None: + pdu = struct.pack(">BBHH", 0x32, S7PDUType.USERDATA, 0, 1) + with pytest.raises(S7ProtocolError, match="too short"): + self.proto.parse_response(pdu) + + def test_userdata_response(self) -> None: + pdu = self._build_userdata_response(data_payload=b"\x01\x02\x03\x04") + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["group"] == S7UserDataGroup.SZL + assert resp["data"]["data"] == b"\x01\x02\x03\x04" + + def test_userdata_with_error(self) -> None: + pdu = self._build_userdata_response(error_code=0x8104) + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["error_code"] == 0x8104 + + def test_userdata_more_data_available(self) -> None: + pdu = self._build_userdata_response(last_data_unit=0x01, sequence_number=0x05) + resp = self.proto.parse_response(pdu) + assert resp["parameters"]["last_data_unit"] == 0x01 + assert resp["parameters"]["sequence_number"] == 0x05 + + +class TestParseStartUploadResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_valid_response(self) -> None: + # Layout: func(1) + status(1) + reserved(1) + reserved(1) + upload_id(4) = 8 bytes + # Parser reads upload_id from raw_params[4:8] + raw_params = struct.pack(">BBBBI", S7Function.START_UPLOAD, 0x00, 0x00, 0x00, 0x12345678) + # Add block length: len_field(1) + length_str + # Condition: len(raw_params) > 9 + len_field, so we need total > 9 + len(length_str) + length_str = b"000100" + raw_params += struct.pack(">B", len(length_str)) + length_str + b"\x00" # extra byte to satisfy > + response = {"raw_parameters": raw_params} + result = self.proto.parse_start_upload_response(response) + assert result["upload_id"] == 0x12345678 + assert result["block_length"] == 100 + + def test_short_response(self) -> None: + response = {"raw_parameters": b"\x00\x00\x00"} + result = self.proto.parse_start_upload_response(response) + assert result["upload_id"] == 0 + assert result["block_length"] == 0 + + def test_no_raw_parameters(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_start_upload_response(response) + assert result["upload_id"] == 0 + + def test_invalid_length_string(self) -> None: + raw_params = struct.pack(">BBBI", 0x1D, 0, 0, 1) + raw_params += struct.pack(">B", 3) + b"abc" + response = {"raw_parameters": raw_params} + result = self.proto.parse_start_upload_response(response) + assert result["block_length"] == 0 # ValueError caught + + +class TestParseUploadResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_valid_response(self) -> None: + response = {"data": {"data": b"\x01\x02\x03\x04\x05"}} + result = self.proto.parse_upload_response(response) + assert result == b"\x01\x02\x03\x04\x05" + + def test_short_data(self) -> None: + response = {"data": {"data": b"\x01\x02"}} + result = self.proto.parse_upload_response(response) + assert result == b"" + + def test_empty_response(self) -> None: + response = {"data": {"data": b""}} + result = self.proto.parse_upload_response(response) + assert result == b"" + + def test_no_data_key(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_upload_response(response) + assert result == b"" + + +class TestParseListBlocksResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_valid_response(self) -> None: + # Build entries: indicator(0x30) + type + count(2 bytes) + data = b"" + data += struct.pack(">BBH", 0x30, 0x38, 5) # OB: 5 + data += struct.pack(">BBH", 0x30, 0x41, 3) # DB: 3 + data += struct.pack(">BBH", 0x30, 0x43, 7) # FC: 7 + response = {"data": {"data": data}} + result = self.proto.parse_list_blocks_response(response) + assert result["OBCount"] == 5 + assert result["DBCount"] == 3 + assert result["FCCount"] == 7 + assert result["FBCount"] == 0 + + def test_empty_data(self) -> None: + response = {"data": {"data": b""}} + result = self.proto.parse_list_blocks_response(response) + assert result["DBCount"] == 0 + + def test_no_data(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_list_blocks_response(response) + assert all(v == 0 for v in result.values()) + + def test_unknown_block_type_ignored(self) -> None: + data = struct.pack(">BBH", 0x30, 0xFF, 99) # unknown type + response = {"data": {"data": data}} + result = self.proto.parse_list_blocks_response(response) + assert all(v == 0 for v in result.values()) + + +class TestParseListBlocksOfTypeResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_valid_response(self) -> None: + # Each entry: block_num(2) + unknown(1) + lang(1) + data = struct.pack(">HBB", 1, 0, 0) + struct.pack(">HBB", 5, 0, 0) + struct.pack(">HBB", 100, 0, 0) + response = {"data": {"data": data}} + result = self.proto.parse_list_blocks_of_type_response(response) + assert result == [1, 5, 100] + + def test_empty_data(self) -> None: + response = {"data": {"data": b""}} + result = self.proto.parse_list_blocks_of_type_response(response) + assert result == [] + + def test_no_data(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_list_blocks_of_type_response(response) + assert result == [] + + +class TestParseGetBlockInfoResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_short_data(self) -> None: + response = {"data": {"data": b"\x00" * 10}} + result = self.proto.parse_get_block_info_response(response) + assert result["block_type"] == 0 + assert result["mc7_size"] == 0 + + def test_valid_data(self) -> None: + raw_data = bytearray(80) + raw_data[1] = 0x41 # block_type = DB + raw_data[9] = 0x01 # flags + raw_data[10] = 0x05 # lang + struct.pack_into(">H", raw_data, 12, 42) # block_number + struct.pack_into(">I", raw_data, 14, 1024) # load_size + struct.pack_into(">H", raw_data, 34, 100) # sbb_length + struct.pack_into(">H", raw_data, 38, 50) # local_data + struct.pack_into(">H", raw_data, 40, 200) # mc7_size + raw_data[66] = 0x03 # version + struct.pack_into(">H", raw_data, 68, 0xABCD) # checksum + + response = {"data": {"data": bytes(raw_data)}} + result = self.proto.parse_get_block_info_response(response) + assert result["block_type"] == 0x41 + assert result["block_number"] == 42 + assert result["mc7_size"] == 200 + assert result["load_size"] == 1024 + assert result["checksum"] == 0xABCD + assert result["version"] == 0x03 + + def test_no_data(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_get_block_info_response(response) + assert result["block_type"] == 0 + + +class TestParseReadSZLResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_first_fragment(self) -> None: + raw_data = struct.pack(">HH", 0x0011, 0x0000) + b"\x01\x02\x03" + response = {"data": {"data": raw_data}} + result = self.proto.parse_read_szl_response(response, first_fragment=True) + assert result["szl_id"] == 0x0011 + assert result["szl_index"] == 0x0000 + assert result["data"] == b"\x01\x02\x03" + + def test_first_fragment_short_data(self) -> None: + response = {"data": {"data": b"\x00\x01"}} + result = self.proto.parse_read_szl_response(response, first_fragment=True) + assert result["szl_id"] == 0 + assert result["data"] == b"" + + def test_followup_fragment(self) -> None: + response = {"data": {"data": b"\xaa\xbb\xcc"}} + result = self.proto.parse_read_szl_response(response, first_fragment=False) + assert result["data"] == b"\xaa\xbb\xcc" + assert result["szl_id"] == 0 + + def test_empty_data(self) -> None: + response: dict[str, Any] = {} + result = self.proto.parse_read_szl_response(response) + assert result["data"] == b"" + + +class TestParseGetClockResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_valid_bcd_time(self) -> None: + # BCD: reserved, year=24, month=03, day=15, hour=10, minute=30, second=45, dow=6(Saturday) + raw_data = struct.pack(">BBBBBBBB", 0x00, 0x24, 0x03, 0x15, 0x10, 0x30, 0x45, 0x06) + response = {"data": {"data": raw_data}} + result = self.proto.parse_get_clock_response(response) + assert result.year == 2024 + assert result.month == 3 + assert result.day == 15 + assert result.hour == 10 + assert result.minute == 30 + assert result.second == 45 + + def test_year_90_is_1990(self) -> None: + raw_data = struct.pack(">BBBBBBBB", 0x00, 0x90, 0x01, 0x01, 0x00, 0x00, 0x00, 0x01) + response = {"data": {"data": raw_data}} + result = self.proto.parse_get_clock_response(response) + assert result.year == 1990 + + def test_short_data_returns_now(self) -> None: + response = {"data": {"data": b"\x00\x01"}} + result = self.proto.parse_get_clock_response(response) + # Should return roughly "now" + assert isinstance(result, datetime) + + def test_invalid_bcd_date_returns_now(self) -> None: + # Month=99 is invalid + raw_data = struct.pack(">BBBBBBBB", 0x00, 0x24, 0x99, 0x15, 0x10, 0x30, 0x45, 0x06) + response = {"data": {"data": raw_data}} + result = self.proto.parse_get_clock_response(response) + # Should fallback to now + assert isinstance(result, datetime) + + +class TestParseParameterEdgeCases: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_empty_parameters(self) -> None: + result = self.proto._parse_parameters(b"") + assert result == {} + + def test_read_response_params_too_short(self) -> None: + with pytest.raises(S7ProtocolError, match="too short"): + self.proto._parse_read_response_params(b"\x04") + + def test_write_response_params_too_short(self) -> None: + with pytest.raises(S7ProtocolError, match="too short"): + self.proto._parse_write_response_params(b"\x05") + + def test_setup_comm_params_too_short(self) -> None: + with pytest.raises(S7ProtocolError, match="too short"): + self.proto._parse_setup_comm_response_params(b"\xf0\x00\x00") + + +class TestParseDataSection: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_single_byte(self) -> None: + result = self.proto._parse_data_section(b"\xff") + assert result["return_code"] == 0xFF + + def test_two_three_bytes_raw(self) -> None: + result = self.proto._parse_data_section(b"\xaa\xbb") + assert result["raw_data"] == b"\xaa\xbb" + + def test_octet_string_transport(self) -> None: + # Transport size 0x09 = octet string (byte length) + data = struct.pack(">BBH", 0xFF, 0x09, 3) + b"\x01\x02\x03" + result = self.proto._parse_data_section(data) + assert result["data"] == b"\x01\x02\x03" + + def test_byte_transport_bit_length(self) -> None: + # Transport size 0x04 = byte (bit length) + data = struct.pack(">BBH", 0xFF, 0x04, 24) + b"\x01\x02\x03" # 24 bits = 3 bytes + result = self.proto._parse_data_section(data) + assert result["data"] == b"\x01\x02\x03" + + +class TestExtractReadData: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_no_data_in_response(self) -> None: + with pytest.raises(S7ProtocolError, match="No data"): + self.proto.extract_read_data({}, None, 0) # type: ignore[arg-type] + + def test_non_success_return_code(self) -> None: + response = {"data": {"return_code": 0x05, "data": b""}} + with pytest.raises(S7ProtocolError, match="Read operation failed"): + self.proto.extract_read_data(response, None, 0) # type: ignore[arg-type] + + def test_success(self) -> None: + from snap7.datatypes import S7WordLen + + response = {"data": {"return_code": 0xFF, "data": b"\x01\x02\x03"}} + result = self.proto.extract_read_data(response, S7WordLen.BYTE, 3) + assert result == [1, 2, 3] + + +class TestCheckWriteResponse: + def setup_method(self) -> None: + self.proto = S7Protocol() + + def test_header_error(self) -> None: + with pytest.raises(S7ProtocolError, match="Write operation failed"): + self.proto.check_write_response({"error_code": 0x8104}) + + def test_data_section_error(self) -> None: + with pytest.raises(S7ProtocolError, match="Write operation failed"): + self.proto.check_write_response({"error_code": 0, "data": {"return_code": 0x05}}) + + def test_success_with_data(self) -> None: + # Should not raise + self.proto.check_write_response({"error_code": 0, "data": {"return_code": 0xFF}}) + + def test_success_without_data(self) -> None: + # ACK without data section — should not raise + self.proto.check_write_response({"error_code": 0}) + + +class TestValidatePDUReference: + def setup_method(self) -> None: + self.proto = S7Protocol() + self.proto.sequence = 5 + + def test_matching(self) -> None: + # Should not raise + self.proto.validate_pdu_reference(5) + + def test_stale(self) -> None: + from snap7.error import S7StalePacketError + + with pytest.raises(S7StalePacketError): + self.proto.validate_pdu_reference(3) + + def test_lost(self) -> None: + from snap7.error import S7PacketLostError + + with pytest.raises(S7PacketLostError): + self.proto.validate_pdu_reference(7) diff --git a/tests/test_server.py b/tests/test_server.py index 99ac7b60..4e17c895 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,14 +1,16 @@ from ctypes import c_char import logging import time +from datetime import datetime import pytest import unittest from threading import Thread +from snap7.client import Client from snap7.error import server_errors, error_text from snap7.server import Server -from snap7.type import SrvEvent, mkEvent, mkLog, SrvArea, Parameter +from snap7.type import SrvEvent, mkEvent, mkLog, SrvArea, Parameter, Block logging.basicConfig(level=logging.WARNING) @@ -237,8 +239,358 @@ def test_server_area_management(self) -> None: pass -if __name__ == "__main__": - import logging +ip = "127.0.0.1" +SERVER_PORT = 12200 + + +@pytest.mark.server +class TestServerBlockOperations(unittest.TestCase): + """Test block operations through client-server communication.""" + + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + # Register several DBs so list_blocks / list_blocks_of_type have something to report + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.register_area(SrvArea.DB, 2, bytearray(200)) + cls.server.register_area(SrvArea.DB, 3, bytearray(50)) + # Also register other area types + cls.server.register_area(SrvArea.MK, 0, bytearray(64)) + cls.server.register_area(SrvArea.PA, 0, bytearray(64)) + cls.server.register_area(SrvArea.PE, 0, bytearray(64)) + cls.server.register_area(SrvArea.TM, 0, bytearray(64)) + cls.server.register_area(SrvArea.CT, 0, bytearray(64)) + cls.server.start(tcp_port=SERVER_PORT) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Client() + self.client.connect(ip, 0, 1, SERVER_PORT) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + # ------------------------------------------------------------------ + # list_blocks + # ------------------------------------------------------------------ + def test_list_blocks(self) -> None: + """list_blocks() should return counts; DBCount >= 3 since we registered 3 DBs.""" + bl = self.client.list_blocks() + self.assertGreaterEqual(bl.DBCount, 3) + # OB/FB/FC should be 0 since the emulator only tracks DBs + self.assertEqual(bl.OBCount, 0) + self.assertEqual(bl.FBCount, 0) + self.assertEqual(bl.FCCount, 0) + + # ------------------------------------------------------------------ + # list_blocks_of_type + # ------------------------------------------------------------------ + def test_list_blocks_of_type_db(self) -> None: + """list_blocks_of_type(DB) should include the DB numbers we registered.""" + block_nums = self.client.list_blocks_of_type(Block.DB, 100) + self.assertIn(1, block_nums) + self.assertIn(2, block_nums) + self.assertIn(3, block_nums) + + def test_list_blocks_of_type_ob(self) -> None: + """list_blocks_of_type(OB) should return an empty list (no OBs registered).""" + block_nums = self.client.list_blocks_of_type(Block.OB, 100) + self.assertEqual(block_nums, []) + + # ------------------------------------------------------------------ + # get_block_info + # ------------------------------------------------------------------ + def test_get_block_info(self) -> None: + """get_block_info for a registered DB should return valid metadata.""" + info = self.client.get_block_info(Block.DB, 1) + self.assertEqual(info.MC7Size, 100) # matches registered size + self.assertEqual(info.BlkNumber, 1) + + def test_get_block_info_db2(self) -> None: + """get_block_info for DB2 with size 200.""" + info = self.client.get_block_info(Block.DB, 2) + self.assertEqual(info.MC7Size, 200) + self.assertEqual(info.BlkNumber, 2) + + # ------------------------------------------------------------------ + # upload (block transfer: START_UPLOAD -> UPLOAD -> END_UPLOAD) + # ------------------------------------------------------------------ + def test_upload(self) -> None: + """Upload a DB from the server and verify the returned data length.""" + # Write known data to DB1 first + test_data = bytearray(range(10)) + self.client.db_write(1, 0, test_data) + + # Upload the block + block_data = self.client.upload(1) + self.assertGreater(len(block_data), 0) + # Verify the first bytes match what we wrote + self.assertEqual(block_data[:10], test_data) + + def test_full_upload(self) -> None: + """full_upload should return block data and its size.""" + data, size = self.client.full_upload(Block.DB, 1) + self.assertGreater(size, 0) + self.assertEqual(len(data), size) + + # ------------------------------------------------------------------ + # download (block transfer: REQUEST_DOWNLOAD -> DOWNLOAD_BLOCK -> DOWNLOAD_ENDED) + # ------------------------------------------------------------------ + def test_download(self) -> None: + """Download data to a registered DB on the server.""" + download_data = bytearray([0xAA, 0xBB, 0xCC, 0xDD]) + result = self.client.download(download_data, block_num=1) + self.assertEqual(result, 0) + + # Verify the data was written by reading it back + read_back = self.client.db_read(1, 0, 4) + self.assertEqual(read_back, download_data) + + +@pytest.mark.server +class TestServerUserdataOperations(unittest.TestCase): + """Test USERDATA handlers (SZL, clock, CPU state) through client-server communication.""" + + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.start(tcp_port=SERVER_PORT + 1) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Client() + self.client.connect(ip, 0, 1, SERVER_PORT + 1) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + # ------------------------------------------------------------------ + # read_szl + # ------------------------------------------------------------------ + def test_read_szl_0x001c(self) -> None: + """read_szl(0x001C) should return component identification data.""" + szl = self.client.read_szl(0x001C, 0) + self.assertGreater(szl.Header.LengthDR, 0) + + def test_read_szl_0x0011(self) -> None: + """read_szl(0x0011) should return module identification data.""" + szl = self.client.read_szl(0x0011, 0) + self.assertGreater(szl.Header.LengthDR, 0) + + def test_read_szl_0x0131(self) -> None: + """read_szl(0x0131) should return communication parameters.""" + szl = self.client.read_szl(0x0131, 0) + self.assertGreater(szl.Header.LengthDR, 0) + + def test_read_szl_0x0232(self) -> None: + """read_szl(0x0232) should return protection level data.""" + szl = self.client.read_szl(0x0232, 0) + self.assertGreater(szl.Header.LengthDR, 0) + + def test_read_szl_0x0000(self) -> None: + """read_szl(0x0000) should return the list of available SZL IDs.""" + szl = self.client.read_szl(0x0000, 0) + self.assertGreater(szl.Header.LengthDR, 0) + + def test_read_szl_list(self) -> None: + """read_szl_list should return raw bytes of available SZL IDs.""" + data = self.client.read_szl_list() + self.assertIsInstance(data, bytes) + self.assertGreater(len(data), 0) + + # ------------------------------------------------------------------ + # get_cpu_info (uses read_szl 0x001C internally) + # ------------------------------------------------------------------ + def test_get_cpu_info(self) -> None: + """get_cpu_info should populate the S7CpuInfo structure.""" + info = self.client.get_cpu_info() + # The emulated server returns "CPU 315-2 PN/DP" + self.assertIn(b"CPU", info.ModuleTypeName) + + # ------------------------------------------------------------------ + # get_order_code (uses read_szl 0x0011 internally) + # ------------------------------------------------------------------ + def test_get_order_code(self) -> None: + """get_order_code should return order code data.""" + oc = self.client.get_order_code() + self.assertIn(b"6ES7", oc.OrderCode) + + # ------------------------------------------------------------------ + # get_cp_info (uses read_szl 0x0131 internally) + # ------------------------------------------------------------------ + def test_get_cp_info(self) -> None: + """get_cp_info should return communication parameters.""" + cp = self.client.get_cp_info() + self.assertGreater(cp.MaxPduLength, 0) + self.assertGreater(cp.MaxConnections, 0) + + # ------------------------------------------------------------------ + # get_protection (uses read_szl 0x0232 internally) + # ------------------------------------------------------------------ + def test_get_protection(self) -> None: + """get_protection should return protection settings.""" + prot = self.client.get_protection() + # Emulator returns no protection (sch_schal=1) + self.assertEqual(prot.sch_schal, 1) + + # ------------------------------------------------------------------ + # get/set PLC datetime (clock USERDATA handlers) + # ------------------------------------------------------------------ + def test_get_plc_datetime(self) -> None: + """get_plc_datetime should return a valid datetime object.""" + dt = self.client.get_plc_datetime() + self.assertIsInstance(dt, datetime) + # Should be recent (within last minute) + now = datetime.now() + delta = abs((now - dt).total_seconds()) + self.assertLess(delta, 60) + + def test_set_plc_datetime(self) -> None: + """set_plc_datetime should succeed (returns 0).""" + test_dt = datetime(2025, 6, 15, 12, 30, 45) + result = self.client.set_plc_datetime(test_dt) + self.assertEqual(result, 0) + + def test_set_plc_system_datetime(self) -> None: + """set_plc_system_datetime should succeed.""" + result = self.client.set_plc_system_datetime() + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # get_cpu_state (SZL-based CPU state request) + # ------------------------------------------------------------------ + def test_get_cpu_state(self) -> None: + """get_cpu_state should return a string state.""" + state = self.client.get_cpu_state() + self.assertIsInstance(state, str) + + +@pytest.mark.server +class TestServerPLCControl(unittest.TestCase): + """Test PLC control operations (stop/start) through client-server communication.""" + + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.start(tcp_port=SERVER_PORT + 2) - logging.basicConfig() + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Client() + self.client.connect(ip, 0, 1, SERVER_PORT + 2) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + def test_plc_stop(self) -> None: + """plc_stop should succeed and set the server CPU state to STOP.""" + result = self.client.plc_stop() + self.assertEqual(result, 0) + + def test_plc_hot_start(self) -> None: + """plc_hot_start should succeed.""" + result = self.client.plc_hot_start() + self.assertEqual(result, 0) + + def test_plc_cold_start(self) -> None: + """plc_cold_start should succeed.""" + result = self.client.plc_cold_start() + self.assertEqual(result, 0) + + def test_plc_stop_then_start(self) -> None: + """Stopping then starting the PLC should work in sequence.""" + self.assertEqual(self.client.plc_stop(), 0) + self.assertEqual(self.client.plc_hot_start(), 0) + + def test_compress(self) -> None: + """compress should succeed.""" + result = self.client.compress(timeout=1000) + self.assertEqual(result, 0) + + def test_copy_ram_to_rom(self) -> None: + """copy_ram_to_rom should succeed.""" + result = self.client.copy_ram_to_rom(timeout=1000) + self.assertEqual(result, 0) + + +@pytest.mark.server +class TestServerErrorScenarios(unittest.TestCase): + """Test error handling paths in the server.""" + + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + # Only register DB1 with a small area + cls.server.register_area(SrvArea.DB, 1, bytearray(10)) + cls.server.start(tcp_port=SERVER_PORT + 3) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Client() + self.client.connect(ip, 0, 1, SERVER_PORT + 3) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + def test_read_unregistered_db(self) -> None: + """Reading from an unregistered DB should still return data (server returns dummy data).""" + # The server returns dummy data for unregistered areas rather than an error + data = self.client.db_read(99, 0, 4) + self.assertEqual(len(data), 4) + + def test_write_beyond_area_bounds(self) -> None: + """Writing beyond area bounds should raise an error.""" + # DB1 is only 10 bytes, writing 20 bytes at offset 0 should fail + with self.assertRaises(Exception): + self.client.db_write(1, 0, bytearray(20)) + + def test_get_block_info_nonexistent(self) -> None: + """get_block_info for a non-existent block should raise an error.""" + with self.assertRaises(Exception): + self.client.get_block_info(Block.DB, 999) + + def test_upload_nonexistent_block(self) -> None: + """Uploading a non-existent block returns empty data (server has no data for that block).""" + # The server defaults to block_num=1 for unknown blocks due to parsing fallback, + # so the upload still completes but returns the default block's data. + # We just verify the operation doesn't crash. + data = self.client.upload(999) + self.assertIsInstance(data, bytearray) + + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py new file mode 100644 index 00000000..b0e4372d --- /dev/null +++ b/tests/test_server_cli.py @@ -0,0 +1,30 @@ +"""Tests for snap7.server.__main__ — CLI entrypoint.""" + +import pytest + +click = pytest.importorskip("click") +from click.testing import CliRunner # noqa: E402 +from snap7.server.__main__ import main # noqa: E402 + + +class TestServerCLI: + """Test the Click CLI entrypoint.""" + + def test_help(self) -> None: + runner = CliRunner() + result = runner.invoke(main, ["--help"]) + assert result.exit_code == 0 + assert "Start a S7 dummy server" in result.output + + def test_help_short(self) -> None: + runner = CliRunner() + result = runner.invoke(main, ["-h"]) + assert result.exit_code == 0 + assert "--port" in result.output + + def test_version(self) -> None: + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 + # Should print version string + assert "version" in result.output.lower() or "." in result.output diff --git a/tests/test_typed_access.py b/tests/test_typed_access.py new file mode 100644 index 00000000..c961cef6 --- /dev/null +++ b/tests/test_typed_access.py @@ -0,0 +1,202 @@ +"""Tests for typed data access methods on Client.""" + +import unittest + +import pytest + +from snap7.client import Client +from snap7.server import Server +from snap7.type import SrvArea + +ip = "127.0.0.1" +tcpport = 1102 +rack = 1 +slot = 1 + + +@pytest.mark.client +class TestTypedAccess(unittest.TestCase): + server: Server = None # type: ignore + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(600)) + cls.server.register_area(SrvArea.DB, 1, bytearray(600)) + cls.server.register_area(SrvArea.PA, 0, bytearray(100)) + cls.server.register_area(SrvArea.PE, 0, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.register_area(SrvArea.TM, 0, bytearray(100)) + cls.server.register_area(SrvArea.CT, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def setUp(self) -> None: + self.client = Client() + self.client.connect(ip, rack, slot, tcpport) + + def tearDown(self) -> None: + self.client.disconnect() + self.client.destroy() + + # Bool tests + + def test_bool_roundtrip(self) -> None: + self.client.db_write_bool(1, 0, 0, True) + self.assertTrue(self.client.db_read_bool(1, 0, 0)) + + self.client.db_write_bool(1, 0, 0, False) + self.assertFalse(self.client.db_read_bool(1, 0, 0)) + + def test_bool_preserves_other_bits(self) -> None: + # Write byte 0xFF first + self.client.db_write_byte(1, 0, 0xFF) + + # Clear bit 3 + self.client.db_write_bool(1, 0, 3, False) + self.assertFalse(self.client.db_read_bool(1, 0, 3)) + + # Other bits should still be set + self.assertTrue(self.client.db_read_bool(1, 0, 0)) + self.assertTrue(self.client.db_read_bool(1, 0, 1)) + self.assertTrue(self.client.db_read_bool(1, 0, 7)) + + def test_bool_all_bits(self) -> None: + self.client.db_write_byte(1, 0, 0) + for bit in range(8): + self.client.db_write_bool(1, 0, bit, True) + self.assertTrue(self.client.db_read_bool(1, 0, bit)) + + # Byte tests + + def test_byte_roundtrip(self) -> None: + self.client.db_write_byte(1, 10, 42) + self.assertEqual(42, self.client.db_read_byte(1, 10)) + + def test_byte_min_max(self) -> None: + self.client.db_write_byte(1, 10, 0) + self.assertEqual(0, self.client.db_read_byte(1, 10)) + + self.client.db_write_byte(1, 10, 255) + self.assertEqual(255, self.client.db_read_byte(1, 10)) + + # INT tests + + def test_int_roundtrip(self) -> None: + self.client.db_write_int(1, 20, 12345) + self.assertEqual(12345, self.client.db_read_int(1, 20)) + + def test_int_negative(self) -> None: + self.client.db_write_int(1, 20, -12345) + self.assertEqual(-12345, self.client.db_read_int(1, 20)) + + def test_int_min_max(self) -> None: + self.client.db_write_int(1, 20, -32768) + self.assertEqual(-32768, self.client.db_read_int(1, 20)) + + self.client.db_write_int(1, 20, 32767) + self.assertEqual(32767, self.client.db_read_int(1, 20)) + + # UINT tests + + def test_uint_roundtrip(self) -> None: + self.client.db_write_uint(1, 30, 50000) + self.assertEqual(50000, self.client.db_read_uint(1, 30)) + + def test_uint_min_max(self) -> None: + self.client.db_write_uint(1, 30, 0) + self.assertEqual(0, self.client.db_read_uint(1, 30)) + + self.client.db_write_uint(1, 30, 65535) + self.assertEqual(65535, self.client.db_read_uint(1, 30)) + + # WORD tests + + def test_word_roundtrip(self) -> None: + self.client.db_write_word(1, 40, 0xABCD) + self.assertEqual(0xABCD, self.client.db_read_word(1, 40)) + + # DINT tests + + def test_dint_roundtrip(self) -> None: + self.client.db_write_dint(1, 50, 100000) + self.assertEqual(100000, self.client.db_read_dint(1, 50)) + + def test_dint_negative(self) -> None: + self.client.db_write_dint(1, 50, -100000) + self.assertEqual(-100000, self.client.db_read_dint(1, 50)) + + def test_dint_min_max(self) -> None: + self.client.db_write_dint(1, 50, -2147483648) + self.assertEqual(-2147483648, self.client.db_read_dint(1, 50)) + + self.client.db_write_dint(1, 50, 2147483647) + self.assertEqual(2147483647, self.client.db_read_dint(1, 50)) + + # UDINT tests + + def test_udint_roundtrip(self) -> None: + self.client.db_write_udint(1, 60, 3000000000) + self.assertEqual(3000000000, self.client.db_read_udint(1, 60)) + + # DWORD tests + + def test_dword_roundtrip(self) -> None: + self.client.db_write_dword(1, 70, 0xDEADBEEF) + self.assertEqual(0xDEADBEEF, self.client.db_read_dword(1, 70)) + + # REAL tests + + def test_real_roundtrip(self) -> None: + self.client.db_write_real(1, 80, 3.14) + self.assertAlmostEqual(3.14, self.client.db_read_real(1, 80), places=2) + + def test_real_zero(self) -> None: + self.client.db_write_real(1, 80, 0.0) + self.assertEqual(0.0, self.client.db_read_real(1, 80)) + + def test_real_negative(self) -> None: + self.client.db_write_real(1, 80, -273.15) + self.assertAlmostEqual(-273.15, self.client.db_read_real(1, 80), places=2) + + # LREAL tests + + def test_lreal_roundtrip(self) -> None: + self.client.db_write_lreal(1, 90, 3.141592653589793) + self.assertAlmostEqual(3.141592653589793, self.client.db_read_lreal(1, 90), places=10) + + def test_lreal_zero(self) -> None: + self.client.db_write_lreal(1, 90, 0.0) + self.assertEqual(0.0, self.client.db_read_lreal(1, 90)) + + # STRING tests + + def test_string_roundtrip(self) -> None: + # First write a proper S7 string header + self.client.db_write_string(1, 100, "Hello") + result = self.client.db_read_string(1, 100) + self.assertEqual("Hello", result) + + def test_string_empty(self) -> None: + self.client.db_write_string(1, 100, "") + result = self.client.db_read_string(1, 100) + self.assertEqual("", result) + + # Combined test + + def test_multiple_types_coexist(self) -> None: + """Write different types at different offsets and verify they don't interfere.""" + self.client.db_write_int(1, 400, 1234) + self.client.db_write_real(1, 404, 5.678) + self.client.db_write_bool(1, 408, 0, True) + self.client.db_write_dint(1, 410, -99999) + + self.assertEqual(1234, self.client.db_read_int(1, 400)) + self.assertAlmostEqual(5.678, self.client.db_read_real(1, 404), places=2) + self.assertTrue(self.client.db_read_bool(1, 408, 0)) + self.assertEqual(-99999, self.client.db_read_dint(1, 410)) diff --git a/tests/test_util.py b/tests/test_util.py index 2f76d2d0..49f3d192 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,13 +1,16 @@ import datetime +import logging import pytest import unittest import struct from typing import cast +from unittest.mock import MagicMock from snap7 import DB, Row +from snap7.type import Area, WordLen from snap7.util import get_byte, get_time, get_fstring, get_int from snap7.util import set_byte, set_time, set_fstring, set_int -from snap7.type import WordLen +from snap7.util.db import print_row test_spec = """ @@ -801,5 +804,675 @@ def test_set_dtl_in_row(self) -> None: self.assertEqual(result.second, 30) +class TestMemoryviewCompat(unittest.TestCase): + """Test that setter and getter functions work with memoryview buffers.""" + + def test_set_bool_memoryview(self) -> None: + from snap7.util.setters import set_bool + + buf = bytearray(1) + mv = memoryview(buf) + set_bool(mv, 0, 0, True) + self.assertEqual(buf[0], 1) + + def test_set_byte_memoryview(self) -> None: + buf = bytearray(1) + mv = memoryview(buf) + set_byte(mv, 0, 42) + self.assertEqual(buf[0], 42) + + def test_set_int_memoryview(self) -> None: + buf = bytearray(2) + mv = memoryview(buf) + set_int(mv, 0, -1234) + self.assertEqual(struct.unpack(">h", buf)[0], -1234) + + def test_set_word_memoryview(self) -> None: + from snap7.util.setters import set_word + + buf = bytearray(2) + mv = memoryview(buf) + set_word(mv, 0, 65535) + self.assertEqual(struct.unpack(">H", buf)[0], 65535) + + def test_set_real_memoryview(self) -> None: + from snap7.util.setters import set_real + + buf = bytearray(4) + mv = memoryview(buf) + set_real(mv, 0, 123.456) + val = struct.unpack(">f", buf)[0] + self.assertAlmostEqual(val, 123.456, places=2) + + def test_set_dword_memoryview(self) -> None: + from snap7.util.setters import set_dword + + buf = bytearray(4) + mv = memoryview(buf) + set_dword(mv, 0, 0xDEADBEEF) + self.assertEqual(struct.unpack(">I", buf)[0], 0xDEADBEEF) + + def test_set_dint_memoryview(self) -> None: + from snap7.util.setters import set_dint + + buf = bytearray(4) + mv = memoryview(buf) + set_dint(mv, 0, -100000) + self.assertEqual(struct.unpack(">i", buf)[0], -100000) + + def test_set_usint_memoryview(self) -> None: + from snap7.util.setters import set_usint + + buf = bytearray(1) + mv = memoryview(buf) + set_usint(mv, 0, 200) + self.assertEqual(buf[0], 200) + + def test_set_sint_memoryview(self) -> None: + from snap7.util.setters import set_sint + + buf = bytearray(1) + mv = memoryview(buf) + set_sint(mv, 0, -50) + self.assertEqual(struct.unpack(">b", buf)[0], -50) + + def test_set_lreal_memoryview(self) -> None: + from snap7.util.setters import set_lreal + + buf = bytearray(8) + mv = memoryview(buf) + set_lreal(mv, 0, 3.14159265358979) + val = struct.unpack(">d", buf)[0] + self.assertAlmostEqual(val, 3.14159265358979, places=10) + + def test_set_string_memoryview(self) -> None: + from snap7.util.setters import set_string + + buf = bytearray(20) + mv = memoryview(buf) + set_string(mv, 0, "hello", 10) + self.assertEqual(buf[1], 5) # length byte + + def test_set_fstring_memoryview(self) -> None: + buf = bytearray(10) + mv = memoryview(buf) + set_fstring(mv, 0, "hi", 5) + self.assertEqual(chr(buf[0]), "h") + self.assertEqual(chr(buf[1]), "i") + + def test_set_char_memoryview(self) -> None: + from snap7.util.setters import set_char + + buf = bytearray(1) + mv = memoryview(buf) + set_char(mv, 0, "A") + self.assertEqual(buf[0], ord("A")) + + def test_set_date_memoryview(self) -> None: + from snap7.util.setters import set_date + + buf = bytearray(2) + mv = memoryview(buf) + set_date(mv, 0, datetime.date(2024, 3, 27)) + self.assertEqual(buf, bytearray(b"\x30\xd8")) + + def test_set_udint_memoryview(self) -> None: + from snap7.util.setters import set_udint + + buf = bytearray(4) + mv = memoryview(buf) + set_udint(mv, 0, 4294967295) + self.assertEqual(struct.unpack(">I", buf)[0], 4294967295) + + def test_set_uint_memoryview(self) -> None: + from snap7.util.setters import set_uint + + buf = bytearray(2) + mv = memoryview(buf) + set_uint(mv, 0, 12345) + self.assertEqual(struct.unpack(">H", buf)[0], 12345) + + def test_set_time_memoryview(self) -> None: + buf = bytearray(4) + mv = memoryview(buf) + set_time(mv, 0, "1:2:3:4.567") + self.assertNotEqual(buf, bytearray(4)) + + +_db_test_spec = """ +4 ID INT +6 NAME STRING[4] + +12.0 testbool1 BOOL +12.1 testbool2 BOOL +13 testReal REAL +17 testDword DWORD +21 testint2 INT +23 testDint DINT +27 testWord WORD +29 testS5time S5TIME +31 testdateandtime DATE_AND_TIME +43 testusint0 USINT +44 testsint0 SINT +46 testTime TIME +50 testByte BYTE +51 testUint UINT +53 testUdint UDINT +57 testLreal LREAL +65 testChar CHAR +66 testWchar WCHAR +68 testWstring WSTRING[4] +80 testDate DATE +82 testTod TOD +86 testDtl DTL +98 testFstring FSTRING[8] +""" + +_db_bytearray = bytearray( + [ + 0, + 0, # test int + 4, + 4, + ord("t"), + ord("e"), + ord("s"), + ord("t"), # test string + 0x0F, # test bools + 68, + 78, + 211, + 51, # test real + 255, + 255, + 255, + 255, # test dword + 0, + 0, # test int 2 + 128, + 0, + 0, + 0, # test dint + 255, + 255, # test word + 0, + 16, # test s5time + 32, + 7, + 18, + 23, + 50, + 2, + 133, + 65, # date_and_time (8 bytes) + 254, + 254, + 254, + 254, + 254, # padding + 127, # usint + 128, # sint + 143, + 255, + 255, + 255, # time + 254, # byte + 48, + 57, # uint + 7, + 91, + 205, + 21, # udint + 65, + 157, + 111, + 52, + 84, + 126, + 107, + 117, # lreal + 65, # char 'A' + 3, + 169, # wchar + 0, + 4, + 0, + 4, + 3, + 169, + 0, + ord("s"), + 0, + ord("t"), + 0, + 196, # wstring + 45, + 235, # date + 2, + 179, + 41, + 128, # tod + 7, + 230, + 3, + 9, + 4, + 12, + 34, + 45, + 0, + 0, + 0, + 0, # dtl + 116, + 101, + 115, + 116, + 32, + 32, + 32, + 32, # fstring 'test ' + ] +) + + +class TestPrintRow: + def test_print_row_output(self, caplog: pytest.LogCaptureFixture) -> None: + data = bytearray([65, 66, 67, 68, 69]) + with caplog.at_level(logging.INFO, logger="snap7.util.db"): + print_row(data) + assert "65" in caplog.text + assert "A" in caplog.text + + +class TestDBDictInterface: + def setup_method(self) -> None: + test_array = bytearray(_db_bytearray * 3) + self.db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=3, layout_offset=4, db_offset=0) + + def test_len(self) -> None: + assert len(self.db) == 3 + + def test_getitem(self) -> None: + row = self.db["0"] + assert row is not None + + def test_getitem_missing(self) -> None: + row = self.db["999"] + assert row is None + + def test_contains(self) -> None: + assert "0" in self.db + assert "999" not in self.db + + def test_keys(self) -> None: + keys = list(self.db.keys()) + assert "0" in keys + assert len(keys) == 3 + + def test_items(self) -> None: + items = list(self.db.items()) + assert len(items) == 3 + for key, row in items: + assert isinstance(key, str) + assert isinstance(row, Row) + + def test_iter(self) -> None: + for key, row in self.db: + assert isinstance(key, str) + assert isinstance(row, Row) + + def test_get_bytearray(self) -> None: + ba = self.db.get_bytearray() + assert isinstance(ba, bytearray) + + +class TestDBWithIdField: + def test_id_field_creates_named_index(self) -> None: + test_array = bytearray(_db_bytearray * 2) + # Set different ID values for each row + struct.pack_into(">h", test_array, 0, 10) # row 0, ID at offset 0 (spec offset 4, layout_offset 4) + struct.pack_into(">h", test_array, len(_db_bytearray), 20) # row 1 + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=2, id_field="ID", layout_offset=4, db_offset=0) + assert "10" in db + assert "20" in db + + +class TestDBSetData: + def test_set_data_valid(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + new_data = bytearray(len(_db_bytearray)) + db.set_data(new_data) + assert db.get_bytearray() is new_data + + def test_set_data_invalid_type(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + with pytest.raises(TypeError): + db.set_data(b"not a bytearray") # type: ignore[arg-type] + + +class TestDBReadWrite: + """Test DB.read() and DB.write() with mocked client.""" + + def test_read_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + mock_client = MagicMock() + mock_client.db_read.return_value = bytearray(len(_db_bytearray)) + db.read(mock_client) + mock_client.db_read.assert_called_once() + + def test_read_non_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(0, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0, area=Area.MK) + mock_client = MagicMock() + mock_client.read_area.return_value = bytearray(len(_db_bytearray)) + db.read(mock_client) + mock_client.read_area.assert_called_once() + + def test_read_negative_row_size(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + db.row_size = -1 + mock_client = MagicMock() + with pytest.raises(ValueError, match="row_size"): + db.read(mock_client) + + def test_write_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + mock_client = MagicMock() + db.write(mock_client) + mock_client.db_write.assert_called_once() + + def test_write_non_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(0, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0, area=Area.MK) + mock_client = MagicMock() + db.write(mock_client) + mock_client.write_area.assert_called_once() + + def test_write_negative_row_size(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + db.row_size = -1 + mock_client = MagicMock() + with pytest.raises(ValueError, match="row_size"): + db.write(mock_client) + + def test_write_with_row_offset(self) -> None: + test_array = bytearray(_db_bytearray * 2) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=2, layout_offset=4, db_offset=0, row_offset=4) + mock_client = MagicMock() + db.write(mock_client) + # Should write each row individually via Row.write() + assert mock_client.db_write.call_count == 2 + + +class TestRowRepr: + def test_repr(self) -> None: + test_array = bytearray(_db_bytearray) + row = Row(test_array, _db_test_spec, layout_offset=4) + r = repr(row) + assert "ID" in r + assert "NAME" in r + + +class TestRowUnchanged: + def test_unchanged_true(self) -> None: + test_array = bytearray(_db_bytearray) + row = Row(test_array, _db_test_spec, layout_offset=4) + assert row.unchanged(test_array) is True + + def test_unchanged_false(self) -> None: + test_array = bytearray(_db_bytearray) + row = Row(test_array, _db_test_spec, layout_offset=4) + other = bytearray(len(_db_bytearray)) + assert row.unchanged(other) is False + + +class TestRowTypeError: + def test_invalid_bytearray_type(self) -> None: + with pytest.raises(TypeError): + Row("not a bytearray", _db_test_spec) # type: ignore[arg-type] + + +class TestRowReadWrite: + """Test Row.read() and Row.write() with mocked client through DB parent.""" + + def test_row_write_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + row = db["0"] + assert row is not None + mock_client = MagicMock() + row.write(mock_client) + mock_client.db_write.assert_called_once() + + def test_row_write_non_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(0, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0, area=Area.MK) + row = db["0"] + assert row is not None + mock_client = MagicMock() + row.write(mock_client) + mock_client.write_area.assert_called_once() + + def test_row_write_not_db_parent(self) -> None: + test_array = bytearray(_db_bytearray) + row = Row(test_array, _db_test_spec, layout_offset=4) + mock_client = MagicMock() + with pytest.raises(TypeError): + row.write(mock_client) + + def test_row_write_negative_row_size(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + row = db["0"] + assert row is not None + row.row_size = -1 + mock_client = MagicMock() + with pytest.raises(ValueError, match="row_size"): + row.write(mock_client) + + def test_row_read_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + row = db["0"] + assert row is not None + mock_client = MagicMock() + mock_client.db_read.return_value = bytearray(len(_db_bytearray)) + row.read(mock_client) + mock_client.db_read.assert_called_once() + + def test_row_read_non_db_area(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(0, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0, area=Area.MK) + row = db["0"] + assert row is not None + mock_client = MagicMock() + mock_client.read_area.return_value = bytearray(len(_db_bytearray)) + row.read(mock_client) + mock_client.read_area.assert_called_once() + + def test_row_read_not_db_parent(self) -> None: + test_array = bytearray(_db_bytearray) + row = Row(test_array, _db_test_spec, layout_offset=4) + mock_client = MagicMock() + with pytest.raises(TypeError): + row.read(mock_client) + + def test_row_read_negative_row_size(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0) + row = db["0"] + assert row is not None + row.row_size = -1 + mock_client = MagicMock() + with pytest.raises(ValueError, match="row_size"): + row.read(mock_client) + + +class TestRowSetValueTypes: + """Test set_value for various type branches.""" + + def setup_method(self) -> None: + self.test_array = bytearray(_db_bytearray) + self.row = Row(self.test_array, _db_test_spec, layout_offset=4) + + def test_set_int(self) -> None: + self.row.set_value(4, "INT", 42) + assert self.row.get_value(4, "INT") == 42 + + def test_set_uint(self) -> None: + self.row.set_value(51, "UINT", 1000) + assert self.row.get_value(51, "UINT") == 1000 + + def test_set_dint(self) -> None: + self.row.set_value(23, "DINT", -100) + assert self.row.get_value(23, "DINT") == -100 + + def test_set_udint(self) -> None: + self.row.set_value(53, "UDINT", 999999) + assert self.row.get_value(53, "UDINT") == 999999 + + def test_set_word(self) -> None: + self.row.set_value(27, "WORD", 12345) + assert self.row.get_value(27, "WORD") == 12345 + + def test_set_usint(self) -> None: + self.row.set_value(43, "USINT", 200) + assert self.row.get_value(43, "USINT") == 200 + + def test_set_sint(self) -> None: + self.row.set_value(44, "SINT", -50) + assert self.row.get_value(44, "SINT") == -50 + + def test_set_time(self) -> None: + self.row.set_value(46, "TIME", "1:2:3:4.5") + assert self.row.get_value(46, "TIME") is not None + + def test_set_date(self) -> None: + d = datetime.date(2024, 1, 15) + self.row.set_value(80, "DATE", d) + assert self.row.get_value(80, "DATE") == d + + def test_set_tod(self) -> None: + td = datetime.timedelta(hours=5, minutes=30) + self.row.set_value(82, "TOD", td) + assert self.row.get_value(82, "TOD") == td + + def test_set_time_of_day(self) -> None: + td = datetime.timedelta(hours=1) + self.row.set_value(82, "TIME_OF_DAY", td) + assert self.row.get_value(82, "TIME_OF_DAY") == td + + def test_set_dtl(self) -> None: + dt = datetime.datetime(2024, 6, 15, 10, 20, 30) + self.row.set_value(86, "DTL", dt) + result = self.row.get_value(86, "DTL") + assert result.year == 2024 # type: ignore[union-attr] + + def test_set_date_and_time(self) -> None: + dt = datetime.datetime(2020, 7, 12, 17, 32, 2, 854000) + self.row.set_value(31, "DATE_AND_TIME", dt) + result = self.row.get_value(31, "DATE_AND_TIME") + assert "2020" in str(result) + + def test_set_unknown_type_raises(self) -> None: + with pytest.raises(ValueError): + self.row.set_value(4, "UNKNOWN_TYPE", 42) + + def test_set_string(self) -> None: + self.row.set_value(6, "STRING[4]", "ab") + assert self.row.get_value(6, "STRING[4]") == "ab" + + def test_set_wstring(self) -> None: + self.row.set_value(68, "WSTRING[4]", "ab") + assert self.row.get_value(68, "WSTRING[4]") == "ab" + + def test_set_fstring(self) -> None: + self.row.set_value(98, "FSTRING[8]", "hi") + assert self.row.get_value(98, "FSTRING[8]") == "hi" + + def test_set_real(self) -> None: + self.row.set_value(13, "REAL", 3.14) + assert abs(self.row.get_value(13, "REAL") - 3.14) < 0.01 # type: ignore[operator] + + def test_set_lreal(self) -> None: + self.row.set_value(57, "LREAL", 2.718281828) + assert abs(self.row.get_value(57, "LREAL") - 2.718281828) < 0.0001 # type: ignore[operator] + + def test_set_char(self) -> None: + self.row.set_value(65, "CHAR", "Z") + assert self.row.get_value(65, "CHAR") == "Z" + + def test_set_wchar(self) -> None: + self.row.set_value(66, "WCHAR", "W") + assert self.row.get_value(66, "WCHAR") == "W" + + +class TestRowGetValueEdgeCases: + """Test get_value for edge cases.""" + + def setup_method(self) -> None: + self.test_array = bytearray(_db_bytearray) + self.row = Row(self.test_array, _db_test_spec, layout_offset=4) + + def test_unknown_type_raises(self) -> None: + with pytest.raises(ValueError): + self.row.get_value(4, "NONEXISTENT") + + def test_string_no_max_size(self) -> None: + spec = "4 test STRING" + row = Row(bytearray(20), spec, layout_offset=0) + with pytest.raises(ValueError, match="Max size"): + row.get_value(4, "STRING") + + def test_fstring_no_max_size(self) -> None: + with pytest.raises(ValueError, match="Max size"): + self.row.get_value(98, "FSTRING") + + def test_wstring_no_max_size(self) -> None: + with pytest.raises(ValueError, match="Max size"): + self.row.get_value(68, "WSTRING") + + +class TestRowSetValueEdgeCases: + """Test set_value edge cases for string types.""" + + def setup_method(self) -> None: + self.test_array = bytearray(_db_bytearray) + self.row = Row(self.test_array, _db_test_spec, layout_offset=4) + + def test_fstring_no_max_size(self) -> None: + with pytest.raises(ValueError, match="Max size"): + self.row.set_value(98, "FSTRING", "test") + + def test_string_no_max_size(self) -> None: + with pytest.raises(ValueError, match="Max size"): + self.row.set_value(6, "STRING", "test") + + def test_wstring_no_max_size(self) -> None: + with pytest.raises(ValueError, match="Max size"): + self.row.set_value(68, "WSTRING", "test") + + +class TestRowWriteWithRowOffset: + """Test Row.write() with row_offset set.""" + + def test_write_with_row_offset(self) -> None: + test_array = bytearray(_db_bytearray) + db = DB(1, test_array, _db_test_spec, row_size=len(_db_bytearray), size=1, layout_offset=4, db_offset=0, row_offset=10) + row = db["0"] + assert row is not None + mock_client = MagicMock() + row.write(mock_client) + # The data written should start at db_offset + row_offset + mock_client.db_write.assert_called_once() + + if __name__ == "__main__": unittest.main() diff --git a/uv.lock b/uv.lock index 4e57fde2..38c470c2 100644 --- a/uv.lock +++ b/uv.lock @@ -25,13 +25,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845, upload-time = "2026-02-01T12:30:53.445Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "cachetools" -version = "7.0.4" +version = "7.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/cc/eb3fd22f3b96b8b70ce456d0854ef08434e5ca79c02bf8db3fc07ccfca87/cachetools-7.0.4.tar.gz", hash = "sha256:7042c0e4eea87812f04744ce6ee9ed3de457875eb1f82d8a206c46d6e48b6734", size = 37379, upload-time = "2026-03-08T21:37:17.133Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/07/56595285564e90777d758ebd383d6b0b971b87729bbe2184a849932a3736/cachetools-7.0.1.tar.gz", hash = "sha256:e31e579d2c5b6e2944177a0397150d312888ddf4e16e12f1016068f0c03b8341", size = 36126, upload-time = "2026-02-10T22:24:05.03Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/83/bc/72adfb3f2ed19eb0317f89ea9b1eeccc670ae46bc394ec2c4ba1dd8c22b7/cachetools-7.0.4-py3-none-any.whl", hash = "sha256:0c8bb1b9ec8194fa4d764accfde602dfe52f70d0f311e62792d4c3f8c051b1e9", size = 13900, upload-time = "2026-03-08T21:37:15.805Z" }, + { url = "https://files.pythonhosted.org/packages/ed/9e/5faefbf9db1db466d633735faceda1f94aa99ce506ac450d232536266b32/cachetools-7.0.1-py3-none-any.whl", hash = "sha256:8f086515c254d5664ae2146d14fc7f65c9a4bce75152eb247e5a9c5e6d7b2ecf", size = 13484, upload-time = "2026-02-10T22:24:03.741Z" }, ] [[package]] @@ -319,11 +328,11 @@ wheels = [ [[package]] name = "filelock" -version = "3.25.0" +version = "3.24.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/77/18/a1fd2231c679dcb9726204645721b12498aeac28e1ad0601038f94b42556/filelock-3.25.0.tar.gz", hash = "sha256:8f00faf3abf9dc730a1ffe9c354ae5c04e079ab7d3a683b7c32da5dd05f26af3", size = 40158, upload-time = "2026-03-01T15:08:45.916Z" } +sdist = { url = "https://files.pythonhosted.org/packages/73/92/a8e2479937ff39185d20dd6a851c1a63e55849e447a55e798cc2e1f49c65/filelock-3.24.3.tar.gz", hash = "sha256:011a5644dc937c22699943ebbfc46e969cdde3e171470a6e40b9533e5a72affa", size = 37935, upload-time = "2026-02-19T00:48:20.543Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/0b/de6f54d4a8bedfe8645c41497f3c18d749f0bd3218170c667bf4b81d0cdd/filelock-3.25.0-py3-none-any.whl", hash = "sha256:5ccf8069f7948f494968fc0713c10e5c182a9c9d9eef3a636307a20c2490f047", size = 26427, upload-time = "2026-03-01T15:08:44.593Z" }, + { url = "https://files.pythonhosted.org/packages/9c/0f/5d0c71a1aefeb08efff26272149e07ab922b64f46c63363756224bd6872e/filelock-3.24.3-py3-none-any.whl", hash = "sha256:426e9a4660391f7f8a810d71b0555bce9008b0a1cc342ab1f6947d37639e002d", size = 24331, upload-time = "2026-02-19T00:48:18.465Z" }, ] [[package]] @@ -631,11 +640,11 @@ wheels = [ [[package]] name = "platformdirs" -version = "4.9.4" +version = "4.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/56/8d4c30c8a1d07013911a8fdbd8f89440ef9f08d07a1b50ab8ca8be5a20f9/platformdirs-4.9.4.tar.gz", hash = "sha256:1ec356301b7dc906d83f371c8f487070e99d3ccf9e501686456394622a01a934", size = 28737, upload-time = "2026-03-05T18:34:13.271Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/04/fea538adf7dbbd6d186f551d595961e564a3b6715bdf276b477460858672/platformdirs-4.9.2.tar.gz", hash = "sha256:9a33809944b9db043ad67ca0db94b14bf452cc6aeaac46a88ea55b26e2e9d291", size = 28394, upload-time = "2026-02-16T03:56:10.574Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" }, + { url = "https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl", hash = "sha256:9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd", size = 21168, upload-time = "2026-02-16T03:56:08.891Z" }, ] [[package]] @@ -687,6 +696,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0" @@ -759,6 +782,7 @@ doc = [ test = [ { name = "mypy" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-html" }, { name = "ruff" }, @@ -774,6 +798,7 @@ requires-dist = [ { name = "click", marker = "extra == 'cli'" }, { name = "mypy", marker = "extra == 'test'" }, { name = "pytest", marker = "extra == 'test'" }, + { name = "pytest-asyncio", marker = "extra == 'test'" }, { name = "pytest-cov", marker = "extra == 'test'" }, { name = "pytest-html", marker = "extra == 'test'" }, { name = "rich", marker = "extra == 'cli'" }, @@ -827,27 +852,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.5" +version = "0.15.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/77/9b/840e0039e65fcf12758adf684d2289024d6140cde9268cc59887dc55189c/ruff-0.15.5.tar.gz", hash = "sha256:7c3601d3b6d76dce18c5c824fc8d06f4eef33d6df0c21ec7799510cde0f159a2", size = 4574214, upload-time = "2026-03-05T20:06:34.946Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/31/d6e536cdebb6568ae75a7f00e4b4819ae0ad2640c3604c305a0428680b0c/ruff-0.15.4.tar.gz", hash = "sha256:3412195319e42d634470cc97aa9803d07e9d5c9223b99bcb1518f0c725f26ae1", size = 4569550, upload-time = "2026-02-26T20:04:14.959Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/47/20/5369c3ce21588c708bcbe517a8fbe1a8dfdb5dfd5137e14790b1da71612c/ruff-0.15.5-py3-none-linux_armv6l.whl", hash = "sha256:4ae44c42281f42e3b06b988e442d344a5b9b72450ff3c892e30d11b29a96a57c", size = 10478185, upload-time = "2026-03-05T20:06:29.093Z" }, - { url = "https://files.pythonhosted.org/packages/44/ed/e81dd668547da281e5dce710cf0bc60193f8d3d43833e8241d006720e42b/ruff-0.15.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6edd3792d408ebcf61adabc01822da687579a1a023f297618ac27a5b51ef0080", size = 10859201, upload-time = "2026-03-05T20:06:32.632Z" }, - { url = "https://files.pythonhosted.org/packages/c4/8f/533075f00aaf19b07c5cd6aa6e5d89424b06b3b3f4583bfa9c640a079059/ruff-0.15.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:89f463f7c8205a9f8dea9d658d59eff49db05f88f89cc3047fb1a02d9f344010", size = 10184752, upload-time = "2026-03-05T20:06:40.312Z" }, - { url = "https://files.pythonhosted.org/packages/66/0e/ba49e2c3fa0395b3152bad634c7432f7edfc509c133b8f4529053ff024fb/ruff-0.15.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba786a8295c6574c1116704cf0b9e6563de3432ac888d8f83685654fe528fd65", size = 10534857, upload-time = "2026-03-05T20:06:19.581Z" }, - { url = "https://files.pythonhosted.org/packages/59/71/39234440f27a226475a0659561adb0d784b4d247dfe7f43ffc12dd02e288/ruff-0.15.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd4b801e57955fe9f02b31d20375ab3a5c4415f2e5105b79fb94cf2642c91440", size = 10309120, upload-time = "2026-03-05T20:06:00.435Z" }, - { url = "https://files.pythonhosted.org/packages/f5/87/4140aa86a93df032156982b726f4952aaec4a883bb98cb6ef73c347da253/ruff-0.15.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391f7c73388f3d8c11b794dbbc2959a5b5afe66642c142a6effa90b45f6f5204", size = 11047428, upload-time = "2026-03-05T20:05:51.867Z" }, - { url = "https://files.pythonhosted.org/packages/5a/f7/4953e7e3287676f78fbe85e3a0ca414c5ca81237b7575bdadc00229ac240/ruff-0.15.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc18f30302e379fe1e998548b0f5e9f4dff907f52f73ad6da419ea9c19d66c8", size = 11914251, upload-time = "2026-03-05T20:06:22.887Z" }, - { url = "https://files.pythonhosted.org/packages/77/46/0f7c865c10cf896ccf5a939c3e84e1cfaeed608ff5249584799a74d33835/ruff-0.15.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc6e7f90087e2d27f98dc34ed1b3ab7c8f0d273cc5431415454e22c0bd2a681", size = 11333801, upload-time = "2026-03-05T20:05:57.168Z" }, - { url = "https://files.pythonhosted.org/packages/d3/01/a10fe54b653061585e655f5286c2662ebddb68831ed3eaebfb0eb08c0a16/ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1cb7169f53c1ddb06e71a9aebd7e98fc0fea936b39afb36d8e86d36ecc2636a", size = 11206821, upload-time = "2026-03-05T20:06:03.441Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0d/2132ceaf20c5e8699aa83da2706ecb5c5dcdf78b453f77edca7fb70f8a93/ruff-0.15.5-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9b037924500a31ee17389b5c8c4d88874cc6ea8e42f12e9c61a3d754ff72f1ca", size = 11133326, upload-time = "2026-03-05T20:06:25.655Z" }, - { url = "https://files.pythonhosted.org/packages/72/cb/2e5259a7eb2a0f87c08c0fe5bf5825a1e4b90883a52685524596bfc93072/ruff-0.15.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:65bb414e5b4eadd95a8c1e4804f6772bbe8995889f203a01f77ddf2d790929dd", size = 10510820, upload-time = "2026-03-05T20:06:37.79Z" }, - { url = "https://files.pythonhosted.org/packages/ff/20/b67ce78f9e6c59ffbdb5b4503d0090e749b5f2d31b599b554698a80d861c/ruff-0.15.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d20aa469ae3b57033519c559e9bc9cd9e782842e39be05b50e852c7c981fa01d", size = 10302395, upload-time = "2026-03-05T20:05:54.504Z" }, - { url = "https://files.pythonhosted.org/packages/5f/e5/719f1acccd31b720d477751558ed74e9c88134adcc377e5e886af89d3072/ruff-0.15.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:15388dd28c9161cdb8eda68993533acc870aa4e646a0a277aa166de9ad5a8752", size = 10754069, upload-time = "2026-03-05T20:06:06.422Z" }, - { url = "https://files.pythonhosted.org/packages/c3/9c/d1db14469e32d98f3ca27079dbd30b7b44dbb5317d06ab36718dee3baf03/ruff-0.15.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b30da330cbd03bed0c21420b6b953158f60c74c54c5f4c1dabbdf3a57bf355d2", size = 11304315, upload-time = "2026-03-05T20:06:10.867Z" }, - { url = "https://files.pythonhosted.org/packages/28/3a/950367aee7c69027f4f422059227b290ed780366b6aecee5de5039d50fa8/ruff-0.15.5-py3-none-win32.whl", hash = "sha256:732e5ee1f98ba5b3679029989a06ca39a950cced52143a0ea82a2102cb592b74", size = 10551676, upload-time = "2026-03-05T20:06:13.705Z" }, - { url = "https://files.pythonhosted.org/packages/b8/00/bf077a505b4e649bdd3c47ff8ec967735ce2544c8e4a43aba42ee9bf935d/ruff-0.15.5-py3-none-win_amd64.whl", hash = "sha256:821d41c5fa9e19117616c35eaa3f4b75046ec76c65e7ae20a333e9a8696bc7fe", size = 11678972, upload-time = "2026-03-05T20:06:45.379Z" }, - { url = "https://files.pythonhosted.org/packages/fe/4e/cd76eca6db6115604b7626668e891c9dd03330384082e33662fb0f113614/ruff-0.15.5-py3-none-win_arm64.whl", hash = "sha256:b498d1c60d2fe5c10c45ec3f698901065772730b411f164ae270bb6bfcc4740b", size = 10965572, upload-time = "2026-03-05T20:06:16.984Z" }, + { url = "https://files.pythonhosted.org/packages/f2/82/c11a03cfec3a4d26a0ea1e571f0f44be5993b923f905eeddfc397c13d360/ruff-0.15.4-py3-none-linux_armv6l.whl", hash = "sha256:a1810931c41606c686bae8b5b9a8072adac2f611bb433c0ba476acba17a332e0", size = 10453333, upload-time = "2026-02-26T20:04:20.093Z" }, + { url = "https://files.pythonhosted.org/packages/ce/5d/6a1f271f6e31dffb31855996493641edc3eef8077b883eaf007a2f1c2976/ruff-0.15.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5a1632c66672b8b4d3e1d1782859e98d6e0b4e70829530666644286600a33992", size = 10853356, upload-time = "2026-02-26T20:04:05.808Z" }, + { url = "https://files.pythonhosted.org/packages/b1/d8/0fab9f8842b83b1a9c2bf81b85063f65e93fb512e60effa95b0be49bfc54/ruff-0.15.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a4386ba2cd6c0f4ff75252845906acc7c7c8e1ac567b7bc3d373686ac8c222ba", size = 10187434, upload-time = "2026-02-26T20:03:54.656Z" }, + { url = "https://files.pythonhosted.org/packages/85/cc/cc220fd9394eff5db8d94dec199eec56dd6c9f3651d8869d024867a91030/ruff-0.15.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2496488bdfd3732747558b6f95ae427ff066d1fcd054daf75f5a50674411e75", size = 10535456, upload-time = "2026-02-26T20:03:52.738Z" }, + { url = "https://files.pythonhosted.org/packages/fa/0f/bced38fa5cf24373ec767713c8e4cadc90247f3863605fb030e597878661/ruff-0.15.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f1c4893841ff2d54cbda1b2860fa3260173df5ddd7b95d370186f8a5e66a4ac", size = 10287772, upload-time = "2026-02-26T20:04:08.138Z" }, + { url = "https://files.pythonhosted.org/packages/2b/90/58a1802d84fed15f8f281925b21ab3cecd813bde52a8ca033a4de8ab0e7a/ruff-0.15.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:820b8766bd65503b6c30aaa6331e8ef3a6e564f7999c844e9a547c40179e440a", size = 11049051, upload-time = "2026-02-26T20:04:03.53Z" }, + { url = "https://files.pythonhosted.org/packages/d2/ac/b7ad36703c35f3866584564dc15f12f91cb1a26a897dc2fd13d7cb3ae1af/ruff-0.15.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9fb74bab47139c1751f900f857fa503987253c3ef89129b24ed375e72873e85", size = 11890494, upload-time = "2026-02-26T20:04:10.497Z" }, + { url = "https://files.pythonhosted.org/packages/93/3d/3eb2f47a39a8b0da99faf9c54d3eb24720add1e886a5309d4d1be73a6380/ruff-0.15.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f80c98765949c518142b3a50a5db89343aa90f2c2bf7799de9986498ae6176db", size = 11326221, upload-time = "2026-02-26T20:04:12.84Z" }, + { url = "https://files.pythonhosted.org/packages/ff/90/bf134f4c1e5243e62690e09d63c55df948a74084c8ac3e48a88468314da6/ruff-0.15.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:451a2e224151729b3b6c9ffb36aed9091b2996fe4bdbd11f47e27d8f2e8888ec", size = 11168459, upload-time = "2026-02-26T20:04:00.969Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e5/a64d27688789b06b5d55162aafc32059bb8c989c61a5139a36e1368285eb/ruff-0.15.4-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:a8f157f2e583c513c4f5f896163a93198297371f34c04220daf40d133fdd4f7f", size = 11104366, upload-time = "2026-02-26T20:03:48.099Z" }, + { url = "https://files.pythonhosted.org/packages/f1/f6/32d1dcb66a2559763fc3027bdd65836cad9eb09d90f2ed6a63d8e9252b02/ruff-0.15.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:917cc68503357021f541e69b35361c99387cdbbf99bd0ea4aa6f28ca99ff5338", size = 10510887, upload-time = "2026-02-26T20:03:45.771Z" }, + { url = "https://files.pythonhosted.org/packages/ff/92/22d1ced50971c5b6433aed166fcef8c9343f567a94cf2b9d9089f6aa80fe/ruff-0.15.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e9737c8161da79fd7cfec19f1e35620375bd8b2a50c3e77fa3d2c16f574105cc", size = 10285939, upload-time = "2026-02-26T20:04:22.42Z" }, + { url = "https://files.pythonhosted.org/packages/e6/f4/7c20aec3143837641a02509a4668fb146a642fd1211846634edc17eb5563/ruff-0.15.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:291258c917539e18f6ba40482fe31d6f5ac023994ee11d7bdafd716f2aab8a68", size = 10765471, upload-time = "2026-02-26T20:03:58.924Z" }, + { url = "https://files.pythonhosted.org/packages/d0/09/6d2f7586f09a16120aebdff8f64d962d7c4348313c77ebb29c566cefc357/ruff-0.15.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3f83c45911da6f2cd5936c436cf86b9f09f09165f033a99dcf7477e34041cbc3", size = 11263382, upload-time = "2026-02-26T20:04:24.424Z" }, + { url = "https://files.pythonhosted.org/packages/1b/fa/2ef715a1cd329ef47c1a050e10dee91a9054b7ce2fcfdd6a06d139afb7ec/ruff-0.15.4-py3-none-win32.whl", hash = "sha256:65594a2d557d4ee9f02834fcdf0a28daa8b3b9f6cb2cb93846025a36db47ef22", size = 10506664, upload-time = "2026-02-26T20:03:50.56Z" }, + { url = "https://files.pythonhosted.org/packages/d0/a8/c688ef7e29983976820d18710f955751d9f4d4eb69df658af3d006e2ba3e/ruff-0.15.4-py3-none-win_amd64.whl", hash = "sha256:04196ad44f0df220c2ece5b0e959c2f37c777375ec744397d21d15b50a75264f", size = 11651048, upload-time = "2026-02-26T20:04:17.191Z" }, + { url = "https://files.pythonhosted.org/packages/3e/0a/9e1be9035b37448ce2e68c978f0591da94389ade5a5abafa4cf99985d1b2/ruff-0.15.4-py3-none-win_arm64.whl", hash = "sha256:60d5177e8cfc70e51b9c5fad936c634872a74209f934c1e79107d11787ad5453", size = 10966776, upload-time = "2026-02-26T20:03:56.908Z" }, ] [[package]] @@ -1091,18 +1116,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" }, ] -[[package]] -name = "tomli-w" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/75/241269d1da26b624c0d5e110e8149093c759b7a286138f4efd61a60e75fe/tomli_w-1.2.0.tar.gz", hash = "sha256:2dd14fac5a47c27be9cd4c976af5a12d87fb1f0b4512f81d69cce3b35ae25021", size = 7184, upload-time = "2025-01-15T12:07:24.262Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/18/c86eb8e0202e32dd3df50d43d7ff9854f8e0603945ff398974c1d91ac1ef/tomli_w-1.2.0-py3-none-any.whl", hash = "sha256:188306098d013b691fcadc011abd66727d3c414c571bb01b1a174ba8c983cf90", size = 6675, upload-time = "2025-01-15T12:07:22.074Z" }, -] - [[package]] name = "tox" -version = "4.49.0" +version = "4.46.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -1113,39 +1129,38 @@ dependencies = [ { name = "pluggy" }, { name = "pyproject-api" }, { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "tomli-w" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, { name = "virtualenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c5/5a/56146cae67d337426a98cf95f1a9f3ae8b557879df9a03332ef7d6654496/tox-4.49.0.tar.gz", hash = "sha256:2e01f09ae1226749466cbcd8c514fe988ffc8c76b5d523c7f9b745d1711a6e71", size = 259917, upload-time = "2026-03-06T19:57:10.723Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/03/10faee6ee03437867cd76198afd22dc5af3fca61d9b9b5a8d8cff1952db2/tox-4.46.3.tar.gz", hash = "sha256:2e87609b7832c818cad093304ea23d7eb124f8ecbab0625463b73ce5e850e1c2", size = 250933, upload-time = "2026-02-25T15:48:33.542Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/97/db/c13e849355a7833b319785bafbc947104f9161b964884b159ca94984965a/tox-4.49.0-py3-none-any.whl", hash = "sha256:97cf3cea10c12442569a31bfa411600fbbfc8cb972ad4e48039599935c94a584", size = 206768, upload-time = "2026-03-06T19:57:09.369Z" }, + { url = "https://files.pythonhosted.org/packages/03/c2/d0e0d9700f9e2a6f20361c59c9fc044c1efebcdc5f13cbf353dd7d112410/tox-4.46.3-py3-none-any.whl", hash = "sha256:e9e1a91bce2836dba8169c005254913bd22aac490131c75a5ffc4fd091dffe0b", size = 201424, upload-time = "2026-02-25T15:48:31.684Z" }, ] [[package]] name = "tox-uv" -version = "1.33.1" +version = "1.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tox-uv-bare" }, { name = "uv" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/19/51/9a6dd32e34a3ee200c7890497093875e2c0a0b08737bb897e5916c6575bc/tox_uv-1.33.1-py3-none-any.whl", hash = "sha256:0617caa6444097434cdef24477307ff3242021a44088df673ae08771d3657f79", size = 5364, upload-time = "2026-03-02T17:06:18.32Z" }, + { url = "https://files.pythonhosted.org/packages/9f/67/736f40388b5e1d1b828b236014be7dd3d62a10642122763e6928d950edad/tox_uv-1.33.0-py3-none-any.whl", hash = "sha256:bb3055599940f111f3dead552dd7560b94339175ec58ffa7628ef59fad760d91", size = 5363, upload-time = "2026-02-25T13:22:52.186Z" }, ] [[package]] name = "tox-uv-bare" -version = "1.33.1" +version = "1.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "tox" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b0/7b/5ce3aa477400c7791968037b3bf27a50a4e19160a111d9956d20e5ce6b06/tox_uv_bare-1.33.1.tar.gz", hash = "sha256:169185feb3cc8f321eb2a33c575c61dc6efd9bf6044b97636a7381261d29e85c", size = 27203, upload-time = "2026-03-02T17:06:21.118Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/e8/f927b6cb26dae64732cb8c31f20be009d264ecf34751e72cf8ae7c7db17b/tox_uv_bare-1.33.0.tar.gz", hash = "sha256:34d8484a36ad121257f22823df154c246d831b84b01df91c4369a56cb4689d2e", size = 26995, upload-time = "2026-02-25T13:22:54.9Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/8e/ae95104165f4e2da5d9d25d8c71c7c935227c3eeb88e0376dab48b787a1c/tox_uv_bare-1.33.1-py3-none-any.whl", hash = "sha256:e64fdcd607a0f66212ef9edb36a5a672f10b461fce2a8216dda3e93c45d4a3f9", size = 19718, upload-time = "2026-03-02T17:06:19.657Z" }, + { url = "https://files.pythonhosted.org/packages/32/e5/0cae08b6c2908b4b8e51a91adaead58d06fd2393333aadc88c9a448da2c3/tox_uv_bare-1.33.0-py3-none-any.whl", hash = "sha256:80b5c1f4f5eda2dfd3a9de569665ad2dccdfb128ed1ee9f69c1dacfd100f6b4a", size = 19528, upload-time = "2026-02-25T13:22:53.269Z" }, ] [[package]] @@ -1186,32 +1201,32 @@ wheels = [ [[package]] name = "uv" -version = "0.10.9" +version = "0.10.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/59/235fa08a6b56de82a45a385dc2bf724502f720f0a9692a1a8cb24aab3e6f/uv-0.10.9.tar.gz", hash = "sha256:31e76ae92e70fec47c3efab0c8094035ad7a578454482415b496fa39fc4d685c", size = 3945685, upload-time = "2026-03-06T21:21:16.219Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d5/53/7a4274dad70b1d17efb99e36d45fc1b5e4e1e531b43247e518604394c761/uv-0.10.6.tar.gz", hash = "sha256:de86e5e1eb264e74a20fccf56889eea2463edb5296f560958e566647c537b52e", size = 3921763, upload-time = "2026-02-25T00:26:27.066Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/6d/f87f1530d5db4132776d49dddd88b1c77bc08fa7b32bf585b366204e6fc2/uv-0.10.9-py3-none-linux_armv6l.whl", hash = "sha256:0649f83fa0f44f18627c00b2a9a60e5c3486a34799b2c874f2b3945b76048a67", size = 22617914, upload-time = "2026-03-06T21:20:48.282Z" }, - { url = "https://files.pythonhosted.org/packages/6f/34/2e5cd576d312eb1131b615f49ee95ff6efb740965324843617adae729cf2/uv-0.10.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:880dd4cffe4bd184e8871ddf4c7d3c3b042e1f16d2682310644aa8d61eaea3e6", size = 21778779, upload-time = "2026-03-06T21:21:01.804Z" }, - { url = "https://files.pythonhosted.org/packages/89/35/684f641de4de2b20db7d2163c735b2bb211e3b3c84c241706d6448e5e868/uv-0.10.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a7a784254380552398a6baf4149faf5b31a4003275f685c28421cf8197178a08", size = 20384301, upload-time = "2026-03-06T21:21:04.089Z" }, - { url = "https://files.pythonhosted.org/packages/eb/5c/7170cfd1b4af09b435abc5a89ff315af130cf4a5082e5eb1206ee46bba67/uv-0.10.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:5ea0e8598fa012cfa4480ecad4d112bc70f514157c3cc1555a7611c7b6b1ab0a", size = 22226893, upload-time = "2026-03-06T21:20:50.902Z" }, - { url = "https://files.pythonhosted.org/packages/43/5c/68a17934dc8a2897fd7928b1c03c965373a820dc182aad96f1be6cce33a1/uv-0.10.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:2d6b5367e9bf87eca51c0f2ecda26a1ff931e41409977b4f0a420de2f3e617cf", size = 22233832, upload-time = "2026-03-06T21:21:11.748Z" }, - { url = "https://files.pythonhosted.org/packages/00/10/d262172ac59b669ca9c006bcbdb49c1a168cc314a5de576a4bb476dfab4c/uv-0.10.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd04e34db27f9a1d5a0871980edc9f910bb11afbc4abca8234d5a363cbe63c04", size = 22192193, upload-time = "2026-03-06T21:20:59.48Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e6/f75fef1e3e5b0cf3592a4c35ed5128164ef2e6bd6a2570a0782c0baf6d4b/uv-0.10.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:547deb57311fc64e4a6b8336228fca4cb4dcbeabdc6e85f14f7804dcd0bc8cd2", size = 23571687, upload-time = "2026-03-06T21:20:45.403Z" }, - { url = "https://files.pythonhosted.org/packages/31/28/4b1ee6f4aa0e1b935e66b6018691258d1b702ef9c5d8c71e853564ad0a3a/uv-0.10.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0091b6d0b666640d7407a433860184f77667077b73564e86d49c2a851f073a8", size = 24418225, upload-time = "2026-03-06T21:21:09.459Z" }, - { url = "https://files.pythonhosted.org/packages/39/a2/5e67987f8d55eeecca7d8f4e94ac3e973fa1e8aaf426fcb8f442e9f7e2bc/uv-0.10.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81b2286e6fd869e3507971f39d14829c03e2e31caa8ecc6347b0ffacabb95a5b", size = 23555724, upload-time = "2026-03-06T21:20:54.085Z" }, - { url = "https://files.pythonhosted.org/packages/79/34/b104c413079874493eed7bf11838b47b697cf1f0ed7e9de374ea37b4e4e0/uv-0.10.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d6deb30edbc22123be75479f99fb476613eaf38a8034c0e98bba24a344179", size = 23438145, upload-time = "2026-03-06T21:21:26.866Z" }, - { url = "https://files.pythonhosted.org/packages/27/8a/cad762b3e9bfb961b68b2ae43a258a92b522918958954b50b09dcb14bb4e/uv-0.10.9-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:24b1ce6d626e06c4582946b6af07b08a032fcccd81fe54c3db3ed2d1c63a97dc", size = 22326765, upload-time = "2026-03-06T21:21:14.283Z" }, - { url = "https://files.pythonhosted.org/packages/a7/62/7e066f197f3eb8f8f71e25d703a29c89849c9c047240c1223e29bc0a37e4/uv-0.10.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:fa3401780273d96a2960dbeab58452ce1b387ad8c5da25be6221c0188519e21d", size = 23215175, upload-time = "2026-03-06T21:21:29.673Z" }, - { url = "https://files.pythonhosted.org/packages/7e/06/51db93b5edb8b0202c0ec6caf3f24384f5abdfc180b6376a3710223fd56f/uv-0.10.9-py3-none-musllinux_1_1_i686.whl", hash = "sha256:8f94a31832d2b4c565312ea17a71b8dd2f971e5aa570c5b796a27b2c9fcdb163", size = 22784507, upload-time = "2026-03-06T21:21:20.676Z" }, - { url = "https://files.pythonhosted.org/packages/96/34/1db511d9259c1f32e5e094133546e5723e183a9ba2c64f7ca6156badddee/uv-0.10.9-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:842c39c19d9072f1ad53c71bb4ecd1c9caa311d5de9d19e09a636274a6c95e2e", size = 23660703, upload-time = "2026-03-06T21:21:06.667Z" }, - { url = "https://files.pythonhosted.org/packages/6c/a0/58388abb252c7a37bc67422fce3a6b87404ea3fac44ca20132a4ba502235/uv-0.10.9-py3-none-win32.whl", hash = "sha256:ed44047c602449916ba18a8596715ef7edbbd00859f3db9eac010dc62a0edd30", size = 21524142, upload-time = "2026-03-06T21:21:18.246Z" }, - { url = "https://files.pythonhosted.org/packages/c9/e9/adf7a12136573937d12ac189569e2e90e7fad18b458192083df6986f3013/uv-0.10.9-py3-none-win_amd64.whl", hash = "sha256:af79552276d8bd622048ab2d67ec22120a6af64d83963c46b1482218c27b571f", size = 24103389, upload-time = "2026-03-06T21:20:56.495Z" }, - { url = "https://files.pythonhosted.org/packages/5e/49/4971affd9c62d26b3ff4a84dc6432275be72d9615d95f7bb9e027beeeed8/uv-0.10.9-py3-none-win_arm64.whl", hash = "sha256:47e18a0521d76293d4f60d129f520b18bddf1976b4a47b50f0fcb04fb6a9d40f", size = 22454171, upload-time = "2026-03-06T21:21:24.596Z" }, + { url = "https://files.pythonhosted.org/packages/4f/f9/faf599c6928dc00d941629260bef157dadb67e8ffb7f4b127b8601f41177/uv-0.10.6-py3-none-linux_armv6l.whl", hash = "sha256:2b46ad78c86d68de6ec13ffaa3a8923467f757574eeaf318e0fce0f63ff77d7a", size = 22412946, upload-time = "2026-02-25T00:26:10.826Z" }, + { url = "https://files.pythonhosted.org/packages/c4/8f/82dd6aa8acd2e1b1ba12fd49210bd19843383538e0e63e8d7a23a7d39d93/uv-0.10.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a1d9873eb26cbef9138f8c52525bc3fd63be2d0695344cdcf84f0dc2838a6844", size = 21524262, upload-time = "2026-02-25T00:27:09.318Z" }, + { url = "https://files.pythonhosted.org/packages/3b/48/5767af19db6f21176e43dfde46ea04e33c49ba245ac2634e83db15d23c8f/uv-0.10.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5a62cdf5ba356dcc792b960e744d67056b0e6d778ce7381e1d78182357bd82e8", size = 20184248, upload-time = "2026-02-25T00:26:20.281Z" }, + { url = "https://files.pythonhosted.org/packages/27/1b/13c2fcdb776ae78b5c22eb2d34931bb3ef9bd71b9578b8fa7af8dd7c11c4/uv-0.10.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:b70a04d51e2239b3aee0e4d4ed9af18c910360155953017cecded5c529588e65", size = 22049300, upload-time = "2026-02-25T00:26:07.039Z" }, + { url = "https://files.pythonhosted.org/packages/6f/43/348e2c378b3733eba15f6144b35a8c84af5c884232d6bbed29e256f74b6f/uv-0.10.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:2b622059a1ae287f8b995dcb6f5548de83b89b745ff112801abbf09e25fd8fa9", size = 22030505, upload-time = "2026-02-25T00:26:46.171Z" }, + { url = "https://files.pythonhosted.org/packages/a5/3f/dcec580099bc52f73036bfb09acb42616660733de1cc3f6c92287d2c7f3e/uv-0.10.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f43db1aa80776386646453c07d5590e1ae621f031a2afe6efba90f89c34c628c", size = 22041360, upload-time = "2026-02-25T00:26:53.725Z" }, + { url = "https://files.pythonhosted.org/packages/2c/96/f70abe813557d317998806517bb53b3caa5114591766db56ae9cc142ff39/uv-0.10.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ca8a26694ba7d0ae902f11054734805741f2b080fe8397401b80c99264edab6", size = 23309916, upload-time = "2026-02-25T00:27:12.99Z" }, + { url = "https://files.pythonhosted.org/packages/db/1d/d8b955937dd0153b48fdcfd5ff70210d26e4b407188e976df620572534fd/uv-0.10.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f2cddae800d14159a9ccb4ff161648b0b0d1b31690d9c17076ec00f538c52ac", size = 24191174, upload-time = "2026-02-25T00:26:30.051Z" }, + { url = "https://files.pythonhosted.org/packages/c2/3d/3d0669d65bf4a270420d70ca0670917ce5c25c976c8b0acd52465852509b/uv-0.10.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:153fcf5375c988b2161bf3a6a7d9cc907d6bbe38f3cb16276da01b2dae4df72c", size = 23320328, upload-time = "2026-02-25T00:26:23.82Z" }, + { url = "https://files.pythonhosted.org/packages/85/f2/f2ccc2196fd6cf1321c2e8751a96afabcbc9509b184c671ece3e804effda/uv-0.10.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f27f2d135d4533f88537ecd254c72dfd25311d912da8649d15804284d70adb93", size = 23229798, upload-time = "2026-02-25T00:26:50.12Z" }, + { url = "https://files.pythonhosted.org/packages/2d/b9/1008266a041e8a55430a92aef8ecc58aaaa7eb7107a26cf4f7c127d14363/uv-0.10.6-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:dd993ec2bf5303a170946342955509559763cf8dcfe334ec7bb9f115a0f86021", size = 22143661, upload-time = "2026-02-25T00:26:42.507Z" }, + { url = "https://files.pythonhosted.org/packages/93/e4/1f8de7da5f844b4c9eafa616e262749cd4e3d9c685190b7967c4681869da/uv-0.10.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:8529e4d4aac40b4e7588177321cb332cc3309d36d7cc482470a1f6cfe7a7e14a", size = 22888045, upload-time = "2026-02-25T00:26:15.935Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2b/03b840dd0101dc69ef6e83ceb2e2970e4b4f118291266cf3332a4b64092c/uv-0.10.6-py3-none-musllinux_1_1_i686.whl", hash = "sha256:ed9e16453a5f73ee058c566392885f445d00534dc9e754e10ab9f50f05eb27a5", size = 22549404, upload-time = "2026-02-25T00:27:05.333Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4e/1ee4d4301874136a4b3bbd9eeba88da39f4bafa6f633b62aef77d8195c56/uv-0.10.6-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:33e5362039bfa91599df0b7487854440ffef1386ac681ec392d9748177fb1d43", size = 23426872, upload-time = "2026-02-25T00:26:35.01Z" }, + { url = "https://files.pythonhosted.org/packages/d3/e3/e000030118ff1a82ecfc6bd5af70949821edac739975a027994f5b17258f/uv-0.10.6-py3-none-win32.whl", hash = "sha256:fa7c504a1e16713b845d457421b07dd9c40f40d911ffca6897f97388de49df5a", size = 21501863, upload-time = "2026-02-25T00:26:57.182Z" }, + { url = "https://files.pythonhosted.org/packages/1c/cc/dd88c9f20c054ef0aea84ad1dd9f8b547463824857e4376463a948983bed/uv-0.10.6-py3-none-win_amd64.whl", hash = "sha256:ecded4d21834b21002bc6e9a2628d21f5c8417fd77a5db14250f1101bcb69dac", size = 23981891, upload-time = "2026-02-25T00:26:38.773Z" }, + { url = "https://files.pythonhosted.org/packages/cf/06/ca117002cd64f6701359253d8566ec7a0edcf61715b4969f07ee41d06f61/uv-0.10.6-py3-none-win_arm64.whl", hash = "sha256:4b5688625fc48565418c56a5cd6c8c32020dbb7c6fb7d10864c2d2c93c508302", size = 22339889, upload-time = "2026-02-25T00:27:00.818Z" }, ] [[package]] name = "virtualenv" -version = "21.2.0" +version = "21.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -1220,7 +1235,7 @@ dependencies = [ { name = "python-discovery" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/aa/92/58199fe10049f9703c2666e809c4f686c54ef0a68b0f6afccf518c0b1eb9/virtualenv-21.2.0.tar.gz", hash = "sha256:1720dc3a62ef5b443092e3f499228599045d7fea4c79199770499df8becf9098", size = 5840618, upload-time = "2026-03-09T17:24:38.013Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ce/4f/d6a5ff3b020c801c808b14e2d2330cdc8ebefe1cdfbc457ecc368e971fec/virtualenv-21.0.0.tar.gz", hash = "sha256:e8efe4271b4a5efe7a4dce9d60a05fd11859406c0d6aa8464f4cf451bc132889", size = 5836591, upload-time = "2026-02-25T20:21:07.691Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, + { url = "https://files.pythonhosted.org/packages/29/d1/3f62e4f9577b28c352c11623a03fb916096d5c131303d4861b4914481b6b/virtualenv-21.0.0-py3-none-any.whl", hash = "sha256:d44e70637402c7f4b10f48491c02a6397a3a187152a70cba0b6bc7642d69fb05", size = 5817167, upload-time = "2026-02-25T20:21:05.476Z" }, ]