Skip to content

Commit

Permalink
chore: avoid runtime array bounds checks (#6134)
Browse files Browse the repository at this point in the history
This PR avoids runtime array bounds checks in places where it can
trivially be done at compile time.

None of these changes are of particular consequence: I mostly wanted to
learn how much we do this, and what the obstacles are to doing it less.
  • Loading branch information
kim-em authored Nov 21, 2024
1 parent 56a80de commit 72e952e
Show file tree
Hide file tree
Showing 34 changed files with 130 additions and 130 deletions.
17 changes: 8 additions & 9 deletions src/Init/Meta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -431,21 +431,20 @@ def getSubstring? (stx : Syntax) (withLeading := true) (withTrailing := true) :
}
| _, _ => none

@[specialize] private partial def updateLast {α} [Inhabited α] (a : Array α) (f : α → Option α) (i : Nat) : Option (Array α) :=
if i == 0 then
none
else
let i := i - 1
let v := a[i]!
@[specialize] private partial def updateLast {α} (a : Array α) (f : α → Option α) (i : Fin (a.size + 1)) : Option (Array α) :=
match i with
| 0 => none
| ⟨i + 1, h⟩ =>
let v := a[i]'(Nat.succ_lt_succ_iff.mp h)
match f v with
| some v => some <| a.set! i v
| none => updateLast a f i
| some v => some <| a.set i v (Nat.succ_lt_succ_iff.mp h)
| none => updateLast a f ⟨i, Nat.lt_of_succ_lt h⟩

partial def setTailInfoAux (info : SourceInfo) : Syntax → Option Syntax
| atom _ val => some <| atom info val
| ident _ rawVal val pre => some <| ident info rawVal val pre
| node info' k args =>
match updateLast args (setTailInfoAux info) args.size with
match updateLast args (setTailInfoAux info) args.size, by simp⟩ with
| some args => some <| node info' k args
| none => none
| _ => none
Expand Down
20 changes: 10 additions & 10 deletions src/Init/NotationExtra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ syntax explicitBinders := (ppSpace bracketedExplicitBinders)+ <|> unb

open TSyntax.Compat in
def expandExplicitBindersAux (combinator : Syntax) (idents : Array Syntax) (type? : Option Syntax) (body : Syntax) : MacroM Syntax :=
let rec loop (i : Nat) (acc : Syntax) := do
let rec loop (i : Nat) (h : i ≤ idents.size) (acc : Syntax) := do
match i with
| 0 => pure acc
| i+1 =>
let ident := idents[i]![0]
| i + 1 =>
let ident := idents[i][0]
let acc ← match ident.isIdent, type? with
| true, none => `($combinator fun $ident => $acc)
| true, some type => `($combinator fun $ident : $type => $acc)
| false, none => `($combinator fun _ => $acc)
| false, some type => `($combinator fun _ : $type => $acc)
loop i acc
loop idents.size body
loop i (Nat.le_of_succ_le h) acc
loop idents.size (by simp) body

def expandBrackedBindersAux (combinator : Syntax) (binders : Array Syntax) (body : Syntax) : MacroM Syntax :=
let rec loop (i : Nat) (acc : Syntax) := do
let rec loop (i : Nat) (h : i ≤ binders.size) (acc : Syntax) := do
match i with
| 0 => pure acc
| i+1 =>
let idents := binders[i]![1].getArgs
let type := binders[i]![3]
loop i (← expandExplicitBindersAux combinator idents (some type) acc)
loop binders.size body
let idents := binders[i][1].getArgs
let type := binders[i][3]
loop i (Nat.le_of_succ_le h) (← expandExplicitBindersAux combinator idents (some type) acc)
loop binders.size (by simp) body

def expandExplicitBinders (combinatorDeclName : Name) (explicitBinders : Syntax) (body : Syntax) : MacroM Syntax := do
let combinator := mkCIdentFrom (← getRef) combinatorDeclName
Expand Down
12 changes: 6 additions & 6 deletions src/Init/System/Uri.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def decodeUri (uri : String) : String := Id.run do
let len := rawBytes.size
let mut i := 0
let percent := '%'.toNat.toUInt8
while i < len do
let c := rawBytes[i]!
(decoded, i) := if c == percent && i + 1 < len then
let h1 := rawBytes[i + 1]!
while h : i < len do
let c := rawBytes[i]
(decoded, i) := if h₁ : c == percent i + 1 < len then
let h1 := rawBytes[i + 1]
if let some hd1 := hexDigitToUInt8? h1 then
if i + 2 < len then
let h2 := rawBytes[i + 2]!
if h₂ : i + 2 < len then
let h2 := rawBytes[i + 2]
if let some hd2 := hexDigitToUInt8? h2 then
-- decode the hex digits into a byte.
(decoded.push (hd1 * 16 + hd2), i + 3)
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Compiler/IR/EmitC.lean
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def emitTag (x : VarId) (xType : IRType) : M Unit := do
emit x

