@@ -952,11 +952,7 @@ <h2 id="model-definition" class="anchor">Model definition </h2>
952952 < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > dec2 < span style ="color: #666666 "> =</ span > nn< span style ="color: #666666 "> .</ span > ConvTranspose2d(< span style ="color: #666666 "> 64</ span > , c, < span style ="color: #666666 "> 3</ span > , padding< span style ="color: #666666 "> =1</ span > )
953953 < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > act < span style ="color: #666666 "> =</ span > nn< span style ="color: #666666 "> .</ span > ReLU()
954954 < span style ="color: #408080; font-style: italic "> # timestep embedding to condition on t</ span >
955- < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > time_mlp < span style ="color: #666666 "> =</ span > nn< span style ="color: #666666 "> .</ span > Sequential(
956- nn< span style ="color: #666666 "> .</ span > Linear(< span style ="color: #666666 "> 1</ span > , < span style ="color: #666666 "> 128</ span > ), < span style ="color: #408080; font-style: italic "> # Changed from 64 to 128</ span >
957- nn< span style ="color: #666666 "> .</ span > ReLU(),
958- nn< span style ="color: #666666 "> .</ span > Linear(< span style ="color: #666666 "> 128</ span > , < span style ="color: #666666 "> 128</ span > ), < span style ="color: #408080; font-style: italic "> # Changed from 64 to 128</ span >
959- )
955+ < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > time_mlp < span style ="color: #666666 "> =</ span > nn< span style ="color: #666666 "> .</ span > Sequential(nn< span style ="color: #666666 "> .</ span > Linear(< span style ="color: #666666 "> 1</ span > , < span style ="color: #666666 "> 128</ span > ), nn< span style ="color: #666666 "> .</ span > ReLU(),nn< span style ="color: #666666 "> .</ span > Linear(< span style ="color: #666666 "> 128</ span > , < span style ="color: #666666 "> 128</ span > ))
960956
961957 < span style ="color: #008000; font-weight: bold "> def</ span > < span style ="color: #0000FF "> forward</ span > (< span style ="color: #008000 "> self</ span > , x, t):
962958 < span style ="color: #408080; font-style: italic "> # x: [B, C, H, W], t: [B]</ span >
@@ -965,7 +961,7 @@ <h2 id="model-definition" class="anchor">Model definition </h2>
965961 < span style ="color: #408080; font-style: italic "> # add time embedding</ span >
966962 t < span style ="color: #666666 "> =</ span > t< span style ="color: #666666 "> .</ span > unsqueeze(< span style ="color: #666666 "> -1</ span > )
967963 temb < span style ="color: #666666 "> =</ span > < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > time_mlp(t)
968- temb < span style ="color: #666666 "> =</ span > temb< span style ="color: #666666 "> .</ span > view(< span style ="color: #666666 "> -1</ span > , < span style ="color: #666666 "> 128</ span > , < span style ="color: #666666 "> 1</ span > , < span style ="color: #666666 "> 1</ span > ) < span style =" color: #408080; font-style: italic " > # Changed from 64 to 128 </ span >
964+ temb < span style ="color: #666666 "> =</ span > temb< span style ="color: #666666 "> .</ span > view(< span style ="color: #666666 "> -1</ span > , < span style ="color: #666666 "> 128</ span > , < span style ="color: #666666 "> 1</ span > , < span style ="color: #666666 "> 1</ span > )
969965 h < span style ="color: #666666 "> =</ span > h < span style ="color: #666666 "> +</ span > temb
970966 h < span style ="color: #666666 "> =</ span > < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > act(< span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > dec1(h))
971967 < span style ="color: #008000; font-weight: bold "> return</ span > < span style ="color: #008000 "> self</ span > < span style ="color: #666666 "> .</ span > dec2(h)
0 commit comments