From 2289d5b8fbeb33f9121acdb5c9961f755c8f99a3 Mon Sep 17 00:00:00 2001 From: jstoobysmith <72603918+jstoobysmith@users.noreply.github.com> Date: Fri, 20 Mar 2026 06:35:31 +0000 Subject: [PATCH 1/3] feat: Add better eval for tensors --- PhysLean/Relativity/Tensors/Elab.lean | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/PhysLean/Relativity/Tensors/Elab.lean b/PhysLean/Relativity/Tensors/Elab.lean index 55ecacff6..6175f27a1 100644 --- a/PhysLean/Relativity/Tensors/Elab.lean +++ b/PhysLean/Relativity/Tensors/Elab.lean @@ -23,6 +23,7 @@ public import PhysLean.Relativity.Tensors.Tensorial - We can also write e.g. `{T | μ ν}ᵀ.tensor` to get the tensor itself. - `{- T | μ ν}ᵀ` is `neg (tensorNode T)`. - `{T | 0 ν}ᵀ` is `eval 0 0 (tensorNode T)`. +- `{T | [μ] ν}ᵀ` is `eval 0 μ (tensorNode T)`. - `{T | μ ν + T' | μ ν}ᵀ` is `addNode (tensorNode T) (perm _ (tensorNode T'))`, where here `_` will be the identity permutation so does nothing. - `{T | μ ν = T' | μ ν}ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T')).tensor`. @@ -75,12 +76,20 @@ syntax num : indexExpr /-- Notation to describe the jiggle of a tensor index. -/ syntax "τ(" ident ")" : indexExpr +/-- Notation to describe the evaulation of a tensor index. -/ +syntax "[" ident "]" : indexExpr + /-- Bool which is true if an index is a num. -/ def indexExprIsNum (stx : Syntax) : Bool := match stx with | `(indexExpr|$_:num) => true | _ => false +def indexExprIsBracketEval(stx : Syntax) : Bool := + match stx with + | `(indexExpr|[$_]) => true + | _ => false + /-- If an index is a num - the underlying natural number. -/ def indexToNum (stx : Syntax) : TermElabM Nat := match stx with @@ -96,6 +105,7 @@ def indexToIdent (stx : Syntax) : TermElabM Ident := match stx with | `(indexExpr|$a:ident) => return a | `(indexExpr| τ($a:ident)) => return a + | `(indexExpr| [$a:ident]) => return a | _ => throwError "Unsupported expression syntax in indexToIdent: {stx}" @@ -144,13 +154,21 @@ def getEvalPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) let pos := evalAdjustPos (evals.map (fun x => x.2)) return List.zip pos evals2 +def getEvalBracketPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × Term)) := do + let indEnum := ind.zipIdx + let evals := indEnum.filter (fun x => indexExprIsBracketEval x.1) + let evals2 ← (evals.mapM (fun x => indexToIdent x.1)) + let pos := evalAdjustPos (evals.map (fun x => x.2)) + return List.zip pos evals2 + /-- For list of `indexExpr` e.g. `[α, 3, β, α, 2, γ]`, `getContrPos` first removes all indices which are numbers (e.g. `[α, β, α, γ]`). It then outputs pairs `(a, b)` in `ℕ × ℕ` of positions of this list with `a < b` such that the index at `a` is equal to the index at `b`. It checks whether or not an element is contracted more then once. -/ def getContrPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) := do - let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) + let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x + ∧ ¬ indexExprIsBracketEval x) let indEnum := indFilt.zipIdx let bind := List.flatMap (fun a => indEnum.map (fun b => (a, b))) indEnum let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2) @@ -388,6 +406,10 @@ def evalTermMap (l : List (ℕ × ℕ)) (T : Term) : Term := l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) #[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T +def evalTermBracketMap (l : List (ℕ × Term)) (T : Term) : Term := + l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) + #[Syntax.mkNumLit (toString x1), x2, T']) T + /-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/ def contrTermMap (n : ℕ) (l : List (ℕ × ℕ)) (T : Term) : Term := let proofTerm := Syntax.mkApp (mkIdent ``Tensor.contrT_decide) #[mkIdent ``rfl] @@ -439,7 +461,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do throwError "The expected number of indices {rawIndex} does not match the tensor {T}." let tensorNodeSyntax := nodeTermMap T let evalSyntax := evalTermMap (← getEvalPos indices) tensorNodeSyntax - let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalSyntax + let evalBracketSyntax := evalTermBracketMap (← getEvalBracketPos indices) evalSyntax + let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalBracketSyntax return contrSyntax | `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do let prodSyntax := prodTermMap (← syntaxFull a) (← syntaxFull b) From ec14528660109c9d2e499531bd3c3555ccba53ed Mon Sep 17 00:00:00 2001 From: jstoobysmith <72603918+jstoobysmith@users.noreply.github.com> Date: Fri, 20 Mar 2026 06:52:25 +0000 Subject: [PATCH 2/3] docs: Add documentation --- PhysLean/Relativity/Tensors/Elab.lean | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/PhysLean/Relativity/Tensors/Elab.lean b/PhysLean/Relativity/Tensors/Elab.lean index 6175f27a1..3e97fb51d 100644 --- a/PhysLean/Relativity/Tensors/Elab.lean +++ b/PhysLean/Relativity/Tensors/Elab.lean @@ -85,6 +85,7 @@ def indexExprIsNum (stx : Syntax) : Bool := | `(indexExpr|$_:num) => true | _ => false +/-- Bool which is true if an index is evaluated bracket `[μ]`. -/ def indexExprIsBracketEval(stx : Syntax) : Bool := match stx with | `(indexExpr|[$_]) => true @@ -154,6 +155,13 @@ def getEvalPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) let pos := evalAdjustPos (evals.map (fun x => x.2)) return List.zip pos evals2 +/-- For list of `indexExpr` e.g. `[α, 3, β, 2, [γ]]`, `getEvalPos` + returns a list of pairs `ℕ × Term` related to indices which are evaluated + e.g. `[μ]`. + The second element of each pair is the value corresponding to that index. + The first element is the position of that number in the list of indices when + all other numbered indices before it are removed. Thus for the example given + `getEvalBracketPos` outputs `[(4, γ)]`. -/ def getEvalBracketPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × Term)) := do let indEnum := ind.zipIdx let evals := indEnum.filter (fun x => indexExprIsBracketEval x.1) @@ -406,6 +414,16 @@ def evalTermMap (l : List (ℕ × ℕ)) (T : Term) : Term := l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) #[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T +/-- Given a list `l` of pairs `ℕ × Term` and a term `T` corresponding to a tensor tree, + for each `(a, b)` in `l`, `evalSyntax` applies `TensorTree.eval a b` to `T` recursively. + Here `a` is the position of the index to be evaluated and + `b` is the value it is evaluated to from the `[μ]` syntax. + + For example, if `l` is `[(1, μ), (1, ν)]` and `T` is a tensor tree then `evalSyntax l T` + is `TensorTree.eval 1 ν (TensorTree.eval 1 μ T)`. + + The list `l` is expected to be the output of `getEvalBracketPos`. +-/ def evalTermBracketMap (l : List (ℕ × Term)) (T : Term) : Term := l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) #[Syntax.mkNumLit (toString x1), x2, T']) T From 279d92ddb8b31a6638e0db0a4d75ef395bb6050c Mon Sep 17 00:00:00 2001 From: jstoobysmith <72603918+jstoobysmith@users.noreply.github.com> Date: Fri, 20 Mar 2026 06:55:33 +0000 Subject: [PATCH 3/3] lint: spelling --- PhysLean/Relativity/Tensors/Elab.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PhysLean/Relativity/Tensors/Elab.lean b/PhysLean/Relativity/Tensors/Elab.lean index 3e97fb51d..c000ed0a1 100644 --- a/PhysLean/Relativity/Tensors/Elab.lean +++ b/PhysLean/Relativity/Tensors/Elab.lean @@ -76,7 +76,7 @@ syntax num : indexExpr /-- Notation to describe the jiggle of a tensor index. -/ syntax "τ(" ident ")" : indexExpr -/-- Notation to describe the evaulation of a tensor index. -/ +/-- Notation to describe the evaluation of a tensor index. -/ syntax "[" ident "]" : indexExpr /-- Bool which is true if an index is a num. -/