import mmcv
import matplotlib.pyplot as plt
from fastcore.basics import *
from fastai.vision.all import *
from fastai.torch_basics import *
import warnings
"ignore")
warnings.filterwarnings(import kornia
from kornia.constants import Resample
from kornia.color import *
from kornia import augmentation as K
# import kornia.augmentation as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF
from torchvision.transforms import transforms
from torchvision.transforms import PILToTensor
from functools import partial
from timm.models.layers import trunc_normal_, DropPath
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import _cfg
from einops import rearrange
from timm.models.registry import register_model
105) set_seed(
code borrowed from https://github.com/tky823/DNN-based_source_separation/blob/main/src/models/d2net.py
= 1e-12
EPS
def choose_layer_norm(name, num_features, n_dims=2, eps=EPS, **kwargs):
if name in ['BN', 'batch', 'batch_norm']:
if n_dims == 1:
= nn.BatchNorm1d(num_features, eps=eps)
layer_norm elif n_dims == 2:
= nn.BatchNorm2d(num_features, eps=eps)
layer_norm else:
raise NotImplementedError("n_dims is expected 1 or 2, but give {}.".format(n_dims))
else:
raise NotImplementedError("Not support {} layer normalization.".format(name))
return layer_norm
def choose_nonlinear(name, **kwargs):
if name == 'relu':
= nn.ReLU()
nonlinear else:
raise NotImplementedError("Invalid nonlinear function is specified. Choose 'relu' instead of {}.".format(name))
return nonlinear
from torch.nn.modules.utils import _pair
1) _pair(
(1, 1)
class ConvBlock2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, norm=True, nonlinear='relu', eps=EPS):
super().__init__()
assert stride == 1, "`stride` is expected 1"
self.kernel_size = _pair(kernel_size)
self.dilation = _pair(dilation)
self.norm = norm
self.nonlinear = nonlinear
if self.norm:
if type(self.norm) is bool:
= 'BN'
name else:
= self.norm
name self.norm2d = choose_layer_norm(name, in_channels, n_dims=2, eps=eps)
if self.nonlinear is not None:
self.nonlinear2d = choose_nonlinear(self.nonlinear)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation)
def forward(self, input):
"""
Args:
input (batch_size, in_channels, H, W)
Returns:
output (batch_size, out_channels, H, W)
"""
= self.kernel_size
Kh, Kw = self.dilation
Dh, Dw
= (Kh - 1) * Dh
padding_height = (Kw - 1) * Dw
padding_width = padding_height // 2
padding_up = padding_height - padding_up
padding_bottom = padding_width // 2
padding_left = padding_width - padding_left
padding_right
= input
x
if self.norm:
= self.norm2d(x)
x if self.nonlinear:
= self.nonlinear2d(x)
x
= F.pad(x, (padding_left, padding_right, padding_up, padding_bottom))
x = self.conv2d(x)
output
return output
= ConvBlock2d(3,128,3,1,1,norm=False,nonlinear=None)
temp 1,3,32,32)).shape temp(torch.randn(
torch.Size([1, 128, 32, 32])
= 1e-12
EPS
class D2BlockFixedDilation(nn.Module):
def __init__(self, in_channels, growth_rate, kernel_size, dilation=1, norm=True, nonlinear='relu', depth=None, eps=EPS):
"""
Args:
in_channels <int>: # of input channels
growth_rate <int> or <list<int>>: # of output channels
kernel_size <int> or <tuple<int>>: Kernel size
dilation <int>: Dilation od dilated convolution.
norm <bool> or <list<bool>>: Applies batch normalization.
nonlinear <str> or <list<str>>: Applies nonlinear function.
depth <int>: If `growth_rate` is given by list, len(growth_rate) must be equal to `depth`.
"""
super().__init__()
if type(growth_rate) is int:
assert depth is not None, "Specify `depth`"
= [growth_rate] * depth
growth_rate elif type(growth_rate) is list:
if depth is not None:
assert depth == len(growth_rate), "`depth` is different from `len(growth_rate)`"
= len(growth_rate)
depth else:
raise ValueError("Not support growth_rate={}".format(growth_rate))
if not type(dilation) is int:
raise ValueError("Not support dilated={}".format(dilated))
if type(norm) is bool:
assert depth is not None, "Specify `depth`"
= [norm] * depth
norm elif type(norm) is list:
if depth is not None:
assert depth == len(norm), "`depth` is different from `len(norm)`"
= len(norm)
depth else:
raise ValueError("Not support norm={}".format(norm))
if type(nonlinear) is bool or type(nonlinear) is str:
assert depth is not None, "Specify `depth`"
= [nonlinear] * depth
nonlinear elif type(nonlinear) is list:
if depth is not None:
assert depth == len(nonlinear), "`depth` is different from `len(nonlinear)`"
= len(nonlinear)
depth else:
raise ValueError("Not support nonlinear={}".format(nonlinear))
self.growth_rate = growth_rate
self.depth = depth
= []
net = in_channels - sum(growth_rate)
_in_channels
for idx in range(depth):
if idx == 0:
= in_channels
_in_channels else:
= growth_rate[idx - 1]
_in_channels = sum(growth_rate[idx:])
_out_channels
= ConvBlock2d(_in_channels, _out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, norm=norm[idx], nonlinear=nonlinear[idx], eps=eps)
conv_block
net.append(conv_block)
self.net = nn.Sequential(*net)
def forward(self, input):
"""
Args:
input: (batch_size, in_channels, H, W)
Returns:
output: (batch_size, out_channels, H, W), where out_channels = growth_rate[-1].
"""
= self.growth_rate, self.depth
growth_rate, depth
= 0
x_residual
for idx in range(depth):
if idx == 0:
= input
x else:
= growth_rate[idx - 1]
_in_channels = [_in_channels, sum(growth_rate[idx:])]
sections = torch.split(x_residual, sections, dim=1)
x, x_residual
= self.net[idx](x)
x = x_residual + x
x_residual
= x_residual
output
return output
class D2Block(nn.Module):
def __init__(self, in_channels, growth_rate, kernel_size, dilated=True, norm=True, nonlinear='relu', depth=None, eps=EPS):
"""
Args:
in_channels <int>: # of input channels
growth_rate <int> or <list<int>>: # of output channels
kernel_size <int> or <tuple<int>>: Kernel size
dilated <bool> or <list<bool>>: Applies dilated convolution.
norm <bool> or <list<bool>>: Applies batch normalization.
nonlinear <str> or <list<str>>: Applies nonlinear function.
depth <int>: If `growth_rate` is given by list, len(growth_rate) must be equal to `depth`.
"""
super().__init__()
if type(growth_rate) is int:
assert depth is not None, "Specify `depth`"
= [growth_rate] * depth
growth_rate elif type(growth_rate) is list:
if depth is not None:
assert depth == len(growth_rate), "`depth` is different from `len(growth_rate)`"
= len(growth_rate)
depth else:
raise ValueError("Not support growth_rate={}".format(growth_rate))
if type(dilated) is bool:
assert depth is not None, "Specify `depth`"
= [dilated] * depth
dilated elif type(dilated) is list:
if depth is not None:
assert depth == len(dilated), "`depth` is different from `len(dilated)`"
= len(dilated)
depth else:
raise ValueError("Not support dilated={}".format(dilated))
if type(norm) is bool:
assert depth is not None, "Specify `depth`"
= [norm] * depth
norm elif type(norm) is list:
if depth is not None:
assert depth == len(norm), "`depth` is different from `len(norm)`"
= len(norm)
depth else:
raise ValueError("Not support norm={}".format(norm))
if type(nonlinear) is bool or type(nonlinear) is str:
assert depth is not None, "Specify `depth`"
= [nonlinear] * depth
nonlinear elif type(nonlinear) is list:
if depth is not None:
assert depth == len(nonlinear), "`depth` is different from `len(nonlinear)`"
= len(nonlinear)
depth else:
raise ValueError("Not support nonlinear={}".format(nonlinear))
self.growth_rate = growth_rate
self.depth = depth
= []
net = in_channels - sum(growth_rate)
_in_channels
for idx in range(depth):
if idx == 0:
= in_channels
_in_channels else:
= growth_rate[idx - 1]
_in_channels = sum(growth_rate[idx:])
_out_channels
if dilated[idx]:
= 2**idx
dilation else:
= 1
dilation
= ConvBlock2d(_in_channels, _out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, norm=norm[idx], nonlinear=nonlinear[idx], eps=eps)
conv_block
net.append(conv_block)
self.net = nn.Sequential(*net)
def forward(self, input):
"""
Args:
input: (batch_size, in_channels, H, W)
Returns:
output: (batch_size, out_channels, H, W), where out_channels = growth_rate[-1].
"""
= self.growth_rate, self.depth
growth_rate, depth
for idx in range(depth):
if idx == 0:
= input
x = 0
x_residual else:
= growth_rate[idx - 1]
_in_channels = [_in_channels, sum(growth_rate[idx:])]
sections = torch.split(x_residual, sections, dim=1)
x, x_residual
= self.net[idx](x)
x = x_residual + x
x_residual
= x_residual
output
return output
def _test_d2block():
= 4
batch_size = 64, 64
n_bins, n_frames = 3
in_channels = (3, 3)
kernel_size = 4
depth
input = torch.randn(batch_size, in_channels, n_bins, n_frames)
print("-"*10, "D2 Block when `growth_rate` is given as int and `dilated` is given as bool.", "-"*10)
= 2
growth_rate = True
dilated = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated, depth=depth)
model
print("-"*10, "D2 Block", "-"*10)
print(model)
= model(input)
output print(input.size(), output.size())
print()
# print("-"*10, "D2 Block when `growth_rate` is given as list and `dilated` is given as bool.", "-"*10)
# growth_rate = [3, 4, 5, 6] # depth = 4
# dilated = False
# model = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated)
# print(model)
# output = model(input)
# print(input.size(), output.size())
# print()
# print("-"*10, "D2 Block when `growth_rate` is given as list and `dilated` is given as list.", "-"*10)
# growth_rate = [3, 4, 5, 6] # depth = 4
# dilated = [True, False, False, True] # depth = 4
# model = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated)
# print(model)
# output = model(input)
# print(input.size(), output.size())
print("="*10, "D2 Block", "="*10)
_test_d2block()
========== D2 Block ==========
---------- D2 Block when `growth_rate` is given as int and `dilated` is given as bool. ----------
---------- D2 Block ----------
D2Block(
(net): Sequential(
(0): ConvBlock2d(
(norm2d): BatchNorm2d(3, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
(nonlinear2d): ReLU()
(conv2d): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1))
)
(1): ConvBlock2d(
(norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
(nonlinear2d): ReLU()
(conv2d): Conv2d(2, 6, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
)
(2): ConvBlock2d(
(norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
(nonlinear2d): ReLU()
(conv2d): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(4, 4))
)
(3): ConvBlock2d(
(norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
(nonlinear2d): ReLU()
(conv2d): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), dilation=(8, 8))
)
)
)
Traceback (most recent call last):
File "/home/ubuntu/miniconda3/envs/new/lib/python3.8/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 478, in change_attr_expression
value = eval(expression, frame.f_globals, frame.f_locals)
File "<string>", line 1
tensor([[[[ 1.3090e-01, -4.1537e-01, -1.9816e-01, ..., -1.1555e-01, -2.2067e-01, -2.6661e-01], [-1.6991e-01, -3.7272e-02, -5.9325e-02, ..., -5.6490e-01, 2.5688e-01, -2.1485e-01], [-4.1483e-01, -8.1407e-02, 3.9883e-02, ..., -8.5913e-02, 1.7721e-02, 3.6666e-01], ..., [ 1.0038e-01, -2.2772e-01, -3.6661e-01, ..., 3.1680e-01, 1.4326e-01, -1.7341e-02], [ 2.5856e-01, -3.5614e-01, -3.0179e-02, ..., -7.1122e-01, -3.5760e-02, -2.0752e-01], [-2.4753e-01, -8.6356e-02, -9.8095e-03, ..., 1.0479e-01, -1.5521e-01, -3.1733e-01]], [[ 5.7790e-02, -1.1537e-01, -3.3660e-01, ..., 1.7819e-01, -3.5242e-02, -7.4898e-02], [ 1.1687e-01, -3.9582e-01, -8.4538e-02, ..., 1.6796e-01, -2.5183e-01, -1.7742e-01], [ 1.3049e-01, 1.6287e-02, -1.0771e-01, ..., -1.0322e+00, 4.9804e-01, 1.2598e-01], ..., [-3.1282e-01, -1.8550e-02, 9.8423e-02, ..., -1.0855e-02, -6.2761e-02, 6.5281e-03], [-2.6171e-01, -4.2940e-01, 2.0093e-01, ..., -9.7036e-02, -3.5372e-01, 2.1260e-01], [ 8.7151e-02, -1.3352e-01, 1.3649e-02, ..., -2.0162e-01, 1.6274e-01, -1.8531e-01]], [[ 1.9531e-01, 2.1278e-01, 2.5971e-01, ..., -2.1074e-01, -4.4137e-01, 2.1402e-01], [ 3.8864e-01, 1.5082e-01, 1.4350e-01, ..., 1.0126e+00, -9.5265e-02, 3.1971e-02], [ 3.2864e-01, -1.0494e-01, 4.4397e-01, ..., 2.2869e-01, -4.5287e-01, -6.4710e-01], ..., [ 2.0047e-01, 4.1493e-01, 3.7252e-01, ..., 1.4543e-01, -2.4189e-02, -8.7822e-02], [ 2.5906e-01, 2.1651e-01, -2.7468e-02, ..., 5.5440e-01, -1.1034e-01, 1.8114e-01], [ 3.5780e-01, -6.4934e-03, 3.1841e-01, ..., -7.3575e-02, 4.5266e-01, 1.7474e-01]], [[ 2.6019e-01, -4.9918e-03, 4.9222e-02, ..., 4.9808e-01, 7.7268e-01, 2.4371e-01], [ 9.6638e-04, 3.0653e-01, 1.7191e-01, ..., 5.1715e-03, 3.4442e-01, 1.1149e-01], [ 3.3931e-01, 2.6264e-01, 5.1961e-01, ..., -2.9706e-01, 7.4714e-01, 1.0039e+00], ..., [-2.2798e-02, -9.0747e-02, 1.7157e-01, ..., 2.4040e-01, 3.1987e-01, 2.4248e-01], [-6.1366e-02, 2.4639e-02, 5.1985e-01, ..., 1.8540e-01, 6.7501e-01, 4.5124e-01], [ 1.9541e-01, -8.2181e-02, 1.5417e-01, ..., -2.4319e-01, -1.6177e-03, -3.6424e-02]], [[ 5.8356e-02, 3.1807e-01, 3.7832e-01, ..., 5.6302e-01, 7.0850e-01, 1.2015e-01], [ 1.9481e-02, 8.6745e-02, -1.1216e-01, ..., 6.8010e-01, 4.3849e-01, 2.1097e-01], [ 4.7859e-01, 1.3357e-01, 3.6910e-01, ..., 1.6181e-01, 6.6159e-01, 7.6544e-01], ..., [-9.6094e-02, 6.2552e-01, 7.3016e-01, ..., 2.5252e-01, 4.1118e-01, 5.3823e-02], [-1.4834e-01, 1.0271e-01, 2.6638e-01, ..., 9.2894e-01, 6.1686e-01, 4.3951e-01], [ 2.5091e-01, 1.1119e-01, 3.5358e-01, ..., -1.8557e-01, 3.4108e-01, 5.0081e-01]], [[ 3.9389e-02, 1.1824e-01, -5.2395e-03, ..., -5.7838e-01, -4.4865e-02, 7.4557e-02], [-6.5465e-02, -4.7576e-01, -6.0455e-02, ..., -1.2431e-01, 4.9461e-02, 3.8061e-02], [ 1.7141e-01, 6.3351e-02, 2.4807e-01, ..., 1.9720e-01, -4.8070e-01, 1.3536e-01], ..., [-4.2736e-01, -3.8638e-01, 4.1305e-01, ..., -8.2113e-01, -1.7334e-03, 6.8984e-02], [-3.0139e-01, -6.1888e-01, 4.0083e-02, ..., 1.8445e-01, -3.9581e-02, -1.6408e-01], [ 1.3904e-01, -2.3817e-01, 8.3790e-02, ..., -1.6300e-01, -1.1475e-01, -7.9784e-02]]], [[[ 1.4356e-01, 8.8703e-03, 6.0797e-02, ..., 2.9055e-02, -1.7647e-01, -1.8921e-01], [-7.3982e-01, -4.1122e-01, -2.0923e-01, ..., 1.7836e-01, -5.5942e-01, -1.4614e-01], [ 3.3344e-01, 2.3998e-02, -3.5281e-01, ..., -4.1098e-01, 2.5884e-01, -9.2192e-04], ..., [-3.5500e-01, -2.9288e-01, 1.9643e-01, ..., 3.4869e-01, 5.0756e-01, -4.1429e-01], [ 1.4664e-01, 1.5499e-01, -9.7463e-02, ..., -5.8297e-01, -6.3597e-02, 3.4977e-01], [-9.2896e-02, -2.7687e-01, 1.7470e-01, ..., 1.9764e-01, 1.5789e-01, 2.0480e-01]], [[-1.0954e-01, -2.6265e-01, 4.6650e-02, ..., -4.5176e-01, -1.5760e-01, -1.2904e-01], [ 1.4853e-01, -6.7585e-02, 9.0678e-02, ..., -4.9944e-01, -1.5137e-01, -1.7148e-01], [-8.4873e-02, 2.6314e-01, -5.3545e-01, ..., 1.9188e-01, -1.7622e-01, -7.6135e-02], ..., [-1.7582e-01, -2.1637e-01, -3.2816e-02, ..., -3.4230e-01, -2.4278e-01, -8.0402e-01], [-4.0788e-01, 4.0176e-02, -5.4633e-01, ..., -2.9893e-01, -1.4787e-01, -4.2392e-01], [ 4.5329e-02, -3.3867e-01, -4.6240e-01, ..., -2.0694e-01, -3.9650e-01, -4.4701e-01]], [[-1.8092e-01, 3.8991e-01, 3.6510e-01, ..., 7.7998e-03, 4.9165e-01, -1.5063e-01], [ 3.7547e-01, 2.3853e-01, 9.3737e-01, ..., 7.0939e-01, 2.4517e-01, 1.2706e-01], [-1.7488e-01, 7.8685e-01, 7.2161e-02, ..., 6.9449e-01, -1.8818e-01, 9.1750e-02], ..., [ 3.2627e-01, 3.2579e-01, -4.3036e-01, ..., 5.3843e-01, 8.0948e-01, -3.4934e-01], [-2.9014e-01, 3.3120e-01, 4.5309e-01, ..., -5.2109e-02, 8.9996e-01, -1.8583e-01], [ 4.6105e-01, 5.5430e-01, 1.9388e-01, ..., -1.6808e-01, 8.7509e-01, 1.5337e-02]], [[-1.4089e-01, -1.2434e-01, 4.6653e-01, ..., -3.4566e-02, 3.4214e-01, 3.2655e-01], [-1.4582e-01, 6.0733e-01, 4.8809e-01, ..., -1.9225e-01, 1.7463e-01, 9.5746e-02], [-4.0453e-01, 3.3570e-01, -1.4281e-01, ..., 5.2611e-02, 3.7899e-01, 1.5127e-01], ..., [ 3.6659e-01, -3.7853e-02, 8.7981e-02, ..., 2.4901e-01, 1.2098e-01, 2.5506e-01], [ 4.1407e-01, 6.2428e-01, 3.5363e-01, ..., 2.3824e-01, 6.8821e-01, 2.0267e-02], [ 8.2573e-02, 2.2750e-01, 2.2867e-02, ..., -1.2145e-01, 2.3449e-01, -2.6963e-01]], [[-8.6067e-02, 3.5226e-01, 7.2136e-02, ..., 4.4474e-01, 7.7592e-01, 2.9618e-01], [ 5.4441e-01, 7.5268e-01, 1.9816e-01, ..., -1.8014e-01, 4.8248e-01, 1.9109e-01], [ 1.5967e-02, 4.2697e-01, -1.7190e-01, ..., 6.5677e-01, 3.9801e-01, 6.6904e-04], ..., [ 2.4438e-01, 1.8839e-01, 2.1768e-01, ..., 1.7727e-01, -2.7860e-01, -8.4015e-02], [ 2.7309e-01, 2.0295e-01, 2.4740e-01, ..., 3.1903e-01, 8.4223e-02, -1.5729e-01], [ 8.9682e-02, 3.8643e-01, -1.8528e-01, ..., -1.6292e-01, 1.0070e-01, -2.9439e-01]], [[-2.0255e-01, 3.8269e-01, -6.5444e-03, ..., 6.0098e-02, 5.6779e-01, -1.7993e-01], [-8.0255e-01, 1.3640e-01, 1.1707e-02, ..., -7.4824e-01, 1.9584e-02, -1.4349e-01], [-1.2504e-01, 2.2301e-01, -5.5353e-01, ..., -2.1798e-01, 1.4390e-02, -1.3819e-01], ..., [-1.1011e-01, 1.1887e-01, -7.1152e-02, ..., -3.5547e-01, -1.6031e-01, -5.2682e-01], [ 9.6868e-02, 8.8074e-02, -2.7179e-01, ..., -2.6354e-01, 1.2194e-01, -4.9255e-01], [-2.0110e-01, -2.8900e-01, -2.6629e-01, ..., -1.4134e-01, -2.0730e-01, -5.0878e-01]]], [[[ 5.6126e-02, -2.5404e-02, -4.1849e-01, ..., -1.2481e-02, -7.2202e-01, -1.1423e-01], [ 1.0557e-01, -4.9420e-01, 2.0791e-01, ..., -1.2306e-02, 2.3137e-01, -5.8570e-02], [-4.1903e-01, 3.3547e-01, 2.6958e-01, ..., 1.0241e-01, 2.3915e-01, -3.3904e-01], ..., [ 3.4523e-02, 1.6014e-01, -4.1179e-01, ..., 1.2029e-01, 7.2006e-02, 3.4913e-02], [-1.0237e-01, -2.7229e-01, 2.0040e-01, ..., -7.1299e-01, 4.2478e-02, -2.4862e-01], [ 2.7004e-01, -1.0191e-01, -4.2506e-01, ..., 2.5071e-01, 1.8583e-01, -3.9743e-01]], [[-2.1440e-02, 2.9341e-01, -1.0359e+00, ..., 5.9206e-02, -4.3417e-01, -2.0778e-01], [ 8.9767e-02, -9.1228e-02, -4.0957e-01, ..., 3.5518e-02, -5.6046e-01, 1.7276e-02], [ 4.3644e-01, -6.1414e-01, -2.8989e-01, ..., -2.2452e-01, -4.4003e-02, -3.7126e-01], ..., [-4.0940e-01, -7.2889e-02, -1.1963e-01, ..., 1.9516e-02, -4.6536e-01, -2.0901e-01], [-3.6333e-01, -4.9914e-01, 1.8891e-01, ..., -1.9815e-01, 9.3363e-02, -1.0661e-02], [-3.8549e-01, -1.7502e-01, -1.0987e-01, ..., -5.7205e-01, -1.0420e-01, -2.6746e-01]], [[-1.6177e-01, 4.3002e-01, -1.7754e-01, ..., 6.6978e-02, 5.5545e-01, 3.2683e-01], [ 1.1671e-01, 6.4944e-01, -2.4066e-01, ..., 9.1957e-01, -3.6894e-02, 2.2699e-01], [ 4.5652e-01, 5.8080e-02, 4.6585e-01, ..., 5.6875e-01, 4.5249e-01, -8.2676e-02], ..., [ 1.1102e-02, 7.2689e-01, 5.3866e-01, ..., 2.2208e-01, 4.7787e-01, 1.3757e-01], [ 4.4706e-01, 8.6774e-01, 3.5510e-01, ..., 5.3292e-01, 2.0083e-01, 6.7404e-01], [ 1.1253e-01, 3.5022e-01, 2.6619e-01, ..., -3.2192e-02, 4.6967e-01, 1.1762e-01]], [[ 4.0240e-01, 7.9310e-01, 6.5073e-02, ..., 5.6078e-01, 4.4267e-01, 3.8579e-01], [ 3.7860e-01, 5.9048e-01, 1.6851e-01, ..., -1.3461e-01, -1.0553e-01, 5.0805e-01], [-1.1219e-01, 4.1083e-01, 7.1022e-02, ..., -1.1048e-02, 1.8843e-01, 1.6078e-01], ..., [ 4.1006e-02, 1.5088e-01, -5.6003e-02, ..., 1.2331e-01, 2.2651e-01, 6.0313e-02], [ 2.2112e-01, -3.8992e-02, -2.8221e-02, ..., 3.9878e-01, 7.0052e-01, 9.5154e-02], [-2.3062e-01, 6.3578e-02, 8.0448e-02, ..., -5.0950e-01, 3.7462e-02, -1.7051e-01]], [[ 4.2451e-01, 3.2481e-01, 2.0671e-01, ..., 6.2420e-01, 7.4580e-01, 5.3606e-01], [ 1.2897e-01, 3.5371e-01, 1.4678e-01, ..., 2.7612e-01, 2.5483e-01, 5.9154e-01], [ 2.5120e-01, 2.6160e-01, 3.5832e-01, ..., 2.3914e-01, 1.1501e-01, 3.4694e-01], ..., [-4.5868e-02, 6.3412e-01, 5.5902e-01, ..., 9.2754e-02, 3.6898e-01, 1.7211e-01], [ 1.9321e-01, 4.5260e-01, 7.2063e-02, ..., 3.5093e-01, 3.3534e-01, 5.2381e-03], [-4.5890e-01, 1.8137e-01, 2.6613e-01, ..., -2.4522e-01, 2.4212e-01, 5.9463e-02]], [[ 1.8611e-01, 9.0050e-01, -1.1732e-01, ..., 6.1897e-01, 2.1543e-02, 1.7325e-02], [-1.0462e-01, 4.7290e-01, -4.7110e-01, ..., -4.9565e-01, -3.6357e-01, 2.3255e-01], [-3.0707e-01, 3.1617e-01, -1.1300e+00, ..., -4.3797e-01, -7.2460e-01, -2.3399e-01], ..., [-1.9016e-01, -3.3995e-01, -6.3654e-01, ..., -4.3309e-01, -1.7814e-02, -4.9589e-01], [-6.3855e-01, -3.1695e-01, -6.9847e-02, ..., -5.6335e-01, -8.6568e-02, -8.5706e-02], [-5.0186e-01, -9.4074e-02, -4.1813e-01, ..., -3.6641e-01, -9.6743e-02, -1.6949e-01]]], [[[-8.1937e-02, 2.1681e-01, -5.2875e-02, ..., -4.4092e-01, 2.6772e-01, -2.4320e-01], [-5.9855e-01, 4.1130e-02, 4.1174e-01, ..., 7.3259e-02, -5.0143e-01, -1.3249e-01], [ 4.0575e-01, 4.6065e-01, -4.4359e-01, ..., -9.5927e-02, -2.1272e-01, 3.1202e-01], ..., [ 2.7425e-01, 4.4362e-01, -1.9357e-01, ..., 7.7396e-01, 9.9243e-02, -2.1813e-01], [ 1.3178e-01, -3.6618e-01, -1.2347e-01, ..., -1.1928e-01, -1.0700e+00, 7.3413e-01], [-1.5958e-02, 2.6776e-01, 3.8485e-01, ..., 1.5780e-01, 5.9023e-01, 7.9637e-02]], [[-1.6003e-01, -2.2649e-02, -1.5973e-01, ..., 5.9173e-02, 1.7555e-01, 3.1533e-02], [-4.5756e-01, -6.2946e-01, -4.6876e-01, ..., -2.9179e-01, 1.7184e-01, -1.4880e-01], [-7.4202e-01, 2.3086e-01, -9.0355e-02, ..., 5.0716e-01, 1.3546e-01, 2.3537e-02], ..., [-2.7658e-01, -2.8616e-01, -8.9722e-01, ..., -5.4695e-01, -1.9833e-01, -4.1697e-01], [-3.7228e-01, -5.7368e-02, -2.9318e-01, ..., 4.6841e-02, 3.2001e-01, 1.1788e-01], [-1.1929e-01, -3.6330e-01, -1.8080e-01, ..., -2.0359e-01, -4.8472e-01, 1.7183e-01]], [[-8.2337e-02, 1.0726e-01, 2.8453e-02, ..., -1.9916e-02, -1.3832e-01, 1.9504e-01], [ 5.1024e-01, 3.0944e-01, 1.3533e-01, ..., 2.9242e-01, 3.8062e-01, -3.4277e-01], [-1.7284e-01, -1.2114e-01, 1.4709e-01, ..., 5.8859e-02, 3.1158e-01, -7.0858e-02], ..., [ 2.9475e-01, -2.5680e-01, 1.4026e-01, ..., -7.1591e-02, 5.5963e-01, -4.0796e-01], [ 2.1493e-01, 5.9311e-01, 5.2926e-01, ..., 2.5379e-01, 5.6689e-01, -7.1665e-01], [ 2.9765e-01, 6.3153e-02, 7.2703e-01, ..., 6.0896e-01, 2.1499e-01, 7.1044e-01]], [[ 4.6605e-01, 6.7653e-01, -7.1833e-02, ..., 3.0033e-02, 7.0425e-01, 4.8540e-01], [ 4.8456e-01, 4.0220e-01, 4.0125e-02, ..., -2.0781e-01, 3.0123e-01, 8.6450e-02], [-4.0775e-01, 6.8943e-01, 4.0457e-01, ..., 6.9082e-01, 8.5107e-03, 3.6140e-01], ..., [-1.4108e-02, 2.0690e-01, -4.4909e-02, ..., -6.0628e-02, 6.1117e-01, 1.0082e-01], [ 1.7421e-01, 7.4176e-02, 4.2727e-01, ..., 2.4365e-01, 1.3029e-01, 8.5371e-01], [ 1.4887e-01, -3.1160e-02, 1.5774e-01, ..., 1.0996e-01, -8.8344e-01, 1.0888e-01]], [[ 3.8795e-01, 3.5317e-01, 9.9006e-02, ..., 8.5936e-01, 6.3873e-01, 2.4603e-01], [ 6.0738e-01, 3.6789e-01, -1.1297e-01, ..., 1.5349e-01, 5.1060e-01, 4.1978e-01], [-2.5957e-01, 6.7207e-01, 5.5577e-01, ..., 6.6401e-01, 2.3646e-01, -6.0896e-02], ..., [-1.6632e-01, -1.1759e-01, 4.6079e-01, ..., -1.4599e-01, 3.4830e-01, 4.4242e-02], [-2.3171e-01, 3.3373e-01, 5.0213e-01, ..., 6.5282e-01, 1.2472e+00, 6.5165e-01], [-8.8611e-02, -3.0489e-01, -1.3409e-01, ..., -1.3034e-01, -5.1689e-01, 2.6151e-02]], [[-3.5143e-01, 3.8954e-02, -1.7949e-01, ..., 2.7233e-02, 3.8368e-01, 1.3782e-01], [-6.3582e-01, -4.1529e-02, -3.4359e-01, ..., -2.6442e-02, -4.9910e-01, -1.1147e-01], [-5.8319e-01, 2.9186e-01, -3.6855e-01, ..., -1.9133e-01, -1.4835e-01, -2.4480e-02], ..., [-2.1798e-01, 7.9533e-02, 2.9324e-01, ..., -5.3775e-01, -1.0022e-01, -2.4489e-01], [-4.7352e-01, -2.6569e-01, 2.9632e-01, ..., 2.5473e-02, -4.2064e-01, 3.0822e-01], [-7.4871e-02, -1.8836e-01, -7.7294e-01, ..., -2.6800e-01, -4.2740e-01, -3.2816e-01]]]], grad_fn=<SplitWithSizesBackward>)
^
SyntaxError: invalid syntax
torch.Size([4, 3, 64, 64]) torch.Size([4, 2, 64, 64])