Skip to content

Commit 9a602ab

Browse files
Mahmood-SinanjalveszCopilotjvdp1
authored
feat: Add generalized lagtm routine supporting arbitrary values for alpha and beta (#1068)
* extended lagtm implemented, but only for real alpha and beta values * alpha and beta now supports complex values also * modularized extended lapack, but cmakefile has to be changed * changed cmakefiles.txt * added tests for random values of alpha and beta * Update src/lapack_extended/stdlib_extended_lapack.fypp Co-authored-by: Copilot <[email protected]> * Update src/lapack_extended/stdlib_extended_lapack.fypp Co-authored-by: Copilot <[email protected]> * changed file names from extended_lapack to lapack_extended, updated docs, updated alpha and beta checks * updated the docs --------- Co-authored-by: José Alves <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Jeremie Vandenplas <[email protected]>
1 parent 927c607 commit 9a602ab

File tree

10 files changed

+218
-15
lines changed

10 files changed

+218
-15
lines changed

doc/specs/stdlib_specialmatrices.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Experimental
9090

9191
With the exception of `extended precision` and `quadruple precision`, all the types provided by `stdlib_specialmatrices` benefit from specialized kernels for matrix-vector products accessible via the common `spmv` interface.
9292

93-
- For `tridiagonal` matrices, the LAPACK `lagtm` backend is being used.
93+
- For `tridiagonal` matrices, the backend is either LAPACK `lagtm` or the generalized routine `glagtm`, depending on the values and types of `alpha` and `beta`.
9494

9595
#### Syntax
9696

@@ -110,10 +110,6 @@ With the exception of `extended precision` and `quadruple precision`, all the ty
110110

111111
- `op` (optional) : In-place operator identifier. Shall be a character(1) argument. It can have any of the following values: `N`: no transpose, `T`: transpose, `H`: hermitian or complex transpose.
112112

113-
@warning
114-
Due to limitations of the underlying `lapack` driver, currently `alpha` and `beta` can only take one of the values `[-1, 0, 1]` for `tridiagonal` and `symtridiagonal` matrices. See `lagtm` for more details.
115-
@endwarning
116-
117113
#### Examples
118114

119115
```fortran
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
ADD_EXAMPLE(specialmatrices_dp_spmv)
2+
ADD_EXAMPLE(specialmatrices_cdp_spmv)
23
ADD_EXAMPLE(tridiagonal_dp_type)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
program example_tridiagonal_matrix_cdp
2+
use stdlib_linalg_constants, only: dp
3+
use stdlib_specialmatrices, only: tridiagonal_cdp_type, tridiagonal, dense, spmv
4+
implicit none
5+
6+
integer, parameter :: n = 5
7+
type(tridiagonal_cdp_type) :: A
8+
complex(dp) :: dl(n-1), dv(n), du(n-1)
9+
complex(dp) :: x(n), y(n), y_dense(n)
10+
integer :: i
11+
complex(dp) :: alpha, beta
12+
13+
dl = [(cmplx(i,i, dp), i=1, n - 1)]
14+
dv = [(cmplx(2*i,2*i, dp), i=1, n)]
15+
du = [(cmplx(3*i,3*i, dp), i=1, n - 1)]
16+
17+
A = tridiagonal(dl, dv, du)
18+
19+
x = (1.0_dp, 0.0_dp)
20+
y = (3.0_dp, -7.0_dp)
21+
y_dense = (0.0_dp, 0.0_dp)
22+
alpha = cmplx(2.0_dp, 3.0_dp)
23+
beta = cmplx(-1.0_dp, 5.0_dp)
24+
25+
y_dense = alpha * matmul(dense(A), x) + beta * y
26+
call spmv(A, x, y, alpha, beta)
27+
28+
print *, 'dense :', y_dense
29+
print *, 'Tridiagonal :', y
30+
end program example_tridiagonal_matrix_cdp

src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ if (NOT STDLIB_NO_BITSET)
33
endif()
44
add_subdirectory(blas)
55
add_subdirectory(lapack)
6+
add_subdirectory(lapack_extended)
67
if (NOT STDLIB_NO_STATS)
78
add_subdirectory(stats)
89
endif()
@@ -115,4 +116,4 @@ configure_stdlib_target(${PROJECT_NAME} f90Files fppFiles cppFiles)
115116
target_link_libraries(${PROJECT_NAME} PUBLIC
116117
$<$<NOT:$<BOOL:${STDLIB_NO_BITSET}>>:bitsets>
117118
$<$<NOT:$<BOOL:${STDLIB_NO_STATS}>>:stats>
118-
blas lapack)
119+
blas lapack lapack_extended)

src/lapack_extended/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
set(lapack_extended_fppFiles
2+
../stdlib_kinds.fypp
3+
stdlib_lapack_extended_base.fypp
4+
stdlib_lapack_extended.fypp
5+
)
6+
set(lapack_extended_cppFiles
7+
../stdlib_linalg_constants.fypp
8+
)
9+
10+
configure_stdlib_target(lapack_extended "" lapack_extended_fppFiles lapack_extended_cppFiles)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#:include "common.fypp"
2+
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
3+
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
4+
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
5+
6+
submodule(stdlib_lapack_extended_base) stdlib_lapack_extended
7+
implicit none
8+
contains
9+
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
10+
#:for k1,t1,s1 in KINDS_TYPES
11+
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
12+
character, intent(in) :: trans
13+
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
14+
${t1}$, intent(in) :: alpha, beta
15+
${t1}$, intent(inout) :: b(ldb,*)
16+
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
17+
18+
! Internal variables.
19+
integer(${ik}$) :: i, j
20+
${t1}$ :: temp
21+
if(n == 0) then
22+
return
23+
endif
24+
if(beta == 0.0_${k1}$) then
25+
b(1:n, 1:nrhs) = 0.0_${k1}$
26+
else
27+
b(1:n, 1:nrhs) = beta * b(1:n, 1:nrhs)
28+
end if
29+
30+
if(trans == 'N') then
31+
do j = 1, nrhs
32+
if(n == 1_${ik}$) then
33+
temp = d(1_${ik}$) * x(1_${ik}$, j)
34+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
35+
else
36+
temp = d(1_${ik}$) * x(1_${ik}$, j) + du(1_${ik}$) * x(2_${ik}$, j)
37+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
38+
do i = 2, n - 1
39+
temp = dl(i - 1) * x(i - 1, j) + d(i) * x(i, j) + du(i) * x(i + 1, j)
40+
b(i, j) = b(i, j) + alpha * temp
41+
end do
42+
temp = dl(n - 1) * x(n - 1, j) + d(n) * x(n, j)
43+
b(n, j) = b(n, j) + alpha * temp
44+
end if
45+
end do
46+
#:if t1.startswith('complex')
47+
else if(trans == 'C') then
48+
do j = 1, nrhs
49+
if(n == 1_${ik}$) then
50+
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j)
51+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
52+
else
53+
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j) + conjg(dl(1_${ik}$)) * x(2_${ik}$, j)
54+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
55+
do i = 2, n - 1
56+
temp = conjg(du(i - 1)) * x(i - 1, j) + conjg(d(i)) * x(i, j) + conjg(dl(i)) * x(i + 1, j)
57+
b(i, j) = b(i, j) + alpha * temp
58+
end do
59+
temp = conjg(du(n - 1)) * x(n - 1, j) + conjg(d(n)) * x(n, j)
60+
b(n, j) = b(n, j) + alpha * temp
61+
end if
62+
end do
63+
#:endif
64+
else
65+
do j = 1, nrhs
66+
if(n == 1_${ik}$) then
67+
temp = d(1_${ik}$) * x(1_${ik}$, j)
68+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
69+
else
70+
temp = d(1_${ik}$) * x(1_${ik}$, j) + dl(1_${ik}$) * x(2_${ik}$, j)
71+
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
72+
do i = 2, n - 1
73+
temp = du(i - 1) * x(i - 1, j) + d(i) * x(i, j) + dl(i) * x(i + 1, j)
74+
b(i, j) = b(i, j) + alpha * temp
75+
end do
76+
temp = du(n - 1) * x(n - 1, j) + d(n) * x(n, j)
77+
b(n, j) = b(n, j) + alpha * temp
78+
end if
79+
end do
80+
end if
81+
end subroutine stdlib${ii}$_glagtm_${s1}$
82+
#:endfor
83+
#:endfor
84+
85+
end submodule
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#:include "common.fypp"
2+
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
3+
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
4+
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
5+
module stdlib_lapack_extended_base
6+
use stdlib_linalg_constants
7+
implicit none
8+
9+
interface glagtm
10+
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
11+
#:for k1,t1,s1 in KINDS_TYPES
12+
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
13+
character, intent(in) :: trans
14+
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
15+
${t1}$, intent(in) :: alpha, beta
16+
${t1}$, intent(inout) :: b(ldb,*)
17+
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
18+
end subroutine stdlib${ii}$_glagtm_${s1}$
19+
#:endfor
20+
#:endfor
21+
end interface
22+
end module

src/stdlib_specialmatrices.fypp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module stdlib_specialmatrices
1212
use stdlib_constants
1313
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
1414
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
15+
use stdlib_lapack_extended_base
1516
implicit none
1617
private
1718
public :: tridiagonal
@@ -99,7 +100,7 @@ module stdlib_specialmatrices
99100
!! Matrix dimension.
100101
type(tridiagonal_${s1}$_type) :: A
101102
!! Corresponding Tridiagonal matrix.
102-
end function
103+
end function
103104

104105
module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
105106
!! Construct a `tridiagonal` matrix from the rank-1 arrays
@@ -122,7 +123,7 @@ module stdlib_specialmatrices
122123
!! Error handling.
123124
type(tridiagonal_${s1}$_type) :: A
124125
!! Corresponding Tridiagonal matrix.
125-
end function
126+
end function
126127
#:endfor
127128
end interface
128129

@@ -145,8 +146,8 @@ module stdlib_specialmatrices
145146
type(tridiagonal_${s1}$_type), intent(in) :: A
146147
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
147148
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
148-
real(${k1}$), intent(in), optional :: alpha
149-
real(${k1}$), intent(in), optional :: beta
149+
${t1}$, intent(in), optional :: alpha
150+
${t1}$, intent(in), optional :: beta
150151
character(1), intent(in), optional :: op
151152
end subroutine
152153
#:endfor

src/stdlib_specialmatrices_tridiagonal.fypp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,18 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
155155
type(tridiagonal_${s1}$_type), intent(in) :: A
156156
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
157157
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
158-
real(${k1}$), intent(in), optional :: alpha
159-
real(${k1}$), intent(in), optional :: beta
158+
${t1}$, intent(in), optional :: alpha
159+
${t1}$, intent(in), optional :: beta
160160
character(1), intent(in), optional :: op
161161

162162
! Internal variables.
163-
real(${k1}$) :: alpha_, beta_
163+
${t1}$ :: alpha_, beta_
164164
integer(ilp) :: n, nrhs, ldx, ldy
165165
character(1) :: op_
166+
#:if t1.startswith('real')
167+
logical :: is_alpha_special, is_beta_special
168+
#:endif
169+
166170
#:if rank == 1
167171
${t1}$, pointer :: xmat(:, :), ymat(:, :)
168172
#:endif
@@ -171,6 +175,10 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
171175
alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha
172176
beta_ = 0.0_${k1}$ ; if (present(beta)) beta_ = beta
173177
op_ = "N" ; if (present(op)) op_ = op
178+
#:if t1.startswith('real')
179+
is_alpha_special = (alpha_ == 1.0_${k1}$ .or. alpha_ == 0.0_${k1}$ .or. alpha_ == -1.0_${k1}$)
180+
is_beta_special = (beta_ == 1.0_${k1}$ .or. beta_ == 0.0_${k1}$ .or. beta_ == -1.0_${k1}$)
181+
#:endif
174182

175183
! Prepare Lapack arguments.
176184
n = A%n ; ldx = n ; ldy = n ;
@@ -179,9 +187,25 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
179187
#:if rank == 1
180188
! Pointer trick.
181189
xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y
182-
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
190+
#:if t1.startswith('complex')
191+
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
183192
#:else
184-
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
193+
if(is_alpha_special .and. is_beta_special) then
194+
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
195+
else
196+
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
197+
end if
198+
#:endif
199+
#:else
200+
#:if t1.startswith('complex')
201+
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
202+
#:else
203+
if(is_alpha_special .and. is_beta_special) then
204+
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
205+
else
206+
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
207+
end if
208+
#:endif
185209
#:endif
186210
end subroutine
187211
#:endfor

test/linalg/test_linalg_specialmatrices.fypp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,39 @@ contains
8282
if (allocated(error)) return
8383
end do
8484
end do
85+
86+
! Test y = A @ x for random values of alpha and beta
87+
y1 = 0.0_wp
88+
call random_number(alpha)
89+
call random_number(beta)
90+
call random_number(y2)
91+
y1 = alpha * matmul(Amat, x) + beta * y2
92+
call spmv(A, x, y2, alpha=alpha, beta=beta)
93+
call check(error, all_close(y1, y2), .true.)
94+
if (allocated(error)) return
95+
96+
! Test y = A.T @ x for random values of alpha and beta
97+
y1 = 0.0_wp
98+
call random_number(alpha)
99+
call random_number(beta)
100+
call random_number(y2)
101+
y1 = alpha * matmul(transpose(Amat), x) + beta * y2
102+
call spmv(A, x, y2, alpha=alpha, beta=beta, op="T")
103+
call check(error, all_close(y1, y2), .true.)
104+
if (allocated(error)) return
105+
106+
#:if t1.startswith('complex')
107+
! Test y = A.H @ x for random values of alpha and beta
108+
y1 = 0.0_wp
109+
call random_number(alpha)
110+
call random_number(beta)
111+
call random_number(y2)
112+
y1 = alpha * matmul(transpose(conjg((Amat))), x) + beta * y2
113+
call spmv(A, x, y2, alpha=alpha, beta=beta, op="H")
114+
call check(error, all_close(y1, y2), .true.)
115+
if (allocated(error)) return
116+
#:endif
117+
85118
end block
86119
#:endfor
87120
end subroutine

0 commit comments

Comments
 (0)