Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make dot notation be affected by export/open #6189

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 39 additions & 62 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1150,48 +1150,33 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE
throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}"

/--
`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`:
- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)`
- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)`
`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`.
If it resolves to `name`, returns `(S', name)`.
-/
private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do
let env ← getEnv
let find? structName' : MetaM (Option (Name × Name)) := do
let fullName := structName' ++ fieldName
if env.contains fullName then
return some (structName', fullName)
let fullNamePrv := mkPrivateName env fullName
if env.contains fullNamePrv then
return some (structName', fullNamePrv)
return none
-- We do not want to make use of the current namespace for resolution.
let candidates := ResolveName.resolveGlobalName (← getEnv) Name.anonymous (← getOpenDecls) fullName
|>.filter (fun (_, fieldList) => fieldList.isEmpty)
|>.map Prod.fst
match candidates with
| [] => return none
| [fullName'] => return some (structName', fullName')
| _ => throwError "\
invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \
{MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}"
-- Optimization: the first element of the resolution order is `structName`,
-- so we can skip computing the resolution order in the common case
-- of the name resolving in the `structName` namespace.
find? structName <||> do
let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName]
for h : i in [1:resolutionOrder.size] do
if let some res ← find? resolutionOrder[i] then
for ns in resolutionOrder[1:resolutionOrder.size] do
if let some res ← find? ns then
return res
return none

/--
Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and
`fullName` is of the form `structName' ++ fieldName`.

TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or
warning.
-/
private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) :=
let fullName := structName ++ fieldName
-- We never skip `protected` aliases when resolving dot-notation.
let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias =>
match alias.eraseSuffix? fieldName with
| none => none
| some structName' => some (structName', alias)
match aliasesCandidates with
| [r] => some r
| _ => none

private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α :=
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"

