diff --git a/snap7/client.py b/snap7/client.py index 40bdb707..b407832b 100644 --- a/snap7/client.py +++ b/snap7/client.py @@ -4,6 +4,7 @@ Drop-in replacement for the ctypes-based client with native Python implementation. """ +import copy import logging import struct import time @@ -20,6 +21,7 @@ from .datatypes import S7WordLen from .error import S7Error, S7ConnectionError, S7ProtocolError, S7StalePacketError from .client_base import ClientMixin +from .optimizer import ReadItem, ReadPacket, sort_items, merge_items, packetize, extract_results from .type import ( Area, @@ -38,9 +40,20 @@ CDataArrayType, ) +_VALID_AREA_VALUES: frozenset[int] = frozenset(a.value for a in Area) + logger = logging.getLogger(__name__) +class _OptimizationPlan: + """Cached optimization plan for repeated read_multi_vars calls with the same layout.""" + + def __init__(self, cache_key: tuple[int, ...], packets: list[ReadPacket], read_items: list[ReadItem]): + self.cache_key = cache_key + self.packets = packets + self.read_items = read_items + + class Client(ClientMixin): """ Pure Python S7 client implementation. @@ -100,6 +113,11 @@ def __init__(self, lib_location: Optional[str] = None, **kwargs: Any): Parameter.PDURequest: 480, } + # Multi-read optimizer state + self._opt_plan: Optional[_OptimizationPlan] = None + self.multi_read_max_gap: int = 5 + self.use_optimizer: bool = True + # Async operation state self._async_pending = False self._async_result: Optional[bytearray] = None @@ -210,6 +228,7 @@ def disconnect(self) -> int: self.connection = None self.connected = False + self._opt_plan = None logger.info(f"Disconnected from {self.host}:{self.port}") return 0 @@ -447,17 +466,24 @@ def write_area(self, area: Area, db_number: int, start: int, data: bytearray, wo return 0 def read_multi_vars(self, items: Union[List[dict[str, Any]], "Array[S7DataItem]"]) -> Tuple[int, Any]: - """ - Read multiple variables in a single request. + """Read multiple variables in a single request. + + When given a list of dicts with two or more items, uses the multi-variable + read optimizer to merge adjacent reads and pack them into minimal PDU + exchanges. This significantly reduces the number of round-trips compared + to reading each variable individually. Args: - items: List of item specifications or S7DataItem array + items: List of item specifications (dicts with ``area``, ``start``, + ``size``, and optionally ``db_number``) **or** a ctypes + ``Array[S7DataItem]``. Returns: - Tuple of (result, items with data) + Tuple of (result_code, data) where *data* is either the updated + ctypes array or a list of bytearrays in the original item order. Raises: - ValueError: If more than MAX_VARS items are requested + ValueError: If more than MAX_VARS items are requested. """ if not items: return (0, items) @@ -465,9 +491,8 @@ def read_multi_vars(self, items: Union[List[dict[str, Any]], "Array[S7DataItem]" if len(items) > self.MAX_VARS: raise ValueError(f"Too many items: {len(items)} exceeds MAX_VARS ({self.MAX_VARS})") - # Handle S7DataItem array (ctypes) + # Handle S7DataItem array (ctypes) -- unchanged legacy path if hasattr(items, "_type_") and hasattr(items[0], "Area"): - # This is a ctypes array of S7DataItem - use cast for type safety s7_items = cast("Array[S7DataItem]", items) for s7_item in s7_items: area = Area(s7_item.Area) @@ -475,25 +500,93 @@ def read_multi_vars(self, items: Union[List[dict[str, Any]], "Array[S7DataItem]" start = s7_item.Start size = s7_item.Amount data = self.read_area(area, db_number, start, size) - - # Copy data to pData buffer if s7_item.pData: for i, b in enumerate(data): s7_item.pData[i] = b - return (0, items) - # Handle dict list + # Dict list path -- use optimizer for 2+ items dict_items = cast(List[dict[str, Any]], items) - results = [] - for dict_item in dict_items: - area = dict_item["area"] - db_number = dict_item.get("db_number", 0) - start = dict_item["start"] - size = dict_item["size"] - data = self.read_area(area, db_number, start, size) - results.append(data) + if len(dict_items) <= 1 or not self.use_optimizer: + # Single item or optimizer disabled: no optimization needed + results: list[bytearray] = [] + for dict_item in dict_items: + area = dict_item["area"] + db_number = dict_item.get("db_number", 0) + start = dict_item["start"] + size = dict_item["size"] + data = self.read_area(area, db_number, start, size) + results.append(data) + return (0, results) + + return self._read_multi_vars_optimized(dict_items) + + def _read_multi_vars_optimized(self, dict_items: List[dict[str, Any]]) -> Tuple[int, List[bytearray]]: + """Optimized multi-variable read using merge + packetize strategy. + + Args: + dict_items: List of item dicts (area, db_number, start, size). + + Returns: + Tuple of (0, list of bytearrays in original order). + """ + # Build ReadItem list + read_items: list[ReadItem] = [] + for idx, d in enumerate(dict_items): + area_val = int(d["area"]) + db_number = d.get("db_number", 0) + read_items.append( + ReadItem( + area=area_val, + db_number=db_number, + byte_offset=d["start"], + bit_offset=0, + byte_length=d["size"], + index=idx, + ) + ) + + # Build cache key from the item layout + cache_key = tuple(val for ri in read_items for val in (ri.area, ri.db_number, ri.byte_offset, ri.byte_length)) + + # Reuse cached plan if layout matches + if self._opt_plan is not None and self._opt_plan.cache_key == cache_key: + packets = self._opt_plan.packets + else: + sorted_ri = sort_items(read_items) + max_block = self._max_read_size() + blocks = merge_items(sorted_ri, max_gap=self.multi_read_max_gap, max_block_size=max_block) + packets = packetize(blocks, self.pdu_length) + self._opt_plan = _OptimizationPlan(cache_key, packets, read_items) + + # Deep-copy blocks from cached packets so we don't mutate cached state + working_packets = copy.deepcopy(packets) + + # Execute each packet + for packet in working_packets: + block_specs = [(blk.area, blk.db_number, blk.start_offset, blk.byte_length) for blk in packet.blocks] + + if len(block_specs) == 1: + # Single block: use regular read to avoid multi-read overhead + blk = packet.blocks[0] + data = self.read_area( + Area(blk.area) if blk.area in _VALID_AREA_VALUES else Area.DB, + blk.db_number, + blk.start_offset, + blk.byte_length, + ) + blk.buffer = data + else: + # Multi-block: use multi-read PDU + request = self.protocol.build_multi_read_request(block_specs) + response = self._send_receive(request) + block_data_list = self.protocol.extract_multi_read_data(response, len(block_specs)) + for blk, buf in zip(packet.blocks, block_data_list): + blk.buffer = buf + + # Extract per-item results in original order + results = extract_results(working_packets, len(dict_items)) return (0, results) def write_multi_vars(self, items: Union[List[dict[str, Any]], List[S7DataItem]]) -> int: diff --git a/snap7/optimizer.py b/snap7/optimizer.py new file mode 100644 index 00000000..82a33aed --- /dev/null +++ b/snap7/optimizer.py @@ -0,0 +1,283 @@ +""" +Multi-variable read optimizer for S7 communication. + +Optimizes multiple scattered read requests into minimal PDU-packed S7 exchanges +by merging adjacent/overlapping reads and packing them into PDU-sized packets. +""" + +import logging +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class ReadItem: + """A single read request from the caller. + + Attributes: + area: S7Area value (e.g. 0x84 for DB). + db_number: DB number (0 for non-DB areas). + byte_offset: Start byte offset in the area. + bit_offset: Bit offset within the byte (0 for byte-level reads). + byte_length: Number of bytes to read. + index: Original ordering position so results can be returned in order. + """ + + area: int + db_number: int + byte_offset: int + bit_offset: int + byte_length: int + index: int + + +@dataclass +class ReadBlock: + """A merged contiguous block of bytes to read in one address spec. + + Attributes: + area: S7Area value. + db_number: DB number. + start_offset: Start byte offset of the block. + byte_length: Total bytes to read. + items: The ReadItems contained in this block. + """ + + area: int + db_number: int + start_offset: int + byte_length: int + items: list[ReadItem] = field(default_factory=list) + buffer: bytearray = field(default_factory=bytearray) + + +@dataclass +class ReadPacket: + """A group of ReadBlocks that fit in a single S7 PDU exchange. + + Attributes: + blocks: The blocks in this packet. + """ + + blocks: list[ReadBlock] = field(default_factory=list) + + +def sort_items(items: list[ReadItem]) -> list[ReadItem]: + """Sort read items for optimal merging. + + Items are sorted by (area, db_number, byte_offset, bit_offset, -byte_length). + Sorting by descending byte_length ensures that when two items start at the same + offset, the larger one comes first, which simplifies overlap handling. + + Args: + items: List of read items to sort. + + Returns: + New sorted list (original is not modified). + """ + return sorted(items, key=lambda i: (i.area, i.db_number, i.byte_offset, i.bit_offset, -i.byte_length)) + + +def merge_items(sorted_items: list[ReadItem], max_gap: int = 5, max_block_size: int = 462) -> list[ReadBlock]: + """Merge sorted read items into contiguous blocks. + + Adjacent or overlapping items in the same area/db are merged when the gap + between them is at most *max_gap* bytes and the resulting block does not + exceed *max_block_size* bytes. + + Args: + sorted_items: Items pre-sorted by :func:`sort_items`. + max_gap: Maximum byte gap between items to still merge them. + max_block_size: Maximum byte length of a single merged block. + + Returns: + List of merged ReadBlocks. + """ + if not sorted_items: + return [] + + blocks: list[ReadBlock] = [] + current = sorted_items[0] + block = ReadBlock( + area=current.area, + db_number=current.db_number, + start_offset=current.byte_offset, + byte_length=current.byte_length, + items=[current], + ) + + for item in sorted_items[1:]: + block_end = block.start_offset + block.byte_length + item_end = item.byte_offset + item.byte_length + + same_region = item.area == block.area and item.db_number == block.db_number + gap = item.byte_offset - block_end + new_length = max(block_end, item_end) - block.start_offset + + if same_region and gap <= max_gap and new_length <= max_block_size: + # Merge: extend block to cover the new item + block.byte_length = new_length + block.items.append(item) + else: + # Start a new block + blocks.append(block) + block = ReadBlock( + area=item.area, + db_number=item.db_number, + start_offset=item.byte_offset, + byte_length=item.byte_length, + items=[item], + ) + + blocks.append(block) + return blocks + + +def _ceil_even(n: int) -> int: + """Round up to the next even number.""" + return n + (n % 2) + + +def _split_block(block: ReadBlock, max_block_size: int) -> list[ReadBlock]: + """Split an oversized block at item boundaries. + + Never tears an item across two blocks. + + Args: + block: The block to split. + max_block_size: Maximum byte length per sub-block. + + Returns: + List of sub-blocks that each fit within *max_block_size*. + """ + if block.byte_length <= max_block_size: + return [block] + + sub_blocks: list[ReadBlock] = [] + current_items: list[ReadItem] = [] + current_start = block.items[0].byte_offset + current_end = current_start + + for item in block.items: + item_end = item.byte_offset + item.byte_length + new_end = max(current_end, item_end) + new_length = new_end - current_start + + if current_items and new_length > max_block_size: + # Flush current sub-block + sub_blocks.append( + ReadBlock( + area=block.area, + db_number=block.db_number, + start_offset=current_start, + byte_length=current_end - current_start, + items=current_items, + ) + ) + current_items = [item] + current_start = item.byte_offset + current_end = item_end + else: + current_items.append(item) + current_end = new_end + + if current_items: + sub_blocks.append( + ReadBlock( + area=block.area, + db_number=block.db_number, + start_offset=current_start, + byte_length=current_end - current_start, + items=current_items, + ) + ) + + return sub_blocks + + +def packetize(blocks: list[ReadBlock], pdu_size: int) -> list[ReadPacket]: + """Pack blocks into PDU-sized packets. + + Two budgets are enforced per packet: + - **Request budget**: ``12 (header) + 2 (func+count) + 12*N (address specs) <= pdu_size`` + - **Reply budget**: ``12 (header) + 2 (func+count) + sum(4 + ceil_even(length)) <= pdu_size`` + + Oversized blocks are first split at item boundaries, then blocks are + greedily packed into packets. + + Args: + blocks: Merged read blocks. + pdu_size: Negotiated PDU size in bytes. + + Returns: + List of ReadPackets. + """ + # First split any oversized blocks + # Max data payload per block in a single-block packet + max_single_block = pdu_size - 12 - 2 - 4 # header + param + data item header + all_blocks: list[ReadBlock] = [] + for block in blocks: + all_blocks.extend(_split_block(block, max_single_block)) + + if not all_blocks: + return [] + + request_overhead = 14 # 12 header + 2 (func + count) + reply_overhead = 14 # 12 header + 2 (func + count) + addr_spec_size = 12 # per block in request + + packets: list[ReadPacket] = [] + current_packet = ReadPacket() + current_req_used = request_overhead + current_reply_used = reply_overhead + + for block in all_blocks: + req_cost = addr_spec_size + reply_cost = 4 + _ceil_even(block.byte_length) + + fits_request = current_req_used + req_cost <= pdu_size + fits_reply = current_reply_used + reply_cost <= pdu_size + + if current_packet.blocks and (not fits_request or not fits_reply): + # Start a new packet + packets.append(current_packet) + current_packet = ReadPacket() + current_req_used = request_overhead + current_reply_used = reply_overhead + + current_packet.blocks.append(block) + current_req_used += req_cost + current_reply_used += reply_cost + + if current_packet.blocks: + packets.append(current_packet) + + return packets + + +def extract_results(packets: list[ReadPacket], original_count: int) -> list[bytearray]: + """Map block buffers back to original items using offset math. + + Each block must have its ``buffer`` attribute set (a bytearray of the + block's data as returned by the PLC) before calling this function. + The buffer is stored as a dynamic attribute on the ReadBlock dataclass. + + Args: + packets: Packets with block buffers populated. + original_count: Number of original read items. + + Returns: + List of bytearrays indexed by original ``ReadItem.index``. + """ + results: list[bytearray] = [bytearray() for _ in range(original_count)] + + for packet in packets: + for block in packet.blocks: + buf = block.buffer + for item in block.items: + local_offset = item.byte_offset - block.start_offset + item_data = buf[local_offset : local_offset + item.byte_length] + results[item.index] = bytearray(item_data) + + return results diff --git a/snap7/s7protocol.py b/snap7/s7protocol.py index 9290ba5b..e2f29d23 100644 --- a/snap7/s7protocol.py +++ b/snap7/s7protocol.py @@ -7,7 +7,7 @@ import struct import logging from datetime import datetime -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple from enum import IntEnum from .datatypes import S7Area, S7WordLen, S7DataTypes @@ -173,6 +173,102 @@ def build_read_request(self, area: S7Area, db_number: int, start: int, word_len: return header + parameters + def build_multi_read_request(self, items: List[Tuple[int, int, int, int]]) -> bytes: + """Build S7 multi-variable read request PDU. + + Encodes multiple address specifications into a single READ_AREA request + so the PLC can return all data in one response. + + Args: + items: List of (area, db_number, start_offset, byte_length) tuples. + + Returns: + Complete S7 PDU. + """ + item_count = len(items) + + # Build N * 12-byte address specifications + addr_spec_parts: list[bytes] = [] + for area_code, db_number, start_offset, byte_length in items: + addr_spec_parts.append( + S7DataTypes.encode_address(S7Area(area_code), db_number, start_offset, S7WordLen.BYTE, byte_length) + ) + + # Parameter: function_code(1) + item_count(1) + N * address_spec(12) + param_data = struct.pack(">BB", S7Function.READ_AREA, item_count) + b"".join(addr_spec_parts) + param_len = len(param_data) + + # S7 Header (12 bytes) + header = struct.pack( + ">BBHHHH", + 0x32, # Protocol ID + S7PDUType.REQUEST, # PDU type + 0x0000, # Reserved + self._next_sequence(), # Sequence + param_len, # Parameter length + 0x0000, # Data length (no data for read) + ) + + return header + param_data + + def extract_multi_read_data(self, response: Dict[str, Any], block_count: int) -> List[bytearray]: + """Extract per-block data from a multi-variable read response. + + Parses the raw data section which contains N items, each with: + - return_code (1 byte) + - transport_size (1 byte) + - bit_length (2 bytes, big-endian) + - data (bit_length / 8 bytes) + - fill byte (1 byte if byte_length is odd and not the last item) + + Args: + response: Parsed S7 response from :meth:`parse_response`. + block_count: Expected number of data items. + + Returns: + List of bytearrays, one per block. + + Raises: + S7ProtocolError: If any item has a non-success return code. + """ + raw = response.get("raw_data", b"") + if not raw: + raise S7ProtocolError("No raw data in multi-read response") + + results: List[bytearray] = [] + offset = 0 + + for i in range(block_count): + if offset + 4 > len(raw): + raise S7ProtocolError(f"Multi-read response truncated at item {i}") + + return_code = raw[offset] + transport_size = raw[offset + 1] + bit_length = struct.unpack(">H", raw[offset + 2 : offset + 4])[0] + offset += 4 + + if return_code != 0xFF: + desc = get_return_code_description(return_code) + raise S7ProtocolError(f"Multi-read item {i} failed: {desc} (0x{return_code:02x})") + + # Transport size 0x04 means bit length, others mean byte length + if transport_size == 0x04: + byte_length = bit_length // 8 + else: + byte_length = bit_length + + if offset + byte_length > len(raw): + raise S7ProtocolError(f"Multi-read data truncated at item {i}") + + results.append(bytearray(raw[offset : offset + byte_length])) + offset += byte_length + + # Fill byte for even alignment (not after the last item) + if i < block_count - 1 and byte_length % 2 != 0: + offset += 1 + + return results + def build_write_request(self, area: S7Area, db_number: int, start: int, word_len: S7WordLen, data: bytes) -> bytes: """ Build S7 write request PDU. @@ -1335,6 +1431,7 @@ def parse_response(self, pdu: bytes) -> Dict[str, Any]: data_section = pdu[offset : offset + data_len] response["data"] = self._parse_data_section(data_section) + response["raw_data"] = data_section return response diff --git a/snap7/server/__init__.py b/snap7/server/__init__.py index d44f1760..ca9dbb41 100644 --- a/snap7/server/__init__.py +++ b/snap7/server/__init__.py @@ -734,65 +734,53 @@ def _handle_setup_communication(self, request: Dict[str, Any]) -> bytes: return header + parameters def _handle_read_area(self, request: Dict[str, Any], client_address: Tuple[str, int]) -> bytes: - """Handle read area request.""" + """Handle read area request (single or multi-item).""" try: - # Parse address specification from request parameters + params = request.get("parameters", {}) + item_count = params.get("item_count", 1) + + # Multi-item read + if item_count > 1 and "address_specs" in params: + return self._handle_multi_read_area(request, client_address) + + # Single-item read (original path) addr_info = self._parse_read_address(request) if not addr_info: - return self._build_error_response(request, 0x8001) # Invalid address + return self._build_error_response(request, 0x8001) area, db_number, start, count = addr_info - # Read data from registered memory area read_data = self._read_from_memory_area(area, db_number, start, count) if read_data is None: - return self._build_error_response(request, 0x8404) # Area not found + return self._build_error_response(request, 0x8404) - # Calculate data length - need to include transport header + data - data_len = 4 + len(read_data) # Transport header (4 bytes) + data + data_len = 4 + len(read_data) - # Build successful response - # S7 response header includes error class + error code header = struct.pack( ">BBHHHHBB", - 0x32, # Protocol ID - S7PDUType.ACK_DATA, # PDU type - 0x0000, # Reserved - request["sequence"], # Sequence (echo) - 0x0002, # Parameter length - data_len, # Data length - 0x00, # Error class (success) - 0x00, # Error code (success) + 0x32, + S7PDUType.ACK_DATA, + 0x0000, + request["sequence"], + 0x0002, + data_len, + 0x00, + 0x00, ) - # Parameters - parameters = struct.pack( - ">BB", - S7Function.READ_AREA, # Function code - 0x01, # Item count - ) + parameters = struct.pack(">BB", S7Function.READ_AREA, 0x01) - # Data section - data_section = ( - struct.pack( - ">BBH", - 0xFF, # Return code (success) - 0x04, # Transport size (04 = byte data) - len(read_data) * 8, # Data length in bits - ) - + read_data - ) + data_section = struct.pack(">BBH", 0xFF, 0x04, len(read_data) * 8) + read_data - # Trigger read event callback if self.read_callback: event = SrvEvent() event.EvtTime = int(time.time()) event.EvtSender = 0 - event.EvtCode = 0x00004000 # Read event + event.EvtCode = 0x00004000 event.EvtRetCode = 0 - event.EvtParam1 = 1 # Area - event.EvtParam2 = 0 # Offset - event.EvtParam3 = len(read_data) # Size + event.EvtParam1 = 1 + event.EvtParam2 = 0 + event.EvtParam3 = len(read_data) event.EvtParam4 = 0 try: self.read_callback(event) @@ -805,6 +793,64 @@ def _handle_read_area(self, request: Dict[str, Any], client_address: Tuple[str, logger.error(f"Error handling read request: {e}") return self._build_error_response(request, 0x8000) + def _handle_multi_read_area(self, request: Dict[str, Any], client_address: Tuple[str, int]) -> bytes: + """Handle multi-item read area request. + + Reads multiple address specifications and returns all data items in a + single response with proper fill-byte alignment between items. + """ + params = request["parameters"] + address_specs: List[Dict[str, Any]] = params["address_specs"] + item_count = len(address_specs) + + # Build data section: concatenated items with fill bytes + data_parts = bytearray() + for i, addr in enumerate(address_specs): + area = addr.get("area", S7Area.DB) + db_number = addr.get("db_number", 0) + start = addr.get("start", 0) + count = addr.get("count", 1) + word_len = addr.get("word_len", S7WordLen.BYTE) + + # Convert count to bytes + if word_len in (S7WordLen.TIMER, S7WordLen.COUNTER, S7WordLen.WORD): + byte_count = count * 2 + elif word_len in (S7WordLen.DWORD, S7WordLen.REAL): + byte_count = count * 4 + elif word_len == S7WordLen.BIT: + byte_count = 1 + else: + byte_count = count + + read_data = self._read_from_memory_area(area, db_number, start, byte_count) + if read_data is None: + # Item error: not found + data_parts.extend(struct.pack(">BBH", 0x0A, 0x00, 0x0000)) + else: + data_parts.extend(struct.pack(">BBH", 0xFF, 0x04, len(read_data) * 8)) + data_parts.extend(read_data) + # Fill byte for even alignment (not after last item) + if i < item_count - 1 and len(read_data) % 2 != 0: + data_parts.append(0x00) + + data_len = len(data_parts) + + header = struct.pack( + ">BBHHHHBB", + 0x32, + S7PDUType.ACK_DATA, + 0x0000, + request["sequence"], + 0x0002, # param length + data_len, + 0x00, + 0x00, + ) + + parameters = struct.pack(">BB", S7Function.READ_AREA, item_count) + + return header + parameters + bytes(data_parts) + def _parse_read_address(self, request: Dict[str, Any]) -> Optional[Tuple[S7Area, int, int, int]]: """ Parse read address from request parameters. @@ -1182,10 +1228,24 @@ def _parse_request_parameters(self, param_data: bytes) -> Dict[str, Any]: elif function_code == S7Function.READ_AREA: # Parse read area parameters if len(param_data) >= 14: # Minimum for read area request - # Function code (1) + item count (1) + address spec (12) + # Function code (1) + item count (1) + N * address spec (12 each) item_count = param_data[1] - # Parse address specification starting at byte 2 + if item_count > 1: + # Multi-item read: parse all address specs + address_specs: List[Dict[str, Any]] = [] + offset = 2 + for _ in range(item_count): + if offset + 12 > len(param_data): + break + addr_spec = param_data[offset : offset + 12] + parsed_addr = self._parse_address_specification(addr_spec) + if parsed_addr: + address_specs.append(parsed_addr) + offset += 12 + return {"function_code": function_code, "item_count": item_count, "address_specs": address_specs} + + # Single-item read if len(param_data) >= 14: addr_spec = param_data[2:14] # 12 bytes of address specification logger.debug(f"Extracted address spec from params: {addr_spec.hex()}") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 00000000..36b41661 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,319 @@ +"""Tests for the multi-variable read optimizer.""" + +from __future__ import annotations + +import random +import time +from ctypes import c_char +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from snap7.client import Client + from snap7.server import Server + +from snap7.optimizer import ( + ReadItem, + ReadBlock, + ReadPacket, + sort_items, + merge_items, + packetize, + extract_results, +) +from snap7.type import Area, SrvArea + + +# --------------------------------------------------------------------------- +# Unit tests for sort_items +# --------------------------------------------------------------------------- + + +class TestSortItems: + """Tests for sort_items().""" + + def test_different_areas(self) -> None: + items = [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x83, db_number=0, byte_offset=0, bit_offset=0, byte_length=4, index=1), + ] + result = sort_items(items) + assert result[0].area == 0x83 # MK before DB + assert result[1].area == 0x84 + + def test_same_area_different_db(self) -> None: + items = [ + ReadItem(area=0x84, db_number=2, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=1), + ] + result = sort_items(items) + assert result[0].db_number == 1 + assert result[1].db_number == 2 + + def test_same_offset_different_sizes(self) -> None: + items = [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=2, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=8, index=1), + ] + result = sort_items(items) + # Larger item first (descending byte_length) + assert result[0].byte_length == 8 + assert result[1].byte_length == 2 + + def test_original_not_modified(self) -> None: + items = [ + ReadItem(area=0x84, db_number=2, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=1), + ] + sort_items(items) + assert items[0].db_number == 2 # Original unchanged + + +# --------------------------------------------------------------------------- +# Unit tests for merge_items +# --------------------------------------------------------------------------- + + +class TestMergeItems: + """Tests for merge_items().""" + + def test_contiguous_merge(self) -> None: + items = sort_items( + [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=4, bit_offset=0, byte_length=4, index=1), + ] + ) + blocks = merge_items(items) + assert len(blocks) == 1 + assert blocks[0].start_offset == 0 + assert blocks[0].byte_length == 8 + + def test_gap_merge(self) -> None: + items = sort_items( + [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=8, bit_offset=0, byte_length=4, index=1), + ] + ) + blocks = merge_items(items, max_gap=5) + assert len(blocks) == 1 + assert blocks[0].byte_length == 12 + + def test_gap_split(self) -> None: + items = sort_items( + [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=100, bit_offset=0, byte_length=4, index=1), + ] + ) + blocks = merge_items(items, max_gap=5) + assert len(blocks) == 2 + + def test_different_areas_split(self) -> None: + items = sort_items( + [ + ReadItem(area=0x83, db_number=0, byte_offset=0, bit_offset=0, byte_length=4, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=1), + ] + ) + blocks = merge_items(items) + assert len(blocks) == 2 + + def test_max_block_size_split(self) -> None: + items = sort_items( + [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=300, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=300, bit_offset=0, byte_length=300, index=1), + ] + ) + blocks = merge_items(items, max_block_size=400) + assert len(blocks) == 2 + + def test_overlapping_items(self) -> None: + items = sort_items( + [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=10, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=5, bit_offset=0, byte_length=10, index=1), + ] + ) + blocks = merge_items(items) + assert len(blocks) == 1 + assert blocks[0].start_offset == 0 + assert blocks[0].byte_length == 15 # 0..15 + + def test_empty_input(self) -> None: + assert merge_items([]) == [] + + +# --------------------------------------------------------------------------- +# Unit tests for packetize +# --------------------------------------------------------------------------- + + +class TestPacketize: + """Tests for packetize().""" + + def test_single_block_one_packet(self) -> None: + blocks = [ReadBlock(area=0x84, db_number=1, start_offset=0, byte_length=10, items=[])] + packets = packetize(blocks, pdu_size=480) + assert len(packets) == 1 + assert len(packets[0].blocks) == 1 + + def test_multiple_blocks_one_packet(self) -> None: + blocks = [ + ReadBlock(area=0x84, db_number=1, start_offset=0, byte_length=10, items=[]), + ReadBlock(area=0x84, db_number=2, start_offset=0, byte_length=10, items=[]), + ] + packets = packetize(blocks, pdu_size=480) + assert len(packets) == 1 + assert len(packets[0].blocks) == 2 + + def test_request_budget_limit(self) -> None: + # Request overhead: 14 + 12*N. With pdu=60, max blocks = (60-14)/12 = 3 + blocks = [ReadBlock(area=0x84, db_number=i, start_offset=0, byte_length=2, items=[]) for i in range(5)] + packets = packetize(blocks, pdu_size=60) + assert len(packets) >= 2 + + def test_reply_budget_limit(self) -> None: + # Reply overhead: 14 + sum(4 + ceil_even(length)). + # Each block of 100 bytes costs 4+100=104 in reply. + # With pdu=240: budget = 240-14 = 226. Fits 2 blocks (208), not 3 (312). + blocks = [ReadBlock(area=0x84, db_number=i, start_offset=0, byte_length=100, items=[]) for i in range(3)] + packets = packetize(blocks, pdu_size=240) + assert len(packets) == 2 + + def test_oversized_block_split(self) -> None: + # A block larger than pdu - overhead should be split at item boundaries + items = [ + ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=200, index=0), + ReadItem(area=0x84, db_number=1, byte_offset=200, bit_offset=0, byte_length=200, index=1), + ] + blocks = [ReadBlock(area=0x84, db_number=1, start_offset=0, byte_length=400, items=items)] + # pdu=240: max single block data = 240-12-2-4 = 222 + packets = packetize(blocks, pdu_size=240) + total_blocks = sum(len(p.blocks) for p in packets) + assert total_blocks == 2 + + +# --------------------------------------------------------------------------- +# Unit tests for extract_results +# --------------------------------------------------------------------------- + + +class TestExtractResults: + """Tests for extract_results().""" + + def test_correct_index_mapping(self) -> None: + item_a = ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=4, index=1) + item_b = ReadItem(area=0x84, db_number=1, byte_offset=4, bit_offset=0, byte_length=4, index=0) + block = ReadBlock(area=0x84, db_number=1, start_offset=0, byte_length=8, items=[item_a, item_b]) + block.buffer = bytearray(b"\x01\x02\x03\x04\x05\x06\x07\x08") + packet = ReadPacket(blocks=[block]) + + results = extract_results([packet], 2) + assert results[0] == bytearray(b"\x05\x06\x07\x08") # index 0 -> item_b + assert results[1] == bytearray(b"\x01\x02\x03\x04") # index 1 -> item_a + + def test_overlapping_items(self) -> None: + item_a = ReadItem(area=0x84, db_number=1, byte_offset=0, bit_offset=0, byte_length=8, index=0) + item_b = ReadItem(area=0x84, db_number=1, byte_offset=4, bit_offset=0, byte_length=4, index=1) + block = ReadBlock(area=0x84, db_number=1, start_offset=0, byte_length=8, items=[item_a, item_b]) + block.buffer = bytearray(b"\x10\x20\x30\x40\x50\x60\x70\x80") + packet = ReadPacket(blocks=[block]) + + results = extract_results([packet], 2) + assert results[0] == bytearray(b"\x10\x20\x30\x40\x50\x60\x70\x80") + assert results[1] == bytearray(b"\x50\x60\x70\x80") + + +# --------------------------------------------------------------------------- +# Integration tests against the server +# --------------------------------------------------------------------------- + + +@pytest.mark.server +class TestMultiReadServer: + """Integration tests for multi-read via server.""" + + server: Server + client: Client + db1_data: bytearray + db2_data: bytearray + + @classmethod + def setup_class(cls) -> None: + """Start a server and connect a client.""" + from snap7.server import Server as Srv + from snap7.client import Client as Cli + + cls.server = Srv() + + cls.db1_data = bytearray(range(100)) + db1_array = (c_char * 100).from_buffer(cls.db1_data) + cls.server.register_area(SrvArea.DB, 1, db1_array) + + cls.db2_data = bytearray(range(100, 200)) + db2_array = (c_char * 100).from_buffer(cls.db2_data) + cls.server.register_area(SrvArea.DB, 2, db2_array) + + port = random.randint(20000, 40000) + cls.server.start(tcp_port=port) + time.sleep(0.2) + + cls.client = Cli() + cls.client.connect("127.0.0.1", 0, 0, tcp_port=port) + + @classmethod + def teardown_class(cls) -> None: + """Stop server and disconnect client.""" + cls.client.disconnect() + cls.server.stop() + + def test_multi_read_basic(self) -> None: + """Read two items from DB1 and verify data.""" + items = [ + {"area": Area.DB, "db_number": 1, "start": 0, "size": 4}, + {"area": Area.DB, "db_number": 1, "start": 10, "size": 4}, + ] + result_code, results = self.client.read_multi_vars(items) + assert result_code == 0 + assert len(results) == 2 + assert results[0] == bytearray(self.db1_data[0:4]) + assert results[1] == bytearray(self.db1_data[10:14]) + + def test_multi_read_different_dbs(self) -> None: + """Read items from DB1 and DB2.""" + items = [ + {"area": Area.DB, "db_number": 1, "start": 0, "size": 4}, + {"area": Area.DB, "db_number": 2, "start": 0, "size": 4}, + ] + result_code, results = self.client.read_multi_vars(items) + assert result_code == 0 + assert results[0] == bytearray(self.db1_data[0:4]) + assert results[1] == bytearray(self.db2_data[0:4]) + + def test_single_item_still_works(self) -> None: + """A single item should use the non-optimized path.""" + items = [ + {"area": Area.DB, "db_number": 1, "start": 5, "size": 10}, + ] + result_code, results = self.client.read_multi_vars(items) + assert result_code == 0 + assert results[0] == bytearray(self.db1_data[5:15]) + + def test_empty_items(self) -> None: + """Empty list should return immediately.""" + result_code, results = self.client.read_multi_vars([]) + assert result_code == 0 + + def test_many_items_multiple_packets(self) -> None: + """Enough items to potentially require multiple packets with a small PDU.""" + items = [{"area": Area.DB, "db_number": 1, "start": i * 8, "size": 4} for i in range(10)] + result_code, results = self.client.read_multi_vars(items) + assert result_code == 0 + assert len(results) == 10 + for i in range(10): + expected = bytearray(self.db1_data[i * 8 : i * 8 + 4]) + assert results[i] == expected, f"Mismatch at item {i}"