Skip to content

Commit 854d90f

Browse files
Fix GPU Cholesky cache initialization for non-square matrices
When using DefaultLinearSolver with non-square GPU matrices (e.g., for least squares problems), the init_cacheval function for CholeskyFactorization would fail because it tried to compute cholesky(A) on a non-square matrix. The fix checks assumptions.issq before attempting Cholesky factorization and returns nothing for non-square matrices, allowing the DefaultLinearSolver to properly use QRFactorization instead. Fixes SciML/NonlinearSolve.jl#746 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 409ab7f commit 854d90f

File tree

2 files changed

+77
-19
lines changed

2 files changed

+77
-19
lines changed

src/factorization.jl

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -317,21 +317,24 @@ end
317317
const PREALLOCATED_QR_ColumnNorm = ArrayInterface.qr_instance(rand(1, 1), ColumnNorm())
318318

319319
function init_cacheval(alg::QRFactorization{ColumnNorm}, A::Matrix{Float64}, b, u, Pl, Pr,
320-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
320+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
321+
assumptions::OperatorAssumptions)
321322
return PREALLOCATED_QR_ColumnNorm
322323
end
323324

324325
function init_cacheval(
325326
alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,
326-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
327+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
328+
assumptions::OperatorAssumptions)
327329
A isa GPUArraysCore.AnyGPUArray && return qr(A)
328330
return qr(A, alg.pivot)
329331
end
330332

331333
const PREALLOCATED_QR_NoPivot = ArrayInterface.qr_instance(rand(1, 1))
332334

333335
function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
334-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
336+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
337+
assumptions::OperatorAssumptions)
335338
return PREALLOCATED_QR_NoPivot
336339
end
337340

@@ -388,13 +391,18 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl,
388391
end
389392

390393
function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl,
391-
Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
392-
cholesky(A; check = false)
394+
Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
395+
assumptions::OperatorAssumptions)
396+
# Cholesky requires square matrices - return nothing for non-square to avoid errors
397+
# during DefaultLinearSolver cache initialization
398+
# See https://github.com/SciML/NonlinearSolve.jl/issues/746
399+
assumptions.issq ? cholesky(A; check = false) : nothing
393400
end
394401

395402
function init_cacheval(
396403
alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
397-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
404+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
405+
assumptions::OperatorAssumptions)
398406
if LinearSolve.is_cusparse_csc(A)
399407
nothing
400408
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
@@ -1012,7 +1020,8 @@ const PREALLOCATED_NORMALCHOLESKY_SYMMETRIC = ArrayInterface.cholesky_instance(
10121020
Symmetric(rand(1, 1)), NoPivot())
10131021

10141022
function init_cacheval(alg::NormalCholeskyFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
1015-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1023+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1024+
assumptions::OperatorAssumptions)
10161025
return PREALLOCATED_NORMALCHOLESKY_SYMMETRIC
10171026
end
10181027

@@ -1164,7 +1173,8 @@ function init_cacheval(alg::SparspakFactorization,
11641173
end
11651174

11661175
function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
1167-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1176+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1177+
assumptions::OperatorAssumptions)
11681178
nothing
11691179
end
11701180

@@ -1190,9 +1200,8 @@ struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization
11901200
alg::A = nothing,
11911201
snd::S = nothing,
11921202
reuse_symbolic = true,
1193-
throwerror = true,
1194-
) where {A, S}
1195-
1203+
throwerror = true
1204+
) where {A, S}
11961205
ext = Base.get_extension(@__MODULE__, :LinearSolveCliqueTreesExt)
11971206

11981207
if throwerror && isnothing(ext)
@@ -1203,30 +1212,36 @@ struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization
12031212
end
12041213
end
12051214

1206-
function init_cacheval(::CliqueTreesFactorization, ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
1207-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1215+
function init_cacheval(::CliqueTreesFactorization,
1216+
::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
1217+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1218+
assumptions::OperatorAssumptions)
12081219
nothing
12091220
end
12101221

12111222
function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
1212-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1223+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1224+
assumptions::OperatorAssumptions)
12131225
nothing
12141226
end
12151227

12161228
# Fallback init_cacheval for extension-based algorithms when extensions aren't loaded
12171229
# These return nothing since the actual implementations are in the extensions
12181230
function init_cacheval(::BLISLUFactorization, A, b, u, Pl, Pr,
1219-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1231+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1232+
assumptions::OperatorAssumptions)
12201233
nothing
12211234
end
12221235

12231236
function init_cacheval(::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
1224-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1237+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1238+
assumptions::OperatorAssumptions)
12251239
nothing
12261240
end
12271241

12281242
function init_cacheval(::MetalLUFactorization, A, b, u, Pl, Pr,
1229-
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
1243+
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
1244+
assumptions::OperatorAssumptions)
12301245
nothing
12311246
end
12321247

test/gpu/cuda.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ x2 = zero(b);
7575
prob1 = LinearProblem(A1, b1; u0 = x1)
7676
prob2 = LinearProblem(A2, b2; u0 = x2)
7777

78-
cache_kwargs = (;abstol = 1e-8, reltol = 1e-8, maxiter = 30)
78+
cache_kwargs = (; abstol = 1e-8, reltol = 1e-8, maxiter = 30)
7979

8080
function test_interface(alg, prob1, prob2)
8181
A1 = prob1.A
@@ -103,7 +103,8 @@ function test_interface(alg, prob1, prob2)
103103
return
104104
end
105105

106-
@testset "$alg" for alg in (CudaOffloadLUFactorization(), CudaOffloadQRFactorization(), NormalCholeskyFactorization())
106+
@testset "$alg" for alg in (CudaOffloadLUFactorization(), CudaOffloadQRFactorization(),
107+
NormalCholeskyFactorization())
107108
test_interface(alg, prob1, prob2)
108109
end
109110

@@ -171,3 +172,45 @@ if Base.find_package("CUSOLVERRF") !== nothing
171172
include("cusolverrf.jl")
172173
end
173174
end
175+
176+
# Test for non-square GPU matrices (least squares problems)
177+
# See https://github.com/SciML/NonlinearSolve.jl/issues/746
178+
@testset "Non-square GPU matrices" begin
179+
# Overdetermined system: more rows than columns (4x2)
180+
A_rect = cu(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0; 7.0 8.0])
181+
b_rect = cu(Float32[1.0, 2.0, 3.0, 4.0])
182+
183+
prob_rect = LinearProblem(A_rect, b_rect)
184+
185+
# Test that default solver works (should use QRFactorization)
186+
@testset "Default solver for non-square" begin
187+
sol = solve(prob_rect)
188+
@test sol.alg.alg == LinearSolve.DefaultAlgorithmChoice.QRFactorization
189+
# Verify least squares solution
190+
@test norm(A_rect * sol.u - b_rect) < norm(b_rect) # residual should be smaller than b
191+
end
192+
193+
# Test explicit QRFactorization
194+
@testset "QRFactorization for non-square" begin
195+
sol = solve(prob_rect, QRFactorization())
196+
@test norm(A_rect * sol.u - b_rect) < norm(b_rect)
197+
end
198+
199+
# Test NormalCholeskyFactorization (should work via A'*A)
200+
@testset "NormalCholeskyFactorization for non-square" begin
201+
sol = solve(prob_rect, NormalCholeskyFactorization())
202+
@test norm(A_rect * sol.u - b_rect) < norm(b_rect)
203+
end
204+
205+
# Underdetermined system: more columns than rows (2x4)
206+
A_under = cu(Float32[1.0 2.0 3.0 4.0; 5.0 6.0 7.0 8.0])
207+
b_under = cu(Float32[1.0, 2.0])
208+
209+
prob_under = LinearProblem(A_under, b_under)
210+
211+
@testset "Default solver for underdetermined" begin
212+
sol = solve(prob_under)
213+
# Should still work and give a solution
214+
@test norm(A_under * sol.u - b_under) < 1e-4
215+
end
216+
end

0 commit comments

Comments
 (0)