# Day 2: Multiple Dispatch

Goals for today:

- Understand multiple dispatch
- Use multiple dispatch
- Use interfaces
- Write your own types

**New website**: [https://mpf-optimization-laboratory.github.io/julia-workshop/](https://mpf-optimization-laboratory.github.io/julia-workshop/)

## Multiple Dispatch

- This is the core paradigm of Julia
  - Called multimethods if you're coming from the Lisp world
- One _function_ may have many _methods_

In [1]:
bar(x)          = x  # implicitly ::Any
bar(x::Integer) = 2x
bar(x::Int8)    = 3x

bar (generic function with 3 methods)

In [2]:
bar(1.0) # Float64

1.0

In [3]:
bar(1) # Int64

2

In [4]:
bar(Int8(1)) # Int8

3

In [5]:
@which bar(10)

- This is the basis for how Julia works
- Even basic operators are functions that use multiple dispatch!

In [6]:
@which +(1, 2)

In [7]:
@which +(1.0, 2.0)

In [8]:
@which +(1.0, 3.0im)

## Writing Our Own Type

- Let's do integers modulo $n$

In [9]:
struct IntModN <: Integer
    value::Int
    modulus::Int
end

In [10]:
IntModN(10, 3)

IntModN(10, 3)

In [11]:
import Base.show # needed to add a method to someone else's function
show(io::IO, x::IntModN) = print(io, "IntModN($(x.value) mod $(x.modulus))")

show (generic function with 278 methods)

In [12]:
IntModN(10, 3)

IntModN(10 mod 3)

- Now, let's define addition

In [13]:
struct MismatchedModulusException <: Exception
    first::IntModN
    second::IntModN
end

In [14]:
import Base.+
function +(x::IntModN, y::IntModN)
    if x.modulus != y.modulus
        throw(MismatchedModulusException(x, y))
    end
    IntModN(x.value + y.value, x.modulus)
end

+ (generic function with 190 methods)

In [15]:
IntModN(5, 10) + IntModN(23, 10)

IntModN(28 mod 10)

In [16]:
IntModN(1, 2) + IntModN(1, 3)

LoadError: MismatchedModulusException(IntModN(1 mod 2), IntModN(1 mod 3))

In [17]:
function congruent(x::IntModN, y::IntModN)
    if x.modulus != y.modulus
        throw(MismatchedModulusException(x, y))
    end
    mod(x.value, x.modulus) == mod(y.value, x.modulus)
end

congruent (generic function with 1 method)

In [18]:
congruent(IntModN(1, 10), IntModN(11, 10))

true

In [19]:
congruent(IntModN(1, 10) + IntModN(1, 10), IntModN(22, 10))

true

In [20]:
congruent(IntModN(1, 10) + IntModN(1, 10), IntModN(22, 10))

true

## Making Our Own Parametric Type

- We explicitly used `Int` for our storage
- But what if we want to make it generic?
  - For example, if we're dealing with numbers larger than `typemax(Int)`
  - Or, we need to store a huge number of them but we know they'll all be small

In [21]:
struct ModN{T <: Integer} <: Integer
    value::T
    modulus::T
end

In [22]:
function show(io::IO, x::ModN{T}) where {T <: Integer}
    print(io, "ModN{$T}($(x.value) mod $(x.modulus))")
end

show (generic function with 279 methods)

In [23]:
ModN(0x1, 0x2)

ModN{UInt8}(1 mod 2)

In [24]:
ModN(1, 2)

ModN{Int64}(1 mod 2)

In [25]:
ModN(0x1, 2)

LoadError: MethodError: no method matching ModN(::UInt8, ::Int64)

[0mClosest candidates are:
[0m  ModN(::T, [91m::T[39m) where T<:Integer
[0m[90m   @[39m [35mMain[39m [90m[4mIn[21]:2[24m[39m
[0m  (::Type{T})(::T) where T<:Number
[0m[90m   @[39m [90mCore[39m [90m[4mboot.jl:792[24m[39m


## Promotion

- To do operations like `+` and `*` on different types, Julia _widens_ one type
- For example, `+(1.0, 1)` is essentially `+(1.0, convert(Float64, 1))`
    - We saw this last time when building lists
- Julia picks the type using `promote_rule`

In [54]:
promote(1.0, 2)

(1.0, 2.0)

In [27]:
typeof(promote(1.0, 1))

Tuple{Float64, Float64}

In [28]:
promote_rule(Float64, Int64)

Float64

- So, let's use this system!
- We'll define constructors for `ModN`
  - Constructors in Julia are just functions that have the same name as the type
- We saw before that we had a method `ModN(::T, ::T) where {T <: Integer}`

In [29]:
ModN(value::Integer, modulus::Integer) = ModN(promote(value, modulus)...)

ModN

In [30]:
ModN(0x1, 2)

ModN{Int64}(1 mod 2)

In [57]:
function +(x::ModN{T}, y::ModN{U}) where {T, U}
    if x.modulus != y.modulus
        throw(MismatchedModulusException(x, y))
    end
    ModN(x.value + y.value, x.modulus)
end

function congruent(x::ModN, y::ModN)
    if x.modulus != y.modulus
        throw(MismatchedModulusException(x, y))
    end
    mod(x.value, x.modulus) == mod(y.value, y.modulus)
end

congruent (generic function with 2 methods)

In [58]:
ModN(0x1, 0x5) + ModN(1, 5)

ModN{Int64}(2 mod 5)

### Problem: Complex Numbers

- Julia has them built in
- But, as an exercise, let's write our own
- Only worry about floats (subtypes of `AbstractFloat`)

```julia
import Base: real, imag, +
struct MyComplex{???}
    real::???
    imag::???
end
```

- Complex numbers: $c = a + bi$; $i^2 = -1$
    - $\Re(c) = \text{real}(c) = a$
    - $\Im(c) = \text{imag}(c) = b$
    - $c_1 + c_2 = \left(\strut\Re(c_1) + \Re(c_2)\right) + \left(\strut\Im(c_1) + \Im(c_2)\right)i$

In [33]:
import Base: real, imag, +
struct MyComplex{T <: AbstractFloat} <: Number
    real::T
    imag::T
end

In [34]:
real(c::MyComplex) = c.real
imag(c::MyComplex) = c.imag

imag (generic function with 17 methods)

In [35]:
MyComplex(real::AbstractFloat, imag::AbstractFloat) = MyComplex(promote(real, imag)...)

MyComplex

In [36]:
MyComplex(Float32(1.0), Float16(2.0))

MyComplex{Float32}(1.0f0, 2.0f0)

In [37]:
+(x::MyComplex, y::MyComplex) = MyComplex(real(x) + real(y), imag(x) + imag(y))

+ (generic function with 192 methods)

In [38]:
MyComplex(1.0, 2.0f0) + MyComplex(Float16(1), Float16(2))

MyComplex{Float64}(2.0, 4.0)

## Iterators

- One use of dispatch: easy to iterate!
- If you have a type that you'd like to iterate over, you just need to define two methods of `iterate`
- Julia turns:

```julia
for x in iterable
    # do something with x
end
```

into:

```julia
next = iterate(iterable)
while !isnothing(next)
    (x, state) = next
    # do something with x
    next = iterate(iterable, state)
end
```

In [39]:
v = ['J', 'u', 'l', 'i', 'a']
for character in v
    print(character)
end

Julia

In [40]:
@show iterated_item, state = iterate(v)
@show iterated_item, state = iterate(v, state)
@show iterated_item, state = iterate(v, state)
@show iterated_item, state = iterate(v, state)
@show iterated_item, state = iterate(v, state)
@show iterate(v, state)

(iterated_item, state) = iterate(v) = ('J', 2)
(iterated_item, state) = iterate(v, state) = ('u', 3)
(iterated_item, state) = iterate(v, state) = ('l', 4)
(iterated_item, state) = iterate(v, state) = ('i', 5)
(iterated_item, state) = iterate(v, state) = ('a', 6)
iterate(v, state) = nothing


- So, if we want to replicate that:

In [41]:
struct MyVector{Element}
    vector::Vector{Element}
end

In [42]:
import Base: iterate

# first method: return (first_item, next_state)
function iterate(v::MyVector)
    if length(v.vector) == 0
        nothing
    else
        # the element returned, the next state
        (v.vector[1], 2)
    end
end

iterate (generic function with 247 methods)

In [43]:
# second method: return (next_item, next_state)
function iterate(v::MyVector, state)
    if state > length(v.vector)
        nothing
    else
        (v.vector[state], state + 1)
    end
end

iterate (generic function with 248 methods)

In [44]:
mv = MyVector(v)
for x in mv
    print(x)
end

Julia

- Or, if we want to iterate over the first $n$ squares:

In [45]:
struct Squares
    max::Int
end

In [46]:
# first method: return (first_item, next_state)
iterate(::Squares) = (1, 2)
# second method: return (next_item, next_state)
iterate(s::Squares, i::Int) = if i > s.max
    nothing
else
    (i^2, i + 1)
end

iterate (generic function with 250 methods)

In [59]:
for s in Squares(6)
    print(s, ", ")
end

1, 4, 9, 16, 25, 36, 

In [48]:
s = Squares(5)

@show iterated_item, next_state = iterate(s)
@show iterated_item, next_state = iterate(s, next_state)
@show iterated_item, next_state = iterate(s, next_state)
@show iterated_item, next_state = iterate(s, next_state)
@show iterated_item, next_state = iterate(s, next_state)
@show iter = iterate(s, next_state)

(iterated_item, next_state) = iterate(s) = (1, 2)
(iterated_item, next_state) = iterate(s, next_state) = (4, 3)
(iterated_item, next_state) = iterate(s, next_state) = (9, 4)
(iterated_item, next_state) = iterate(s, next_state) = (16, 5)
(iterated_item, next_state) = iterate(s, next_state) = (25, 6)
iter = iterate(s, next_state) = nothing


### Aside: Dispatch and Defaults

- Julia (like many languages) has default parameters
- In Python:

```python
>>> def f(x, y=0):
...     print(f"({x}, {y})")
... 
>>> f(1, 2)
(1, 2)
>>> f(1)
(1, 0)
>>> 
```

- In Julia, we do it through dispatch rather than through function call magic

```julia
hasdefault(x, y=0) = println("($x, $y)")
```

is the same as:

In [49]:
hasdefault(x, y) = println("($x, $y)")
hasdefault(x) = hasdefault(x, 0)
hasdefault(1)
hasdefault(1, 2)

(1, 0)
(1, 2)


- So, this would've worked as well:

```julia
iterate(s::Squares, i::Int=1) = if i > s.max
    nothing
else
    (i^2, i + 1)
end
```

## Interfaces

- This iteration is an example of one of Julia's _interfaces_
- Unlike e.g. Rust or Haskell, Julia's interfaces are informal and not checked by the compiler during compilation
- But they are still powerful!
- Other useful interfaces:
  - Indexing
  - Array

## Exercise

- Write a type which represents all constant multiples of integers up to `n`
    - I.e. if you build it with `factor = 1.5` and `n = 4`, iterating over it should produce `1.5, 3.0, 4.5, 6.0`
- It should store `n` as an `Int`, but have a type parameter `{T <: Number}` for `factor`
- Write the two `iterate` methods for it
    - [docs.julialang.org/en/v1/manual/interfaces/](https://docs.julialang.org/en/v1/manual/interfaces/)
- Then, using the documentation linked above, write a `length` method and a `getindex` method for it
- Bonus: what is `eltype` for this type?

In [50]:
struct Multiples{T <: Number}
    n::Int
    factor::T
end

iterate(m::Multiples) = if m.n == 0
    nothing
else
    (m.factor * 1, 2)
end

iterate(m::Multiples, state::Int) = if state > m.n
    nothing
else
    (m.factor * state, state + 1)
end

import Base: length, getindex

length(m::Multiples) = m.n

getindex(m::Multiples, i) = if i > length(m) || i < 1
    throw(BoundsError(m, i))
else
    m.factor * i
end

getindex (generic function with 186 methods)

In [51]:
m = Multiples(4, 1.5)
println(m)
for x in m
    println(x)
end

Multiples{Float64}(4, 1.5)
1.5
3.0
4.5
6.0


In [52]:
m = Multiples(3, 0.5im)
println(m)
for x in m
    println(x)
end

Multiples{ComplexF64}(3, 0.0 + 0.5im)
0.0 + 0.5im
0.0 + 1.0im
0.0 + 1.5im


In [53]:
Multiples(5, 2//3)[2]

4//3

## Problems

- There's starter code on the website with some basic tests

1. Write an iterator over the prime numbers $\leq n$ using a Sieve of Eratosthenes
2. Write a matrix type which is the outer product of two vectors, without storing the full matrix
3. Implement the natural numbers using types!