Skip to content

Make Op class an immutable Generic#1997

Draft
lucianopaz wants to merge 2 commits intopymc-devs:v3from
lucianopaz:generic_op
Draft

Make Op class an immutable Generic#1997
lucianopaz wants to merge 2 commits intopymc-devs:v3from
lucianopaz:generic_op

Conversation

@lucianopaz
Copy link
Copy Markdown
Member

This PR aims to solve a subset of the type hint issues that pytensor has, namely the Op class expects to have Variable input instances and output Variable instances. This makes it impossible for type checkers to realize that an Op actually operates with Variable subtypes like TensorVariable, and so the type checker will never know that an Op's output will have all of the tensor traits.

The PR tries to handle this by making Op a Generic class. After much thought, I decided to make the Op an immutable Generic. I decided against covariant or contravariant because subclasses of subclasses of Op's might have outputs that are not subclasses of their parent class output, so immutable seemed like the safest choice.

@lucianopaz lucianopaz marked this pull request as draft March 24, 2026 10:19


class Scan(Op, ScanMethodsMixin, HasInnerGraph):
class Scan(Op[TensorVariable], ScanMethodsMixin, HasInnerGraph):
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scan can output any variable type (as untraced_sit_sot)

The Scan materialized with an inner graph should determine the output type



class CSM(Op):
class CSM(Op[TensorVariable]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be SparseTensorVariable?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a such a thing? It would be great if there were, but I didn't find it.

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is,

class SparseVariable(_sparse_py_operators, TensorVariable): # type: ignore[misc]

SparseVariable not SparseTensorVariable. Making a subclass of TensorVariable was still a mistake imo, so we should change that later down the road



class RNGConsumerOp(Op):
class RNGConsumerOp(Op[TensorVariable]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also outputs RandomType variable (the next rng state)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to type hint that 😬😬😬
My type hints only work for ops that always return the same type. I don't know how to add more output types on the fly.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why, can't it take a tuple?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output signature of RandomVariable (not RNGConsumerOp) is tuple[RandomTypeVariable, TensorVariable]. However there is the annoying default_output, which means the signature of call is just Tensorvariable. What about multi-output ops in general?

raise TypeError(
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)
self.outputs = cast(ApplyOutputsType, tuple(_outputs))
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds nit, but apply creation is hot loop for us, don't wast cycles with cast function (even a no-op like this), just use type-ignore if mypy doesn't believe us

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants