Skip to content

Commit

Permalink
reworked getindex and subsetting. Removed subset as it was type pir…
Browse files Browse the repository at this point in the history
…acy. Added batch and unbatch from MLUtils
  • Loading branch information
simonmandlik committed Nov 4, 2024
1 parent 38f3871 commit f3a5dd9
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 116 deletions.
2 changes: 1 addition & 1 deletion docs/src/api/data_nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Mill.metadata
datasummary
dropmeta
catobs
Mill.subset
Mill.metadata_getindex
Mill.mapdata
removeinstances
Expand Down
54 changes: 26 additions & 28 deletions docs/src/examples/musk/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.11.0"
julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "c980577c0415d2c228751f8ca1d6579dbbffa98a"

Expand Down Expand Up @@ -42,9 +42,9 @@ version = "0.1.38"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099"
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.0.4"
version = "4.1.1"
weakdeps = ["StaticArrays"]

[deps.Adapt.extensions]
Expand Down Expand Up @@ -294,9 +294,9 @@ version = "0.12.32"

[[deps.Flux]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777"
git-tree-sha1 = "df520a0727f843576801a0294f5be1a94be28e23"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.14.22"
version = "0.14.25"

[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
Expand All @@ -305,22 +305,20 @@ version = "0.14.22"
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
git-tree-sha1 = "a9ce73d3c827adab2d70bf168aaece8cce196898"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.36"
version = "0.10.37"
weakdeps = ["StaticArrays"]

[deps.ForwardDiff.extensions]
Expand Down Expand Up @@ -411,9 +409,9 @@ version = "1.0.0"

[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22"
git-tree-sha1 = "783c1be5213a09609b23237a0c9e5dfd258ae6f2"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
version = "0.5.5"
version = "0.5.7"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
Expand All @@ -429,9 +427,9 @@ version = "0.2.4"

[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8"
git-tree-sha1 = "e73a077abc7fe798fe940deabe30ef6c66bdde52"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.28"
version = "0.9.29"

[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
Expand All @@ -444,10 +442,10 @@ version = "0.9.28"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
git-tree-sha1 = "4ad43cb0a4bb5e5b1506e1d1f48646d7e0c80363"
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"]
git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "9.1.2"
version = "9.1.3"

[deps.LLVM.extensions]
BFloat16sExt = "BFloat16s"
Expand All @@ -462,9 +460,9 @@ uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.34+0"

[[deps.LaTeXStrings]]
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"
version = "1.4.0"

[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -526,10 +524,10 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
version = "1.11.0"

[[deps.MLDataDevices]]
deps = ["Adapt", "Functors", "Preferences", "Random"]
git-tree-sha1 = "e16288e37e76d68c3f1c418e0a2bec88d98d55fc"
deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"]
git-tree-sha1 = "5cffc52b59227864b665459e1f7bcc4d3c4fb47b"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
version = "1.2.0"
version = "1.4.2"

[deps.MLDataDevices.extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
Expand Down Expand Up @@ -602,7 +600,7 @@ version = "0.2.0"
deps = ["Accessors", "ChainRulesCore", "Combinatorics", "Compat", "DataFrames", "DataStructures", "FiniteDifferences", "Flux", "HierarchicalUtils", "LinearAlgebra", "MLUtils", "MacroTools", "OneHotArrays", "PooledArrays", "Preferences", "SparseArrays", "Statistics", "Test"]
path = "../../../.."
uuid = "1d0525e4-8992-11e8-313c-e310e1f6ddea"
version = "2.10.5"
version = "2.10.6"

[[deps.Missings]]
deps = ["DataAPI"]
Expand Down Expand Up @@ -774,9 +772,9 @@ version = "0.7.0"

[[deps.SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a"
git-tree-sha1 = "d0553ce4031a081cc42387a9b9c8441b7d99f32d"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.4.5"
version = "1.4.7"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand Down Expand Up @@ -838,9 +836,9 @@ version = "0.1.15"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50"
git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.7"
version = "1.9.8"
weakdeps = ["ChainRulesCore", "Statistics"]

[deps.StaticArrays.extensions]
Expand Down Expand Up @@ -983,9 +981,9 @@ version = "1.2.13+1"

[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "f2f85ad73ca67b5d3c94239b0fde005e0fe2d900"
git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.71"
version = "0.6.72"

[deps.Zygote.extensions]
ZygoteColorsExt = "Colors"
Expand Down
11 changes: 4 additions & 7 deletions docs/src/manual/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ pm(ds)
The solution using [`LazyNode`](@ref) is sufficient in most scenarios. For other cases, it is recommended to equip custom nodes with the following functionality:

* allow nesting (if needed)
* implement [`Mill.subset`](@ref) and optionally `Base.getindex` to obtain subsets of observations.
[`Mill.jl`](https://github.com/CTUAvastLab/Mill.jl) already defines [`Mill.subset`](@ref) for
common datatypes, which can be used.
* implement `Base.getindex` to obtain subsets of observations. We make use of [`Mill.metadata_getindex`](@ref) to index the metadata.
* allow concatenation of nodes with [`catobs`](@ref). Optionally, implement `reduce(catobs, ...)` as well to avoid excessive compilations if a number of arguments will vary a lot
* define a specialized method for `MLUtils.numobs`, which we can however import directly from
[`Mill.jl`](https://github.com/CTUAvastLab/Mill.jl).
Expand All @@ -103,11 +101,10 @@ end
PathNode(data::Vector{S}) where {S <: AbstractString} = PathNode(data, nothing)
Base.show(io::IO, n::PathNode) = print(io, "PathNode ($(numobs(n)) obs)")
Base.ndims(n::PathNode) = Colon()
Base.ndims(::PathNode) = Colon()
numobs(n::PathNode) = length(n.data)
catobs(ns::PathNode) = PathNode(vcat(data.(ns)...), catobs(metadata.(as)...))
Base.getindex(n::PathNode, i::VecOrRange{<:Int}) = PathNode(subset(data(x), i),
subset(metadata(x), i))
Base.getindex(n::PathNode, i::VecOrRange{<:Int}) = PathNode(n.data[i], Mill.metadata_getindex(n.metadata, i))
NodeType(::Type{<:PathNode}) = LeafNode()
nothing # hide
```
Expand All @@ -122,7 +119,7 @@ struct PathModel{T, F} <: AbstractMillModel
end
Flux.@layer :ignore PathModel
show(io::IO, n::PathModel) = print(io, "PathModel")
show(io::IO, ::PathModel) = print(io, "PathModel")
NodeType(::Type{<:PathModel}) = LeafNode()
path2mill(ds::PathNode) = path2mill(ds.data)
Expand Down
5 changes: 3 additions & 2 deletions src/Mill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ using Preferences
using SparseArrays
using Statistics

using Accessors: PropertyLens, IndexLens, ComposedOptic
using Base: CodeUnits, nameof
using ChainRulesCore: NotImplemented, NotImplementedException
using HierarchicalUtils: encode, stringify
using Accessors: PropertyLens, IndexLens, ComposedOptic
using MLUtils: batch, unbatch

import Base: *, ==

Expand Down Expand Up @@ -51,7 +52,7 @@ include("datanodes/datanode.jl")
export AbstractMillNode, AbstractProductNode, AbstractBagNode
export ArrayNode, BagNode, WeightedBagNode, ProductNode, LazyNode
export numobs, getobs, catobs, removeinstances, dropmeta
@compat public subset, data, metadata, mapdata, unpack2mill
@compat public data, metadata, mapdata, unpack2mill

include("special_arrays/special_arrays.jl")
export MaybeHotVector, MaybeHotMatrix, maybehot, maybehotbatch, maybecold
Expand Down
6 changes: 3 additions & 3 deletions src/aggregations/aggregations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ for p in filter(p -> length(p) > 1, collect(powerset(collect(1:length(names)))))
s = Symbol("Segmented", names[p]...)
@eval begin
"""
$($(s))([t::Type, ]d::Int)
$($(s))([t::Type, ]d::Integer)
Construct [`AggregationStack`](@ref) consisting of $($(
join("[`Segmented" .* names[p] .* "`](@ref)", ", ", " and ")
Expand All @@ -83,11 +83,11 @@ for p in filter(p -> length(p) > 1, collect(powerset(collect(1:length(names)))))
See also: [`AbstractAggregation`](@ref), [`AggregationStack`](@ref), [`SegmentedSum`](@ref),
[`SegmentedMax`](@ref), [`SegmentedMean`](@ref), [`SegmentedPNorm`](@ref), [`SegmentedLSE`](@ref).
"""
function $s(d::Int)
function $s(d::Integer)
AggregationStack($((Expr(:call, Symbol("Segmented", n), :d) for n in names[p])...))
end
end
@eval function $s(::Type{T}, d::Int) where T
@eval function $s(::Type{T}, d::Integer) where T
AggregationStack($((Expr(:call, Symbol("Segmented", n), :T, :d) for n in names[p])...))
end
@eval export $s
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_lse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ end

Flux.@layer :ignore SegmentedLSE

SegmentedLSE(T::Type, d::Int) = SegmentedLSE(zeros(T, d), randn(T, d))
SegmentedLSE(d::Int) = SegmentedLSE(Float32, d)
SegmentedLSE(T::Type, d::Integer) = SegmentedLSE(zeros(T, d), randn(T, d))
SegmentedLSE(d::Integer) = SegmentedLSE(Float32, d)

Flux.@forward SegmentedLSE.ψ Base.getindex, Base.length, Base.size, Base.firstindex, Base.lastindex,
Base.first, Base.last, Base.iterate, Base.eltype
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_max.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ end

Flux.@layer :ignore SegmentedMax

SegmentedMax(T::Type, d::Int) = SegmentedMax(zeros(T, d))
SegmentedMax(d::Int) = SegmentedMax(Float32, d)
SegmentedMax(T::Type, d::Integer) = SegmentedMax(zeros(T, d))
SegmentedMax(d::Integer) = SegmentedMax(Float32, d)

Flux.@forward SegmentedMax.ψ Base.getindex, Base.length, Base.size, Base.firstindex, Base.lastindex,
Base.first, Base.last, Base.iterate, Base.eltype
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ end

Flux.@layer :ignore SegmentedMean

SegmentedMean(T::Type, d::Int) = SegmentedMean(zeros(T, d))
SegmentedMean(d::Int) = SegmentedMean(Float32, d)
SegmentedMean(T::Type, d::Integer) = SegmentedMean(zeros(T, d))
SegmentedMean(d::Integer) = SegmentedMean(Float32, d)

Flux.@forward SegmentedMean.ψ Base.getindex, Base.length, Base.size, Base.firstindex, Base.lastindex,
Base.first, Base.last, Base.iterate, Base.eltype
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_pnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ end

Flux.@layer :ignore SegmentedPNorm

SegmentedPNorm(T::Type, d::Int) = SegmentedPNorm(zeros(T, d), randn(T, d), zeros(T, d))
SegmentedPNorm(d::Int) = SegmentedPNorm(Float32, d)
SegmentedPNorm(T::Type, d::Integer) = SegmentedPNorm(zeros(T, d), randn(T, d), zeros(T, d))
SegmentedPNorm(d::Integer) = SegmentedPNorm(Float32, d)

Flux.@forward SegmentedPNorm.ψ Base.getindex, Base.length, Base.size, Base.firstindex, Base.lastindex,
Base.first, Base.last, Base.iterate, Base.eltype
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_sum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ end

Flux.@layer :ignore SegmentedSum

SegmentedSum(T::Type, d::Int) = SegmentedSum(zeros(T, d))
SegmentedSum(d::Int) = SegmentedSum(Float32, d)
SegmentedSum(T::Type, d::Integer) = SegmentedSum(zeros(T, d))
SegmentedSum(d::Integer) = SegmentedSum(Float32, d)

Flux.@forward SegmentedSum.ψ Base.getindex, Base.length, Base.size, Base.firstindex, Base.lastindex,
Base.first, Base.last, Base.iterate, Base.eltype
Expand Down
7 changes: 5 additions & 2 deletions src/datanodes/arraynode.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
ArrayNode{A <: AbstractArray, C} <: AbstractMillNode
Data node for storing array-like data of type `A` and metadata of type `C`. The convention is that samples are stored along the last axis, e.g. in columns of a matrix.
Data node for storing array-like data of type `A` and metadata of type `C`. The convention is that
samples are stored along the last axis, e.g. in columns of a matrix.
See also: [`AbstractMillNode`](@ref), [`ArrayModel`](@ref).
"""
Expand Down Expand Up @@ -53,7 +54,9 @@ function Base.reduce(::typeof(hcat), as::Vector{<:ArrayNode})
ArrayNode(reduce(hcat, data.(as)), _cat_meta(hcat, metadata.(as)))
end

Base.getindex(x::ArrayNode, i::VecOrRange{<:Int}) = ArrayNode(subset(x.data, i), subset(x.metadata, i))
function Base.getindex(x::ArrayNode, i::VecOrRange{<:Integer})
ArrayNode(x.data[:, i], metadata_getindex(x.metadata, i))
end

_arraynode(m) = ArrayNode(m)
_arraynode(m::AbstractMillNode) = m
Expand Down
8 changes: 5 additions & 3 deletions src/datanodes/bagnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ mapdata(f, x::BagNode) = BagNode(mapdata(f, x.data), x.bags, x.metadata)

dropmeta(x::BagNode) = BagNode(dropmeta(x.data), x.bags)

function Base.getindex(x::BagNode, i::VecOrRange{<:Int})
function Base.getindex(x::BagNode, i::VecOrRange{<:Integer})
nb, ii = remapbags(x.bags, i)
emptyismissing() && isempty(ii) && return(BagNode(missing, nb, nothing))
BagNode(subset(x.data,ii), nb, subset(x.metadata, i))
BagNode(x.data[ii], nb, metadata_getindex(x.metadata, i))
end

function Base.reduce(::typeof(catobs), as::Vector{<:BagNode})
Expand All @@ -68,7 +68,9 @@ function Base.reduce(::typeof(catobs), as::Vector{<:BagNode})
)
end

removeinstances(a::BagNode, mask) = BagNode(subset(a.data, findall(mask)), adjustbags(a.bags, mask), a.metadata)
function removeinstances(a::BagNode, mask)
BagNode(a.data[mask], adjustbags(a.bags, mask), a.metadata)
end

Base.hash(n::BagNode, h::UInt) = hash((n.data, n.bags, n.metadata), h)
(n1::BagNode == n2::BagNode) = n1.data == n2.data && n1.bags == n2.bags && n1.metadata == n2.metadata
Expand Down
Loading

0 comments on commit f3a5dd9

Please sign in to comment.