Expand Down Expand Up @@ -1223,30 +1208,22 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
| some structName, LVal.fieldName _ fieldName _ _ =>
let env ← getEnv
let searchEnv : Unit → TermElabM LValResolution := fun _ => do
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then
return LValResolution.const structName' structName' fullName
else
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
-- search local context first, then environment
let searchCtx : Unit → TermElabM LValResolution := fun _ => do
let fullName := Name.mkStr structName fieldName
for localDecl in (← getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
searchEnv ()
if isStructure env structName then
match findField? env structName (Name.mkSimple fieldName) with
| some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
| none => searchCtx ()
else
searchCtx ()
if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then
return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
-- Search the local context first
let fullName := Name.mkStr structName fieldName
for localDecl in (← getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
-- Then search the environment
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
| none, LVal.fieldName _ _ (some suffix) _ =>
if e.isConst then
throwUnknownConstant (e.constName! ++ suffix)
Expand Down Expand Up @@ -1326,7 +1303,7 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n

Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
-/
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) (explicit : Bool) :
MetaM (Array Arg × Array NamedArg) := do
withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true
where
Expand All @@ -1351,25 +1328,25 @@ where
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
else
if (← typeMatchesBaseName xDecl.type baseName) then
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
if ← typeMatchesBaseName xDecl.type baseName then
/- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode.
First, we check if the current argument is one that can be used positionally,
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if h : argIdx ≤ args.size ∧ bInfo.isExplicit then
if h : argIdx ≤ args.size ∧ (explicit || bInfo.isExplicit) then
/- We can insert `e` as an explicit argument -/
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
else
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
if there isn't an argument with the same name occurring before it. -/
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
throwError "\
invalid field notation, function '{fullName}' has argument with the expected type\
invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\
{indentExpr xDecl.type}\n\
but it cannot be used"
else
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
/- Advance `argIdx` and update seen named arguments. -/
if bInfo.isExplicit then
if explicit || bInfo.isExplicit then
argIdx := argIdx + 1
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
Expand All @@ -1380,7 +1357,7 @@ where
if let some f' ← coerceToFunction? (mkAppN f xs) then
return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false
throwError "\
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \
it must be explicit or implicit with a unique name"

/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
Expand Down Expand Up @@ -1426,15 +1403,15 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let projFn ← mkConst constName
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn explicit
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.localRec baseName fullName fvar =>
let fvar ← addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar explicit
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
Expand Down
109 changes: 109 additions & 0 deletions tests/lean/run/3031.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/-!
# Tests for generalized field notation through aliases and "top-level" dot notation
https://github.com/leanprover/lean4/issues/3031
-/

/-!
Alias dot notation. There used to be a different kind of alias dot notation;
in the following example, it would have looked for an argument of type `Common.String`.
Now it looks for one of type `String`, allowing libraries to add "extension methods" from within their own namespaces.
-/
def Common.String.a (s : String) : Nat := s.length

export Common (String.a)

/-- info: String.a "x" : Nat -/
#guard_msgs in #check String.a "x"
/-- info: String.a "x" : Nat -/
#guard_msgs in #check "x".a

/-!
Declarations take precedence over aliases
-/
def String.a (s : String) : Nat := s.length + 100
/-- info: "x".a : Nat -/
#guard_msgs in #check "x".a
/-- info: 100 -/
#guard_msgs in #eval "".a

/-!
Private declarations take precedence over aliases
-/
private def String.b (s : String) : Nat := 0
def Common.String.b (s : String) : Nat := 1
export Common (String.b)
/-- info: 0 -/
#guard_msgs in #eval "".b

/-!
Multiple aliases is an error
-/
def Common.String.c (s : String) : Nat := 0
def Common'.String.c (s : String) : Nat := 0
export Common (String.c)
export Common' (String.c)
/--
error: invalid field notation 'c', the name 'String.c' is ambiguous, possible interpretations: 'Common'.String.c', 'Common.String.c'
-/
#guard_msgs in #eval "".c

/-!
Aliases work with inheritance
-/
namespace Ex1
structure A
structure B extends A
def Common.A.x (_ : A) : Nat := 0
export Common (A.x)
/-- info: fun b => A.x b.toA : B → Nat -/
#guard_msgs in #check fun (b : B) => b.x
end Ex1

/-!
`open` also works
-/
def Common.String.parse (_ : String) : List Nat := []

namespace ExOpen1
/--
error: invalid field 'parse', the environment does not contain 'String.parse'
""
has type
String
-/
#guard_msgs in #check "".parse
section
open Common
/-- info: String.parse "" : List Nat -/
#guard_msgs in #check "".parse
end
section
open Common (String.parse)
/-- info: String.parse "" : List Nat -/
#guard_msgs in #check "".parse
end
end ExOpen1


namespace Ex2
class A (n : Nat) where
x : Nat

/-!
Incidental fix: `@` for generalized field notation was failing if there were implicit arguments.
True projections were ok.
-/
def A.x' {n : Nat} (a : A n) := a.x

/-- info: fun a => a.x' : A 2 → Nat -/
#guard_msgs in #check fun (a : A 2) => @a.x'
end Ex2

namespace Ex3
variable (f : α → β) (g : β → γ)
/-!
Functions use the "top-level" dot notation rule: they use the first explicit argument, rather than the first function argument.
-/
/-- info: g ∘ f : α → γ -/
#guard_msgs in #check g.comp f
end Ex3
13 changes: 12 additions & 1 deletion tests/lean/run/DVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ example (v : Vec Nat 1) : Nat :=

#check @Vec.hd

-- works
-- Does not work: Aliases find that `v` could be the `TypeVec` argument since `TypeVec` is an abbrev for `Vec`.
/--
error: application type mismatch
@Vec.hd ?_ v
argument
v
has type
Vec Nat 1 : Type
but is expected to have type
TypeVec (?_ + 1) : Type (_ + 1)
-/
#guard_msgs in set_option pp.mvars false in
example (v : Vec Nat 1) : Nat :=
v.hd
21 changes: 19 additions & 2 deletions tests/lean/run/alias.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@ def Set (α : Type) := α → Prop
def Set.union (s₁ s₂ : Set α) : Set α :=
fun a => s₁ a ∨ s₂ a

def FinSet (n : Nat) := Fin n → Prop
def FinSet (n : Nat) := Set (Fin n)

/-!
The type of `x` is unfolded to find `Set.union`
-/
example (x y : FinSet 10) : FinSet 10 :=
x.union y

namespace FinSet
export Set (union)
export Set (union)
end FinSet

/-!
Since the types are defeq, this alias works:
-/
example (x y : FinSet 10) : FinSet 10 :=
FinSet.union x y

/-!
However, this dot notation fails since there is no `FinSet` argument.
However, unfolding is the preferred method.
-/
/--
error: invalid field notation, function 'FinSet.union' does not have argument with type (FinSet ...) that can be used, it must be explicit or implicit with a unique name
-/
#guard_msgs in
example (x y : FinSet 10) : FinSet 10 :=
x.union y
Loading