Skip to content

Commit

Permalink
Use i16.
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDue committed Jan 22, 2024
1 parent 4764ad3 commit 8706cde
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
4 changes: 2 additions & 2 deletions bench.fut
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import "lib/github.com/diku-dk/sorts/radix_sort"
entry bitonic_sort_i32 = bitonic_sort (i32.<=)
entry merge_sort_i32 = merge_sort (i32.<=)
entry radix_sort_i32 = radix_sort 32 i32.get_bit
entry chunked_radix_sort_i32 = chunked_radix_sort 512 i32.highest 32 i32.get_bit
entry chunked_radix_sort_i32 = chunked_radix_sort 256 i32.highest 32 i32.get_bit

-- 64-bit keys
-- ==
Expand All @@ -24,4 +24,4 @@ entry chunked_radix_sort_i32 = chunked_radix_sort 512 i32.highest 32 i32.get_bit
entry bitonic_sort_i64 = bitonic_sort (i64.<=)
entry merge_sort_i64 = merge_sort (i64.<=)
entry radix_sort_i64 = radix_sort 64 i64.get_bit
entry chunked_radix_sort_i64 = chunked_radix_sort 512 i64.highest 64 i64.get_bit
entry chunked_radix_sort_i64 = chunked_radix_sort 256 i64.highest 64 i64.get_bit
51 changes: 40 additions & 11 deletions lib/github.com/diku-dk/sorts/radix_sort.fut
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
-- * `merge_sort`@term@"merge_sort"

local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
(digit_n: i32): ([n]t, (i64, i64, i64, i64)) =
(digit_n: i32): [n]t =
let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x
let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
(a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
Expand All @@ -30,7 +30,7 @@ local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
+ c * (i64.bool (bin == 2)) + nc * (i64.bool (bin > 2))
+ d * (i64.bool (bin == 3))
let is = map2 f bins offsets
in (scatter (copy xs) is xs, (na, nb, nc, nd))
in scatter (copy xs) is xs

-- | The `num_bits` and `get_bit` arguments can be taken from one of
-- the numeric modules of module type `integral`@mtype@"/prelude/math"
Expand All @@ -51,7 +51,7 @@ local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
def radix_sort [n] 't (num_bits: i32) (get_bit: i32 -> t -> i32)
(xs: [n]t): [n]t =
let iters = if n == 0 then 0 else (num_bits+2-1)/2
in loop xs for i < iters do radix_sort_step xs get_bit (i*2) |> (.0)
in loop xs for i < iters do radix_sort_step xs get_bit (i*2)

def with_indices [n] 'a (xs: [n]a) : [n](a, i64) =
zip xs (iota n)
Expand Down Expand Up @@ -123,6 +123,29 @@ local def get_bin 't
(x: t): i64 =
i64.i32 <| get_bit (digit_n+1) x * 2 + get_bit digit_n x

local def radix_sort_step_i16 [n] 't (xs: [n]t)
(get_bit: i32 -> t -> i32)
(digit_n: i32): ([n]t, (i16, i16, i16, i16)) =
let num x = i16.i32 (get_bit (digit_n+1) x * 2 + get_bit digit_n x)
let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
(a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
let bins = xs |> map num
let flags = bins |> map (\x ->
( i16.bool (x==0)
, i16.bool (x==1)
, i16.bool (x==2)
, i16.bool (x==3) ) )
let offsets = scan (pairwise (+)) (0,0,0,0) flags
let (na,nb,nc,nd) = last offsets
let f bin (a,b,c,d) = i64.i16 ((-1)
+ a * (i16.bool (bin == 0)) + na * (i16.bool (bin > 0))
+ b * (i16.bool (bin == 1)) + nb * (i16.bool (bin > 1))
+ c * (i16.bool (bin == 2)) + nc * (i16.bool (bin > 2))
+ d * (i16.bool (bin == 3)))
let is = map2 f bins offsets
in (scatter (copy xs) is xs, (na, nb, nc, nd))


local def chunked_radix_sort_step [n] [m] 't
(get_bit: i32 -> t -> i32)
(digit_n: i32)
Expand All @@ -132,8 +155,13 @@ local def chunked_radix_sort_step [n] [m] 't
unflatten xs
|> map (
\arr ->
let (ys, hist') = radix_sort_step arr get_bit digit_n
in (ys, (\(a, b, c, d) -> sized hist_size [a, b, c, d]) hist')
let (ys, (a, b, c, d)) =
radix_sort_step_i16 arr get_bit digit_n
let hist' = sized hist_size [i64.i16 a
,i64.i16 b
,i64.i16 c
,i64.i16 d]
in (ys, hist')
)
|> unzip
let ys = flatten xs'
Expand Down Expand Up @@ -193,12 +221,13 @@ local def (////) (a: i64) (b: i64) : i64 =
-- Symposium on Parallel & Distributed Processing, Rome, Italy, 2009,
-- pp. 1-10, doi: 10.1109/IPDPS.2009.5161005.
def chunked_radix_sort [n] 't
(chunk: i64)
(chunk: i16)
(highest: t)
(num_bits: i32)
(get_bit: i32 -> t -> i32)
(xs: [n]t): [n]t =
let iters = if n == 0 then 0 else (num_bits + 2 - 1) / 2
let chunk = i64.i16 chunk
let n_chunks = n //// chunk
let padding = replicate (n_chunks * chunk - n) highest
let xs = sized (n_chunks * chunk) (xs ++ padding)
Expand All @@ -208,7 +237,7 @@ def chunked_radix_sort [n] 't

-- | Like `radix_sort_by_key` but chunked.
def chunked_radix_sort_by_key [n] 't 'k
(chunk: i64)
(chunk: i16)
(highest: k)
(key: t -> k)
(num_bits: i32)
Expand All @@ -219,7 +248,7 @@ def chunked_radix_sort_by_key [n] 't 'k

-- | Like `radix_sort_by_int` but chunked.
def chunked_radix_sort_int [n] 't
(chunk: i64)
(chunk: i16)
(highest: t)
(num_bits: i32)
(get_bit: i32 -> t -> i32)
Expand All @@ -231,7 +260,7 @@ def chunked_radix_sort_int [n] 't

-- | Like `radix_sort_int_by_key` but chunked.
def chunked_radix_sort_int_by_key [n] 't 'k
(chunk: i64)
(chunk: i16)
(highest: k)
(key: t -> k)
(num_bits: i32)
Expand All @@ -242,7 +271,7 @@ def chunked_radix_sort_int_by_key [n] 't 'k

-- | Like `radix_sort_float` but chunked.
def chunked_radix_sort_float [n] 't
(chunk: i64)
(chunk: i16)
(highest: t)
(num_bits: i32)
(get_bit: i32 -> t -> i32)
Expand All @@ -255,7 +284,7 @@ def chunked_radix_sort_float [n] 't

-- | Like `radix_sort_float_by_key` but chunked.
def chunked_radix_sort_float_by_key [n] 't 'k
(chunk: i64)
(chunk: i16)
(highest: k)
(key: t -> k)
(num_bits: i32)
Expand Down

0 comments on commit 8706cde

Please sign in to comment.