Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,9 @@ def __init__(
# Mapping from SCC id to corresponding SCC instance. This is populated
# in process_graph().
self.scc_by_id: dict[int, SCC] = {}
# Mapping from module id to the SCC it belongs to. This is populated
# in process_graph().
self.scc_by_mod_id: dict[str, SCC] = {}
# Global topological order for SCCs. This exists to make order of processing
# SCCs more predictable.
self.top_order: list[int] = []
Expand All @@ -892,6 +895,8 @@ def __init__(
# raw parsed trees not analyzed with mypy. We use these to find absolute
# location of a symbol used as a location for an error message.
self.extra_trees: dict[str, MypyFile] = {}
# Cache for transitive dependency check (expensive).
self.transitive_deps_cache: dict[tuple[int, int], bool] = {}

def dump_stats(self) -> None:
if self.options.dump_build_stats:
Expand Down Expand Up @@ -1203,6 +1208,20 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l
results,
)

def is_transitive_scc_dep(self, from_scc_id: int, to_scc_id: int) -> bool:
"""Check if one SCC is a (transitive) dependency of another."""
edge = (from_scc_id, to_scc_id)
if (cached := self.transitive_deps_cache.get(edge)) is not None:
return cached
if to_scc_id in self.scc_by_id[from_scc_id].deps:
self.transitive_deps_cache[edge] = True
return True
for dep in self.scc_by_id[from_scc_id].deps:
if self.is_transitive_scc_dep(dep, to_scc_id):
return True
self.transitive_deps_cache[edge] = False
return False


def deps_to_json(x: dict[str, set[str]]) -> bytes:
return json_dumps({k: list(v) for k, v in x.items()})
Expand Down Expand Up @@ -1841,6 +1860,7 @@ def write_cache(
dep_prios: list[int],
dep_lines: list[int],
old_interface_hash: bytes,
trans_dep_hash: bytes,
source_hash: str,
ignore_all: bool,
manager: BuildManager,
Expand Down Expand Up @@ -1957,6 +1977,7 @@ def write_cache(
dep_prios=dep_prios,
dep_lines=dep_lines,
interface_hash=interface_hash,
trans_dep_hash=trans_dep_hash,
version_id=manager.version_id,
ignore_all=ignore_all,
plugin_data=plugin_data,
Expand Down Expand Up @@ -2175,6 +2196,12 @@ class State:
# Contains a hash of the public interface in incremental mode
interface_hash: bytes = b""

# Hash of import structure that this module depends on. It is not 1:1 with
# transitive dependencies set, but if two hashes are equal, transitive
# dependencies are guaranteed to be identical. Some expensive checks can be
# skipped if this value is unchanged for a module.
trans_dep_hash: bytes = b""

# Options, specialized for this file
options: Options

Expand Down Expand Up @@ -2322,15 +2349,15 @@ def new_state(
if temporary:
state.load_tree(temporary=True)
if not manager.use_fine_grained_cache():
# Special case: if there were a previously missing package imported here
# Special case: if there were a previously missing package imported here,
# and it is not present, then we need to re-calculate dependencies.
# This is to support patterns like this:
# from missing_package import missing_module # type: ignore
# At first mypy doesn't know that `missing_module` is a module
# (it may be a variable, a class, or a function), so it is not added to
# suppressed dependencies. Therefore, when the package with module is added,
# we need to re-calculate dependencies.
# NOTE: see comment below for why we skip this in fine grained mode.
# NOTE: see comment below for why we skip this in fine-grained mode.
if exist_added_packages(suppressed, manager, options):
state.parse_file() # This is safe because the cache is anyway stale.
state.compute_dependencies()
Expand All @@ -2350,6 +2377,7 @@ def new_state(
# We don't need parsed trees in coordinator process, we parse only to
# compute dependencies.
state.tree = None
del manager.ast_cache[id]

return state

Expand Down Expand Up @@ -3012,6 +3040,7 @@ def write_cache(self) -> tuple[CacheMeta, str] | None:
dep_prios,
dep_lines,
self.interface_hash,
self.trans_dep_hash,
self.source_hash,
self.ignore_all,
self.manager,
Expand Down Expand Up @@ -3774,6 +3803,27 @@ def order_ascc_ex(graph: Graph, ascc: SCC) -> list[str]:
return scc


def verify_transitive_deps(ascc: SCC, graph: Graph, manager: BuildManager) -> str | None:
"""Verify all indirect dependencies of this SCC are still reachable via direct ones.

Return first unreachable dependency id, or None.
"""
for id in ascc.mod_ids:
st = graph[id]
assert st.meta is not None, "Must be called on fresh SCCs only"
if st.trans_dep_hash == st.meta.trans_dep_hash:
# Import graph unchanged, skip this module.
continue
for dep in st.dependencies:
if st.priorities.get(dep) == PRI_INDIRECT:
dep_scc_id = manager.scc_by_mod_id[dep].id
if dep_scc_id == ascc.id:
continue
if not manager.is_transitive_scc_dep(ascc.id, dep_scc_id):
return dep
return None


def find_stale_sccs(
sccs: list[SCC], graph: Graph, manager: BuildManager
) -> tuple[list[SCC], list[SCC]]:
Expand All @@ -3782,7 +3832,8 @@ def find_stale_sccs(
Fresh SCCs are those where:
* We have valid cache files for all modules in the SCC.
* There are no changes in dependencies (files removed from/added to the build).
* The interface hashes of direct dependents matches those recorded in the cache.
* The interface hashes of dependencies matches those recorded in the cache.
* All indirect dependencies are still reachable via direct ones.
The first and second conditions are verified by is_fresh().
"""
stale_sccs = []
Expand All @@ -3799,6 +3850,15 @@ def find_stale_sccs(
stale_deps.add(dep)
fresh = fresh and not stale_deps

# Verify the invariant that indirect dependencies are a subset of transitive direct
# dependencies. Note: the case where indirect dependency is removed from the graph
# completely is already handled above.
stale_indirect = None
if fresh:
stale_indirect = verify_transitive_deps(ascc, graph, manager)
if stale_indirect is not None:
fresh = False

if fresh:
fresh_msg = "fresh"
elif stale_scc:
Expand All @@ -3807,8 +3867,11 @@ def find_stale_sccs(
fresh_msg += f" ({' '.join(sorted(stale_scc))})"
if stale_deps:
fresh_msg += f" with stale deps ({' '.join(sorted(stale_deps))})"
else:
elif stale_deps:
fresh_msg = f"stale due to deps ({' '.join(sorted(stale_deps))})"
else:
assert stale_indirect is not None
fresh_msg = f"stale due to stale indirect dep(s): first {stale_indirect}"

scc_str = " ".join(ascc.mod_ids)
if fresh:
Expand Down Expand Up @@ -3860,6 +3923,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
scc_by_id = {scc.id: scc for scc in sccs}
manager.scc_by_id = scc_by_id
manager.top_order = [scc.id for scc in sccs]
for scc in sccs:
for mod_id in scc.mod_ids:
manager.scc_by_mod_id[mod_id] = scc

# Broadcast SCC structure to the parallel workers, since they don't compute it.
sccs_message = SccsDataMessage(sccs=sccs)
Expand Down Expand Up @@ -3904,8 +3970,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
# type-checking this is already done and results should be empty here.
if not manager.workers:
assert not results
for id, (interface_cache, errors) in results.items():
new_hash = bytes.fromhex(interface_cache)
for id, (interface_hash, errors) in results.items():
new_hash = bytes.fromhex(interface_hash)
if new_hash != graph[id].interface_hash:
graph[id].mark_interface_stale()
graph[id].interface_hash = new_hash
Expand All @@ -3917,6 +3983,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
if not scc_by_id[dependent].not_ready_deps:
not_ready.remove(scc_by_id[dependent])
ready.append(scc_by_id[dependent])
manager.trace(f"Transitive deps cache size: {sys.getsizeof(manager.transitive_deps_cache)}")


def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_INDIRECT) -> list[str]:
Expand Down Expand Up @@ -4168,6 +4235,11 @@ def sorted_components(graph: Graph) -> list[SCC]:
scc.size_hint = sum(graph[mid].size_hint for mid in scc.mod_ids)
for dep in scc_dep_map[scc]:
dep.direct_dependents.append(scc.id)
# We compute dependencies hash here since we know no direct
# dependencies will be added or suppressed after this point.
trans_dep_hash = transitive_dep_hash(scc, graph)
for id in scc.mod_ids:
graph[id].trans_dep_hash = trans_dep_hash
res.extend(sorted_ready)
return res

Expand Down Expand Up @@ -4201,6 +4273,21 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
]


def transitive_dep_hash(scc: SCC, graph: Graph) -> bytes:
"""Compute stable snapshot of transitive import structure for given SCC."""
all_direct_deps = {
dep
for id in scc.mod_ids
for dep in graph[id].dependencies
if graph[id].priorities.get(dep) != PRI_INDIRECT
}
trans_dep_hash_map = {
dep_id: "" if dep_id in scc.mod_ids else graph[dep_id].trans_dep_hash.hex()
for dep_id in all_direct_deps
}
return hash_digest_bytes(json_dumps(trans_dep_hash_map))


def missing_stubs_file(cache_dir: str) -> str:
return os.path.join(cache_dir, "missing_stubs")

Expand Down
8 changes: 7 additions & 1 deletion mypy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from mypy_extensions import u8

# High-level cache layout format
CACHE_VERSION: Final = 2
CACHE_VERSION: Final = 3

SerializedError: _TypeAlias = tuple[str | None, int | str, int, int, int, str, str, str | None]

Expand All @@ -95,6 +95,7 @@ def __init__(
dep_lines: list[int],
dep_hashes: list[bytes],
interface_hash: bytes,
trans_dep_hash: bytes,
error_lines: list[SerializedError],
version_id: str,
ignore_all: bool,
Expand All @@ -117,6 +118,7 @@ def __init__(
# dep_hashes list is aligned with dependencies only
self.dep_hashes = dep_hashes # list of interface_hash for dependencies
self.interface_hash = interface_hash # hash representing the public interface
self.trans_dep_hash = trans_dep_hash # hash of import structure (transitive)
self.error_lines = error_lines
self.version_id = version_id # mypy version for cache invalidation
self.ignore_all = ignore_all # if errors were ignored
Expand All @@ -138,6 +140,7 @@ def serialize(self) -> dict[str, Any]:
"dep_lines": self.dep_lines,
"dep_hashes": [dep.hex() for dep in self.dep_hashes],
"interface_hash": self.interface_hash.hex(),
"trans_dep_hash": self.trans_dep_hash.hex(),
"error_lines": self.error_lines,
"version_id": self.version_id,
"ignore_all": self.ignore_all,
Expand Down Expand Up @@ -165,6 +168,7 @@ def deserialize(cls, meta: dict[str, Any], data_file: str) -> CacheMeta | None:
dep_lines=meta["dep_lines"],
dep_hashes=[bytes.fromhex(dep) for dep in meta["dep_hashes"]],
interface_hash=bytes.fromhex(meta["interface_hash"]),
trans_dep_hash=bytes.fromhex(meta["trans_dep_hash"]),
error_lines=[tuple(err) for err in meta["error_lines"]],
version_id=meta["version_id"],
ignore_all=meta["ignore_all"],
Expand All @@ -191,6 +195,7 @@ def write(self, data: WriteBuffer) -> None:
write_int_list(data, self.dep_lines)
write_bytes_list(data, self.dep_hashes)
write_bytes(data, self.interface_hash)
write_bytes(data, self.trans_dep_hash)
write_errors(data, self.error_lines)
write_str(data, self.version_id)
write_bool(data, self.ignore_all)
Expand Down Expand Up @@ -219,6 +224,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
dep_lines=read_int_list(data),
dep_hashes=read_bytes_list(data),
interface_hash=read_bytes(data),
trans_dep_hash=read_bytes(data),
error_lines=read_errors(data),
version_id=read_str(data),
ignore_all=read_bool(data),
Expand Down
6 changes: 0 additions & 6 deletions mypy/semanal_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from itertools import groupby
from typing import TYPE_CHECKING, Final, TypeAlias as _TypeAlias

import mypy.build
import mypy.state
from mypy.checker import FineGrainedDeferredNode
from mypy.errors import Errors
Expand Down Expand Up @@ -416,11 +415,6 @@ def semantic_analyze_target(
)
if isinstance(node, Decorator):
infer_decorator_signature_if_simple(node, analyzer)
for dep in analyzer.imports:
state.add_dependency(dep)
priority = mypy.build.PRI_LOW
if priority <= state.priorities.get(dep, priority):
state.priorities[dep] = priority

# Clear out some stale data to avoid memory leaks and astmerge
# validity check confusion
Expand Down
8 changes: 5 additions & 3 deletions mypy/test/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def test_sorted_components(self) -> None:
"d": State.new_state("d", None, "pass", manager),
"b": State.new_state("b", None, "import c", manager),
"c": State.new_state("c", None, "import b, d", manager),
"builtins": State.new_state("builtins", None, "", manager),
}
res = [scc.mod_ids for scc in sorted_components(graph)]
assert_equal(res, [{"d"}, {"c", "b"}, {"a"}])
assert_equal(res, [{"builtins"}, {"d"}, {"c", "b"}, {"a"}])

def test_order_ascc(self) -> None:
manager = self._make_manager()
Expand All @@ -75,9 +76,10 @@ def test_order_ascc(self) -> None:
"d": State.new_state("d", None, "def f(): import a", manager),
"b": State.new_state("b", None, "import c", manager),
"c": State.new_state("c", None, "import b, d", manager),
"builtins": State.new_state("builtins", None, "", manager),
}
res = [scc.mod_ids for scc in sorted_components(graph)]
assert_equal(res, [frozenset({"a", "d", "c", "b"})])
ascc = res[0]
assert_equal(res, [{"builtins"}, {"a", "d", "c", "b"}])
ascc = res[1]
scc = order_ascc(graph, ascc)
assert_equal(scc, ["d", "c", "b", "a"])
Loading