Type Classes and Recursion
Type Classes
- A
classdefines an interface: a set of functions that any type can implement - An
instancedeclaration provides the implementation for one specific type; Lean selects the right instance by type inference at the call site - Write
{α : Type} [ClassName α]to say a function works for any type that has an instance; the compiler rejects calls with types that do not deriving BEqgenerates an==instance automatically;deriving Hashabledoes the same for hashing;deriving Reprenables#eval- Any type can get an instance at any time without modifying the type definition, unlike inheritance which requires a shared base class
i
-- class defines an interface: a set of functions any type can implement
class Describable (α : Type) where
describe : α → String
-- Concrete types have no knowledge of each other
structure Point where
x y : Float
deriving Repr
structure Color where
r g b : Nat
deriving Repr
-- instance provides the implementation for a specific type
instance : Describable Point where
describe p := s!"({p.x}, {p.y})"
instance : Describable Color where
describe c := s!"rgb({c.r}, {c.g}, {c.b})"
-- [Describable α] constrains the type variable: works for any type with an instance
def label {α : Type} [Describable α] (x : α) : String :=
s!"value: {Describable.describe x}"
def p1 : Point := { x := 1.0, y := 2.0 }
def c1 : Color := { r := 255, g := 0, b := 0 }
#eval label p1 -- "value: (1.0, 2.0)"
#eval label c1 -- "value: rgb(255, 0, 0)"
-- The same constraint applies to list operations
def describeAll {α : Type} [Describable α] (xs : List α) : List String :=
xs.map Describable.describe
#eval describeAll [p1, { x := 0.0, y := 0.0 }]
-- ["(1.0, 2.0)", "(0.0, 0.0)"]
-- deriving BEq generates == without writing an instance by hand
structure Tag where
name : String
deriving BEq, Repr
#guard ({ name := "lean" } : Tag) == { name := "lean" }
#guard !( ({ name := "lean" } : Tag) == { name := "coq" } )
Recursion and Mutation
- Lean checks that every recursive function terminates; it rejects functions with no provable decreasing measure
- Mark a function
partialto opt out of termination checking; Lean accepts it on faith, and it may loop forever termination_by exprsupplies an explicit decreasing measure when Lean cannot infer one automaticallywhereattaches private helper definitions to a function, keeping them out of the module's namespacelet mut x := vinside adoblock creates a mutable local variable; update it withx := newValId.run doruns adoblock in the identity monad, giving access tolet mutandforloops without requiringIO
i
-- Structural recursion: Lean verifies the argument shrinks at each call
def sumList : List Int → Int
| [] => 0
| x :: xs => x + sumList xs
#guard sumList [1, 2, 3, 4, 5] == 15
-- termination_by: explicit measure for when Lean cannot infer one
def countdown (n : Nat) : List Nat :=
if n == 0 then [0] else n :: countdown (n - 1)
termination_by n
#guard countdown 3 == [3, 2, 1, 0]
-- partial: opt out of termination checking; the function may loop
partial def collatz : Nat → List Nat
| 1 => [1]
| n => n :: collatz (if n % 2 == 0 then n / 2 else 3 * n + 1)
#eval collatz 6 -- [6, 3, 10, 5, 16, 8, 4, 2, 1]
-- where: private helper attached to its parent, invisible outside
def frequencies (xs : List String) : List (String × Nat) :=
xs.foldl tally []
where
tally (acc : List (String × Nat)) (s : String) : List (String × Nat) :=
match acc.find? (·.1 == s) with
| some (_, n) => acc.map fun p => if p.1 == s then (s, n + 1) else p
| none => acc ++ [(s, 1)]
#eval frequencies ["a", "b", "a", "c", "b", "a"]
-- [("a", 3), ("b", 2), ("c", 1)]
-- Id.run do and let mut: local mutable state without IO
def countVowels (s : String) : Nat := Id.run do
let mut n := 0
for c in s.toList do
if "aeiouAEIOU".contains c then
n := n + 1
return n
#guard countVowels "hello world" == 3
Exercises
Summable Class
- Define
class Summable (α : Type)with a methodzero : αandadd : α → α → α - Write instances for
NatandFloat - Write a polymorphic
sumAll : [Summable α] → List α → αand verify it on both types
Tree Depth
- Using the
Treetype from the Algebraic Data Types lesson, writedepth : Tree → Natthat returns the length of the longest root-to-leaf path - Verify Lean accepts the function without
partialand thatTree.Leafhas depth0
Fibonacci with Fuel
- Write
fib (fuel n : Nat) : Natthat returns the nth Fibonacci number, usingfuelas the decreasing measure - Verify
fib 20 10 == 55
Mutable Character Count
- Using
let mutand aforloop insideId.run do, count how many times each distinct character appears in a string - Return a
List (Char × Nat)and verify it on a short string
Flatten Without Partial
- Write
flatten : List (List α) → List αusing structural recursion on the outer list - Verify Lean accepts it without
partialand thatflatten [[1, 2], [3], [4, 5]] == [1, 2, 3, 4, 5]