From 3972a76fbcaef6c7957a3e8e50a49930edea0f11 Mon Sep 17 00:00:00 2001 From: nHackel Date: Fri, 16 Feb 2024 12:53:15 +0100 Subject: [PATCH] Add extension for ProximalCore/-Operators adapter --- Project.toml | 7 +++++++ .../RegularizedLeastSquaresProximalCore.jl | 19 +++++++++++++++++++ src/Regularization/Regularization.jl | 2 ++ 3 files changed, 28 insertions(+) create mode 100644 ext/RegularizedLeastSquaresProximalCore/RegularizedLeastSquaresProximalCore.jl diff --git a/Project.toml b/Project.toml index e33269f1..9ea450e2 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[weakdeps] +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" + [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -30,6 +33,10 @@ VectorizationBase = "0.19, 0.21" LinearOperatorCollection = "1.0" LinearOperators = "2.3.3" FFTW = "1.0" +ProximalCore = "0.1" [targets] test = ["Test", "Random", "FFTW"] + +[extensions] +RegularizedLeastSquaresProximalCore = "ProximalCore" diff --git a/ext/RegularizedLeastSquaresProximalCore/RegularizedLeastSquaresProximalCore.jl b/ext/RegularizedLeastSquaresProximalCore/RegularizedLeastSquaresProximalCore.jl new file mode 100644 index 00000000..98af9bc7 --- /dev/null +++ b/ext/RegularizedLeastSquaresProximalCore/RegularizedLeastSquaresProximalCore.jl @@ -0,0 +1,19 @@ +module RegularizedLeastSquaresProximalCore + +using RegularizedLeastSquares, ProximalCore + +import RegularizedLeastSquares.prox!, RegularizedLeastSquares.ProximalCoreAdapter + +struct ProximalCoreAdapterImpl{T, F} <: ProximalCoreAdapter{T, F} + λ::T + op::F +end + +RegularizedLeastSquares.ProximalCoreAdapter(λ::T, op::F) where {T, F} = ProximalCoreAdapterImpl(λ, op) + +function prox!(reg::ProximalCoreAdapter, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} + ProximalCore.prox!(x, reg.op, x, λ) + return x +end + +end \ No newline at end of file diff --git a/src/Regularization/Regularization.jl b/src/Regularization/Regularization.jl index 760114b5..7aa42a26 100644 --- a/src/Regularization/Regularization.jl +++ b/src/Regularization/Regularization.jl @@ -67,6 +67,8 @@ include("TransformedRegularization.jl") include("MaskedRegularization.jl") include("PlugAndPlayRegularization.jl") +export ProximalCoreAdapter +abstract type ProximalCoreAdapter{T, F} <: AbstractParameterizedRegularization{T} where F end function findfirst(::Type{S}, reg::AbstractRegularization) where S <: AbstractRegularization regs = collect(reg)