Skip to content

Commit

Permalink
implement type checking for sgir (#3)
Browse files Browse the repository at this point in the history
This is a first cut of implementing type checking for SGIR, including
adding some features that are required for us to make things work.
  • Loading branch information
aatxe authored Mar 6, 2024
1 parent 216ecef commit 6ee4899
Show file tree
Hide file tree
Showing 3 changed files with 435 additions and 24 deletions.
50 changes: 40 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
use crate::sgir::Binding;
use crate::sgir::{Binding, TypeBinding};

mod sgir;

fn main() {
use sgir::Expression::*;
use sgir::Type;
use sgir::{Kind, Type};

let identity = Quantify {
parameters: vec![TypeBinding { id: "X".to_owned(), kind: Kind::Star }],
body: Box::new(Function {
parameters: vec![Binding { id: "x".to_owned(), typ: Type::Variable("X".to_owned()) }],
body: Box::new(Variable("x".to_owned())),
}),
};

let application = Application {
function: Box::new(Instantiate {
function: Box::new(identity),
arguments: vec![Type::Number]
}),
arguments: vec![
Number(11)
],
};

let prog = Application {
function: Box::new(Function {
Expand All @@ -15,18 +33,30 @@ fn main() {
Binding { id: "j".to_owned(), typ: Type::Number },
],
body: Box::new(Variable("alex!".to_owned())),
}),
}), // (number, boolean, number, number) -> number
arguments: vec![
Number(420),
Boolean(true),
Function {
parameters: vec![Binding { id: "x".to_owned(), typ: Type::Number }],
body: Box::new(Variable("x".to_owned())),
},
Number(694208008135),
application, // : number
Boolean(true), // : boolean
Application {
function: Box::new(Function {
parameters: vec![Binding { id: "x".to_owned(), typ: Type::Number }],
body: Box::new(Variable("x".to_owned())),
}), // : (number) -> number
arguments: vec![Number(42)], // : number
}, // : number
Number(694208008135), // : number
]
};

let typ = match sgir::check(prog.clone()) {
Ok(typ) => typ,
Err(type_error) => {
eprintln!("[ERROR] {:?}", type_error);
return
},
};
println!("TYPE: {:?}", typ);

let result = sgir::run(prog);
println!("{:?}", result);
}
202 changes: 200 additions & 2 deletions src/sgir/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
//! This module implements `sgir`, an intermediate representation for sanguinello.
//!
//! It is based on System Fω with explicit typing.
use std::collections::HashMap;
use thiserror::Error;

Expand All @@ -8,7 +12,10 @@ type Identifier = String;

