Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
_, chout = l.channel
heads = l.heads

Wxi = Wxj = l.dense_x(xj)
Wxi = Wxj = reshape(Wxj, chout, heads, :)

if xi !== xj
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
end
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
Comment on lines +127 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
Wxj = l.dense_x(xj)
Wxj = reshape(Wxj, chout, heads, :)
if xi !== xj
Wxi = l.dense_x(xi)
Wxi = reshape(Wxi, chout, heads, :)
else
Wxi = Wxj
end

would work?

Copy link
Contributor Author

@lenianiva lenianiva Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work. It seems like the above example triggers both branches and Zygote gets one branch confused for another. I have seen this kind of behaviour before with Zygote.


# a hand-written message passing
message = Fix1(gat_message, l)
Expand All @@ -142,8 +139,11 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =

if !l.concat
x = mean(x, dims = 2)
width = size(x, 1)
else
width = size(x, 1) * size(x, 2)
end
x = reshape(x, :, size(x, 3)) # return a matrix
x = reshape(x, width, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)

return x
Expand Down Expand Up @@ -194,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix}

if !l.concat
x = mean(x, dims = 2)
width = size(x, 1)
else
width = size(x, 1) * size(x, 2)
end
x = reshape(x, :, size(x, 3))
x = reshape(x, width, size(x, 3))
x = l.σ.(x .+ l.bias)
return x
end
Expand Down
Loading