def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) :=
if alts.size != 2 then none
else match alts[0]! with
| Alt.ctor c b => some (c.cidx, b, alts[1]!.body)
if h : alts.size 2 then none
else match alts[0] with
| Alt.ctor c b => some (c.cidx, b, alts[1].body)
| _ => none

def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do
Expand Down
13 changes: 7 additions & 6 deletions src/Lean/Compiler/IR/EmitLLVM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,8 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
(needsPackedArgs? : Bool) (llvmfn : LLVM.Value llvmctx) (params : Array Param) : M llvmctx Unit := do
if needsPackedArgs? then do
let argsp ← LLVM.getParam llvmfn 0 -- lean_object **args
for i in List.range params.size do
let param := params[i]!
for h : i in [:params.size] do
let param := params[i]
-- argsi := (args + i)
let argsi ← LLVM.buildGEP2 builder (← LLVM.voidPtrType llvmctx) argsp #[← constIntUnsigned i] s!"packed_arg_{i}_slot"
let llvmty ← toLLVMType param.ty
Expand All @@ -1182,15 +1182,16 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
-- slot for arg[i] which is always void* ?
let alloca ← buildPrologueAlloca builder llvmty s!"arg_{i}"
LLVM.buildStore builder pv alloca
addVartoState params[i]!.x alloca llvmty
addVartoState param.x alloca llvmty
else
let n ← LLVM.countParams llvmfn
for i in (List.range n.toNat) do
let llvmty ← toLLVMType params[i]!.ty
for i in [:n.toNat] do
let param := params[i]!
let llvmty ← toLLVMType param.ty
let alloca ← buildPrologueAlloca builder llvmty s!"arg_{i}"
let arg ← LLVM.getParam llvmfn (UInt64.ofNat i)
let _ ← LLVM.buildStore builder arg alloca
addVartoState params[i]!.x alloca llvmty
addVartoState param.x alloca llvmty

def emitDeclAux (mod : LLVM.Module llvmctx) (builder : LLVM.Builder llvmctx) (d : Decl) : M llvmctx Unit := do
let env ← getEnv
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Compiler/IR/ExpandResetReuse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ abbrev Mask := Array (Option VarId)
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
if bs.size < 2 then done ()
if h : bs.size < 2 then done ()
else
let b := bs.back!
match b with
| .vdecl _ _ (.sproj _ _ _) _ => keepInstr b
| .vdecl _ _ (.uproj _ _) _ => keepInstr b
| .inc z n c p _ =>
if n == 0 then done () else
let b' := bs[bs.size - 2]!
let b' := bs[bs.size - 2]
match b' with
| .vdecl w _ (.proj i x) _ =>
if w == z && y == x then
Expand Down
10 changes: 5 additions & 5 deletions src/Lean/Compiler/LCNF/SpecInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
let specArgs? := getSpecializationArgs? (← getEnv) decl.name
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
let mut paramsInfo : Array SpecParamInfo := #[]
for i in [:decl.params.size] do
let param := decl.params[i]!
for h :i in [:decl.params.size] do
let param := decl.params[i]
let info ←
if contains i then
pure .user
Expand Down Expand Up @@ -181,14 +181,14 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
declsInfo := declsInfo.push paramsInfo
if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then
let m := mkFixedParamsMap decls
for i in [:decls.size] do
let decl := decls[i]!
for hi : i in [:decls.size] do
let decl := decls[i]
let mut paramsInfo := declsInfo[i]!
let some mask := m.find? decl.name | unreachable!
trace[Compiler.specialize.info] "{decl.name} {mask}"
paramsInfo := paramsInfo.zipWith mask fun info fixed => if fixed || info matches .user then info else .other
for j in [:paramsInfo.size] do
let mut info := paramsInfo[j]!
let mut info := paramsInfo[j]!
if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then
paramsInfo := paramsInfo.set! j .other
if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Compiler/LCNF/ToLCNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ where
match app with
| .fvar f =>
let mut argsNew := #[]
for i in [arity : args.size] do
argsNew := argsNew.push (← visitAppArg args[i]!)
for h :i in [arity : args.size] do
argsNew := argsNew.push (← visitAppArg args[i])
letValueToArg <| .fvar f argsNew
| .erased | .type .. => return .erased

