diff --git a/PhysLean/Relativity/Tensors/Elab.lean b/PhysLean/Relativity/Tensors/Elab.lean index 55ecacff6..c000ed0a1 100644 --- a/PhysLean/Relativity/Tensors/Elab.lean +++ b/PhysLean/Relativity/Tensors/Elab.lean @@ -23,6 +23,7 @@ public import PhysLean.Relativity.Tensors.Tensorial - We can also write e.g. `{T | μ ν}ᵀ.tensor` to get the tensor itself. - `{- T | μ ν}ᵀ` is `neg (tensorNode T)`. - `{T | 0 ν}ᵀ` is `eval 0 0 (tensorNode T)`. +- `{T | [μ] ν}ᵀ` is `eval 0 μ (tensorNode T)`. - `{T | μ ν + T' | μ ν}ᵀ` is `addNode (tensorNode T) (perm _ (tensorNode T'))`, where here `_` will be the identity permutation so does nothing. - `{T | μ ν = T' | μ ν}ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T')).tensor`. @@ -75,12 +76,21 @@ syntax num : indexExpr /-- Notation to describe the jiggle of a tensor index. -/ syntax "τ(" ident ")" : indexExpr +/-- Notation to describe the evaluation of a tensor index. -/ +syntax "[" ident "]" : indexExpr + /-- Bool which is true if an index is a num. -/ def indexExprIsNum (stx : Syntax) : Bool := match stx with | `(indexExpr|$_:num) => true | _ => false +/-- Bool which is true if an index is evaluated bracket `[μ]`. -/ +def indexExprIsBracketEval(stx : Syntax) : Bool := + match stx with + | `(indexExpr|[$_]) => true + | _ => false + /-- If an index is a num - the underlying natural number. -/ def indexToNum (stx : Syntax) : TermElabM Nat := match stx with @@ -96,6 +106,7 @@ def indexToIdent (stx : Syntax) : TermElabM Ident := match stx with | `(indexExpr|$a:ident) => return a | `(indexExpr| τ($a:ident)) => return a + | `(indexExpr| [$a:ident]) => return a | _ => throwError "Unsupported expression syntax in indexToIdent: {stx}" @@ -144,13 +155,28 @@ def getEvalPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) let pos := evalAdjustPos (evals.map (fun x => x.2)) return List.zip pos evals2 +/-- For list of `indexExpr` e.g. `[α, 3, β, 2, [γ]]`, `getEvalPos` + returns a list of pairs `ℕ × Term` related to indices which are evaluated + e.g. `[μ]`. + The second element of each pair is the value corresponding to that index. + The first element is the position of that number in the list of indices when + all other numbered indices before it are removed. Thus for the example given + `getEvalBracketPos` outputs `[(4, γ)]`. -/ +def getEvalBracketPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × Term)) := do + let indEnum := ind.zipIdx + let evals := indEnum.filter (fun x => indexExprIsBracketEval x.1) + let evals2 ← (evals.mapM (fun x => indexToIdent x.1)) + let pos := evalAdjustPos (evals.map (fun x => x.2)) + return List.zip pos evals2 + /-- For list of `indexExpr` e.g. `[α, 3, β, α, 2, γ]`, `getContrPos` first removes all indices which are numbers (e.g. `[α, β, α, γ]`). It then outputs pairs `(a, b)` in `ℕ × ℕ` of positions of this list with `a < b` such that the index at `a` is equal to the index at `b`. It checks whether or not an element is contracted more then once. -/ def getContrPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) := do - let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) + let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x + ∧ ¬ indexExprIsBracketEval x) let indEnum := indFilt.zipIdx let bind := List.flatMap (fun a => indEnum.map (fun b => (a, b))) indEnum let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2) @@ -388,6 +414,20 @@ def evalTermMap (l : List (ℕ × ℕ)) (T : Term) : Term := l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) #[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T +/-- Given a list `l` of pairs `ℕ × Term` and a term `T` corresponding to a tensor tree, + for each `(a, b)` in `l`, `evalSyntax` applies `TensorTree.eval a b` to `T` recursively. + Here `a` is the position of the index to be evaluated and + `b` is the value it is evaluated to from the `[μ]` syntax. + + For example, if `l` is `[(1, μ), (1, ν)]` and `T` is a tensor tree then `evalSyntax l T` + is `TensorTree.eval 1 ν (TensorTree.eval 1 μ T)`. + + The list `l` is expected to be the output of `getEvalBracketPos`. +-/ +def evalTermBracketMap (l : List (ℕ × Term)) (T : Term) : Term := + l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT) + #[Syntax.mkNumLit (toString x1), x2, T']) T + /-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/ def contrTermMap (n : ℕ) (l : List (ℕ × ℕ)) (T : Term) : Term := let proofTerm := Syntax.mkApp (mkIdent ``Tensor.contrT_decide) #[mkIdent ``rfl] @@ -439,7 +479,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do throwError "The expected number of indices {rawIndex} does not match the tensor {T}." let tensorNodeSyntax := nodeTermMap T let evalSyntax := evalTermMap (← getEvalPos indices) tensorNodeSyntax - let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalSyntax + let evalBracketSyntax := evalTermBracketMap (← getEvalBracketPos indices) evalSyntax + let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalBracketSyntax return contrSyntax | `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do let prodSyntax := prodTermMap (← syntaxFull a) (← syntaxFull b)