Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions PhysLean/Relativity/Tensors/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public import PhysLean.Relativity.Tensors.Tensorial
- We can also write e.g. `{T | μ ν}ᵀ.tensor` to get the tensor itself.
- `{- T | μ ν}ᵀ` is `neg (tensorNode T)`.
- `{T | 0 ν}ᵀ` is `eval 0 0 (tensorNode T)`.
- `{T | [μ] ν}ᵀ` is `eval 0 μ (tensorNode T)`.
- `{T | μ ν + T' | μ ν}ᵀ` is `addNode (tensorNode T) (perm _ (tensorNode T'))`, where
here `_` will be the identity permutation so does nothing.
- `{T | μ ν = T' | μ ν}ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T')).tensor`.
Expand Down Expand Up @@ -75,12 +76,21 @@ syntax num : indexExpr
/-- Notation to describe the jiggle of a tensor index. -/
syntax "τ(" ident ")" : indexExpr

/-- Notation to describe the evaluation of a tensor index. -/
syntax "[" ident "]" : indexExpr

/-- Bool which is true if an index is a num. -/
def indexExprIsNum (stx : Syntax) : Bool :=
match stx with
| `(indexExpr|$_:num) => true
| _ => false

/-- Bool which is true if an index is evaluated bracket `[μ]`. -/
def indexExprIsBracketEval(stx : Syntax) : Bool :=
match stx with
| `(indexExpr|[$_]) => true
| _ => false

/-- If an index is a num - the underlying natural number. -/
def indexToNum (stx : Syntax) : TermElabM Nat :=
match stx with
Expand All @@ -96,6 +106,7 @@ def indexToIdent (stx : Syntax) : TermElabM Ident :=
match stx with
| `(indexExpr|$a:ident) => return a
| `(indexExpr| τ($a:ident)) => return a
| `(indexExpr| [$a:ident]) => return a
| _ =>
throwError "Unsupported expression syntax in indexToIdent: {stx}"

Expand Down Expand Up @@ -144,13 +155,28 @@ def getEvalPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ))
let pos := evalAdjustPos (evals.map (fun x => x.2))
return List.zip pos evals2

/-- For list of `indexExpr` e.g. `[α, 3, β, 2, [γ]]`, `getEvalPos`
returns a list of pairs `ℕ × Term` related to indices which are evaluated
e.g. `[μ]`.
The second element of each pair is the value corresponding to that index.
The first element is the position of that number in the list of indices when
all other numbered indices before it are removed. Thus for the example given
`getEvalBracketPos` outputs `[(4, γ)]`. -/
def getEvalBracketPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × Term)) := do
let indEnum := ind.zipIdx
let evals := indEnum.filter (fun x => indexExprIsBracketEval x.1)
let evals2 ← (evals.mapM (fun x => indexToIdent x.1))
let pos := evalAdjustPos (evals.map (fun x => x.2))
return List.zip pos evals2

/-- For list of `indexExpr` e.g. `[α, 3, β, α, 2, γ]`, `getContrPos`
first removes all indices which are numbers (e.g. `[α, β, α, γ]`).
It then outputs pairs `(a, b)` in `ℕ × ℕ` of positions of this list with `a < b`
such that the index at `a` is equal to the index at `b`. It checks whether or not
an element is contracted more then once. -/
def getContrPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) := do
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x
∧ ¬ indexExprIsBracketEval x)
let indEnum := indFilt.zipIdx
let bind := List.flatMap (fun a => indEnum.map (fun b => (a, b))) indEnum
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
Expand Down Expand Up @@ -388,6 +414,20 @@ def evalTermMap (l : List (ℕ × ℕ)) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T

/-- Given a list `l` of pairs `ℕ × Term` and a term `T` corresponding to a tensor tree,
for each `(a, b)` in `l`, `evalSyntax` applies `TensorTree.eval a b` to `T` recursively.
Here `a` is the position of the index to be evaluated and
`b` is the value it is evaluated to from the `[μ]` syntax.

For example, if `l` is `[(1, μ), (1, ν)]` and `T` is a tensor tree then `evalSyntax l T`
is `TensorTree.eval 1 ν (TensorTree.eval 1 μ T)`.

The list `l` is expected to be the output of `getEvalBracketPos`.
-/
def evalTermBracketMap (l : List (ℕ × Term)) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``Tensor.evalT)
#[Syntax.mkNumLit (toString x1), x2, T']) T

/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
def contrTermMap (n : ℕ) (l : List (ℕ × ℕ)) (T : Term) : Term :=
let proofTerm := Syntax.mkApp (mkIdent ``Tensor.contrT_decide) #[mkIdent ``rfl]
Expand Down Expand Up @@ -439,7 +479,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
let tensorNodeSyntax := nodeTermMap T
let evalSyntax := evalTermMap (← getEvalPos indices) tensorNodeSyntax
let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalSyntax
let evalBracketSyntax := evalTermBracketMap (← getEvalBracketPos indices) evalSyntax
let contrSyntax := contrTermMap indices.length (← getContrPos indices) evalBracketSyntax
return contrSyntax
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let prodSyntax := prodTermMap (← syntaxFull a) (← syntaxFull b)
Expand Down
Loading