#[derive(Clone, Debug, PartialEq)]
pub enum Kind {
/// The type of types.
Star,

/// A type constructor, or type function.
Arrow {
from: Vec<Kind>,
to: Box<Kind>,
Expand All @@ -31,6 +38,7 @@ pub enum Type {
parameters: Vec<TypeBinding>,
typ: Box<Type>,
},

/// type instantiation, e.g. T<U...>
Instantiate {
typ: Box<Type>,
Expand All @@ -44,20 +52,84 @@ pub enum Type {
arguments: Vec<Type>,
result: Box<Type>,
},

/// a boolean
Boolean,

/// a number
Number,
}

type TypeSubstitution = HashMap<Identifier, Type>;

impl Type {
fn apply(self, subst: &TypeSubstitution) -> Type {
match self {
Type::Variable(id) => match subst.get(&id) {
Some(replacement) => replacement.clone(),
None => Type::Variable(id),
},

Type::ForAll { parameters, typ } => {
// to handle shadowing properly, we have to remove any occurrences
// of any of the `parameters` found in the substitution.
let mut extended_subst = subst.clone();
for TypeBinding { id, .. } in parameters.iter() {
extended_subst.remove(id);
}

Type::ForAll {
parameters,
typ: Box::new(typ.apply(&extended_subst)),
}
},

Type::Instantiate { typ, arguments } => Type::Instantiate {
typ: Box::new(typ.apply(&subst)),
arguments: arguments.into_iter().map(|typ| typ.apply(&subst)).collect(),
},

Type::Function { arguments, result } => Type::Function {
arguments: arguments.into_iter().map(|typ| typ.apply(&subst)).collect(),
result: Box::new(result.apply(&subst)),
},

Type::Boolean => Type::Boolean,
Type::Number => Type::Number,
}
}
}

#[derive(Debug, Error, Clone, PartialEq)]
enum TypeError {
pub enum TypeError {
#[error("kind mismatch: expected {expected:?}, found {found:?}")]
KindMismatch {
expected: Kind,
found: Kind,
},

#[error("type mismatch: expected {expected:?}, found {found:?}")]
TypeMismatch {
expected: Type,
found: Type,
},

#[error("arity mismatch: expected {expected} arguments, found {found}")]
ArityMismatch {
expected: usize,
found: usize,
},

#[error("cannot call a non-function: {found:?}")]
CannotCallNonFunction {
found: Type,
},

#[error("cannot instantiate a non-quantification: {found:?}")]
CannotCallNonQuantification {
found: Type,
},

#[error("kind mismatch: expected a quantifier in type {found:?}")]
ExpectedQuantifier {
found: Type,
Expand All @@ -67,7 +139,7 @@ enum TypeError {
UnboundIdentifier(Identifier),
}

type TC<T> = Result<T, TypeError>;
pub type TC<T> = Result<T, TypeError>;

type KindEnv = HashMap<Identifier, Kind>;

Expand Down Expand Up @@ -122,6 +194,18 @@ pub enum Expression {
Boolean(bool),
Number(i64), // haha, this should be a bignum

/// type quantification (i.e. \Lambda)
Quantify {
parameters: Vec<TypeBinding>,
body: Box<Expression>,
},

/// type instantiation (i.e. application of a \Lambda)
Instantiate {
function: Box<Expression>,
arguments: Vec<Type>,
},

Function {
parameters: Vec<Binding>,
body: Box<Expression>,
Expand All @@ -133,6 +217,116 @@ pub enum Expression {
},
}

type TypeEnv = HashMap<Identifier, Type>;

fn check_type(tenv: &TypeEnv, kenv: &KindEnv, expr: Expression) -> TC<Type> {
match expr {
Expression::Variable(id) => match tenv.get(&id) {
Some(ty) => Ok(ty.clone()),
None => Err(TypeError::UnboundIdentifier(id.clone())),
},

Expression::Boolean(_) => Ok(Type::Boolean),

Expression::Number(_) => Ok(Type::Number),

Expression::Quantify { parameters, body } => {
let mut extended_kenv = kenv.clone();
extended_kenv.extend(parameters.clone().into_iter()
.map(|TypeBinding { id, kind }| (id, kind)));

let typ = check_type(tenv, &extended_kenv, *body)?;
match check_kinds(kenv, typ.clone())? {
// the resulting type is a type...
Kind::Star => Ok(Type::ForAll { parameters , typ: Box::new(typ) }),

// the resulting type is a type function...
kind => Err(TypeError::KindMismatch { expected: Kind::Star, found: kind }),
}
}

Expression::Instantiate { function, arguments } => {
match check_type(tenv, kenv, *function)? {
Type::ForAll { parameters, typ } => {
if arguments.len() != parameters.len() {
return Err(TypeError::ArityMismatch { expected: parameters.len(), found: arguments.len() })
}

let argument_kind_pairs = arguments.clone().into_iter()
.zip(parameters.iter().map(|TypeBinding { kind, .. }| kind.clone()));

for (argument, expected_kind) in argument_kind_pairs {
let computed_kind = check_kinds(kenv, argument)?;

if computed_kind != expected_kind {
return Err(TypeError::KindMismatch { expected: expected_kind, found: computed_kind })
}
}

// we have to substitute the types in `arguments` for the type parameters in `parameters`
let subst: TypeSubstitution = parameters.into_iter()
.map(|TypeBinding { id, .. }| id)
.zip(arguments)
.collect();

Ok(typ.apply(&subst))
},

// Unexpected type here, it must be a forall!
found => Err(TypeError::CannotCallNonQuantification { found })
}
}

Expression::Function { parameters, body } => {
for Binding { typ, .. } in parameters.iter() {
if let kind@Kind::Arrow { .. } = check_kinds(kenv, typ.clone())? {
return Err(TypeError::KindMismatch { expected: Kind::Star, found: kind })
}
}

let arguments = parameters.iter()
.map(|Binding { typ, .. }| typ.clone())
.collect();

let mut extended_tenv = tenv.clone();
extended_tenv.extend(parameters.into_iter()
.map(|Binding { id, typ }| (id, typ)));
let result = Box::new(check_type(&extended_tenv, &kenv, *body)?);

Ok(Type::Function { arguments, result })
},

Expression::Application { function, arguments } => {
match check_type(tenv, kenv, *function)? {
Type::Function { arguments: expected_types, result: result_type } => {
if arguments.len() != expected_types.len() {
return Err(TypeError::ArityMismatch { expected: expected_types.len(), found: arguments.len() })
}

for (argument, expected_type) in arguments.into_iter().zip(expected_types) {
let computed_type = check_type(tenv, kenv, argument)?;

if computed_type != expected_type {
return Err(TypeError::TypeMismatch { expected: expected_type, found: computed_type })
}
}

Ok(*result_type)
},

// Unexpected type here, it must be a function!
found => Err(TypeError::CannotCallNonFunction { found })
}
},
}
}

pub fn check(expr: Expression) -> TC<Type> {
let type_env = HashMap::new();
let kind_env = HashMap::new();
check_type(&type_env, &kind_env, expr)
}

#[derive(Clone, Debug)]
pub enum Value {
// Primitives
Expand All @@ -152,6 +346,10 @@ fn eval(subst: &Substitution, expr: Expression) -> Value {
Expression::Variable(identifier) => subst[&identifier].clone(),
Expression::Boolean(value) => Value::Boolean(value),
Expression::Number(value) => Value::Number(value),
// quantification has no runtime semantics
Expression::Quantify { body, .. } => eval(subst, *body),
// instantiation has no runtime semantics
Expression::Instantiate { function, .. } => eval(subst, *function),
Expression::Function { parameters, body } => Value::Function { parameters: parameters.clone(), body: body.clone() },
Expression::Application { function, arguments } => match eval(subst, *function) {
Value::Function { parameters, body } => {
Expand Down
Loading

0 comments on commit 6ee4899

Please sign in to comment.