Expand Down
7 changes: 4 additions & 3 deletions src/Lean/Compiler/Specialize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ private def elabSpecArgs (declName : Name) (args : Array Syntax) : MetaM (Array
if let some idx := arg.isNatLit? then
if idx == 0 then throwErrorAt arg "invalid specialization argument index, index must be greater than 0"
let idx := idx - 1
if idx >= argNames.size then
if h : idx >= argNames.size then
throwErrorAt arg "invalid argument index, `{declName}` has #{argNames.size} arguments"
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]!}` has already been specified as a specialization candidate"
result := result.push idx
else
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]}` has already been specified as a specialization candidate"
result := result.push idx
else
let argName := arg.getId
if let some idx := argNames.indexOf? argName then
Expand Down
8 changes: 4 additions & 4 deletions src/Lean/Elab/Deriving/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ invoking ``mkInstImplicitBinders `BarClass foo #[`α, `n, `β]`` gives `` `([Bar
def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) :=
forallBoundedTelescope indVal.type indVal.numParams fun xs _ => do
let mut binders := #[]
for i in [:xs.size] do
for h : i in [:xs.size] do
try
let x := xs[i]!
let x := xs[i]
let c ← mkAppM className #[x]
if (← isTypeCorrect c) then
let argName := argNames[i]!
Expand Down Expand Up @@ -86,8 +86,8 @@ def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do

def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do
let mut letDecls := #[]
for i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]!
for h : i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]
let auxFunName := ctx.auxFunNames[i]!
let currArgNames ← mkInductArgNames indVal
let numParams := indVal.numParams
Expand Down
16 changes: 8 additions & 8 deletions src/Lean/Elab/Do.lean
Original file line number Diff line number Diff line change
Expand Up @@ -796,10 +796,10 @@ Note that we are not restricting the macro power since the
actions to be in the same universe.
-/
private def mkTuple (elems : Array Syntax) : MacroM Syntax := do
if elems.size == 0 then
if elems.size = 0 then
mkUnit
else if elems.size == 1 then
return elems[0]!
else if h : elems.size = 1 then
return elems[0]
else
elems.extract 0 (elems.size - 1) |>.foldrM (init := elems.back!) fun elem tuple =>
``(MProd.mk $elem $tuple)
Expand Down Expand Up @@ -831,10 +831,10 @@ def isDoExpr? (doElem : Syntax) : Option Syntax :=
We use this method when expanding the `for-in` notation.
-/
private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
if uvars.size == 0 then
if uvars.size = 0 then
return body
else if uvars.size == 1 then
`(let $(uvars[0]!):ident := $x; $body)
else if h : uvars.size = 1 then
`(let $(uvars[0]):ident := $x; $body)
else
destruct uvars.toList x body
where
Expand Down Expand Up @@ -1314,9 +1314,9 @@ private partial def expandLiftMethodAux (inQuot : Bool) (inBinder : Bool) : Synt
else if liftMethodDelimiter k then
return stx
-- For `pure` if-then-else, we only lift `(<- ...)` occurring in the condition.
else if args.size >= 2 && (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
else if h : args.size >= 2 (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
let inAntiquot := stx.isAntiquot && !stx.isEscapedAntiquot
let arg1 ← expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]!
let arg1 ← expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]
let args := args.set! 1 arg1
return Syntax.node i k args
else if k == ``Parser.Term.liftMethod && !inQuot then withFreshMacroScope do
Expand Down
20 changes: 10 additions & 10 deletions src/Lean/Elab/Inductive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ private def checkUnsafe (rs : Array ElabHeaderResult) : TermElabM Unit := do
throwErrorAt r.view.ref "invalid inductive type, cannot mix unsafe and safe declarations in a mutually inductive datatypes"

private def InductiveView.checkLevelNames (views : Array InductiveView) : TermElabM Unit := do
if views.size > 1 then
let levelNames := views[0]!.levelNames
if h : views.size > 1 then
let levelNames := views[0].levelNames
for view in views do
unless view.levelNames == levelNames do
throwErrorAt view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"

private def ElabHeaderResult.checkLevelNames (rs : Array ElabHeaderResult) : TermElabM Unit := do
if rs.size > 1 then
let levelNames := rs[0]!.levelNames
if h : rs.size > 1 then
let levelNames := rs[0].levelNames
for r in rs do
unless r.levelNames == levelNames do
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
Expand Down Expand Up @@ -433,8 +433,8 @@ where
let mut args := e.getAppArgs
unless args.size ≥ params.size do
throwError "unexpected inductive type occurrence{indentExpr e}"
for i in [:params.size] do
let param := params[i]!
for h : i in [:params.size] do
let param := params[i]
let arg := args[i]!
unless (← isDefEq param arg) do
throwError "inductive datatype parameter mismatch{indentExpr arg}\nexpected{indentExpr param}"
Expand Down Expand Up @@ -694,8 +694,8 @@ private def collectLevelParamsInInductive (indTypes : List InductiveType) : Arra
private def mkIndFVar2Const (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name) : ExprMap Expr := Id.run do
let levelParams := levelNames.map mkLevelParam;
let mut m : ExprMap Expr := {}
for i in [:views.size] do
let view := views[i]!
for h : i in [:views.size] do
let view := views[i]
let indFVar := indFVars[i]!
m := m.insert indFVar (mkConst view.declName levelParams)
return m
Expand Down Expand Up @@ -856,9 +856,9 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
withInductiveLocalDecls rs fun params indFVars => do
trace[Elab.inductive] "indFVars: {indFVars}"
let mut indTypesArray := #[]
for i in [:views.size] do
for h : i in [:views.size] do
let indFVar := indFVars[i]!
Term.addLocalVarInfo views[i]!.declId indFVar
Term.addLocalVarInfo views[i].declId indFVar
let r := rs[i]!
/- At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/LetRec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ private def elabLetRecDeclValues (view : LetRecView) : TermElabM (Array Expr) :=
view.decls.mapM fun view => do
forallBoundedTelescope view.type view.binderIds.size fun xs type => do
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
for i in [0:view.binderIds.size] do
addLocalVarInfo view.binderIds[i]! xs[i]!
for h : i in [0:view.binderIds.size] do
addLocalVarInfo view.binderIds[i] xs[i]!
withDeclName view.declName do
withInfoContext' view.valStx
(mkInfo := (pure <| .inl <| mkBodyInfo view.valStx ·))
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/Match.lean
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ where
let dArg := dArgs[i]!
unless (← isDefEq tArg dArg) do
return i :: (← goType tArg dArg)
for i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]!
for h : i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]
let dArg := dArgs[i]!
unless (← isDefEq tArg dArg) do
return i :: (← goIndex tArg dArg)
Expand Down
8 changes: 4 additions & 4 deletions src/Lean/Elab/Open.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ private def resolveNameUsingNamespacesCore (nss : List Name) (idStx : Syntax) :
exs := exs.push ex
if exs.size == nss.length then
withRef idStx do
if exs.size == 1 then
throw exs[0]!
if h : exs.size = 1 then
throw exs[0]
else
throwErrorWithNestedErrors "failed to open" exs
if result.size == 1 then
return result[0]!
if h : result.size = 1 then
return result[0]
else
withRef idStx do throwError "ambiguous identifier '{idStx.getId}', possible interpretations: {result.map mkConst}"

Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/PreDefinition/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def checkCodomainsLevel (preDefs : Array PreDefinition) : MetaM Unit := do
lambdaTelescope preDef.value fun xs _ => return xs.size
forallBoundedTelescope preDefs[0]!.type arities[0]! fun _ type₀ => do
let u₀ ← getLevel type₀
for i in [1:preDefs.size] do
forallBoundedTelescope preDefs[i]!.type arities[i]! fun _ typeᵢ =>
for h : i in [1:preDefs.size] do
forallBoundedTelescope preDefs[i].type arities[i]! fun _ typeᵢ =>
unless ← isLevelDefEq u₀ (← getLevel typeᵢ) do
withOptions (fun o => pp.sanitizeNames.set o false) do
throwError m!"invalid mutual definition, result types must be in the same universe " ++
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/PreDefinition/TerminationArgument.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def TerminationArgument.elab (funName : Name) (type : Expr) (arity extraParams :
(hint : TerminationBy) : TermElabM TerminationArgument := withDeclName funName do
assert! extraParams ≤ arity

if hint.vars.size > extraParams then
if h : hint.vars.size > extraParams then
let mut msg := m!"{parameters hint.vars.size} bound in `termination_by`, but the body of " ++
m!"{funName} only binds {parameters extraParams}."
if let `($ident:ident) := hint.vars[0]! then
if let `($ident:ident) := hint.vars[0] then
if ident.getId.isSuffixOf funName then
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
"expects the function name here.)"
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/PreDefinition/TerminationHint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ lambda of `value`, and throws appropriate errors.
-/
def TerminationBy.checkVars (funName : Name) (extraParams : Nat) (tb : TerminationBy) : MetaM Unit := do
unless tb.synthetic do
if tb.vars.size > extraParams then
if h : tb.vars.size > extraParams then
let mut msg := m!"{parameters tb.vars.size} bound in `termination_by`, but the body of " ++
m!"{funName} only binds {parameters extraParams}."
if let `($ident:ident) := tb.vars[0]! then
if let `($ident:ident) := tb.vars[0] then
if ident.getId.isSuffixOf funName then
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
"expects the function name here.)"
Expand Down
Loading

0 comments on commit 72e952e

Please sign in to comment.