7
7
import torch
8
8
import torch .nn as nn
9
9
from .ocean import Ocean_
10
- from .oceanplus import OceanPlus_
11
- from .oceanplusTRT import OceanPlusTRT_
10
+ # from .oceanplus import OceanPlus_
11
+ # from .oceanplusTRT import OceanPlusTRT_
12
12
from .oceanTRT import OceanTRT_
13
13
from .siamfc import SiamFC_
14
14
from .connect import box_tower , AdjustLayer , AlignHead , Corr_Up , MultiDiCorr , OceanCorr
@@ -49,24 +49,24 @@ def __init__(self, online=False, align=False):
49
49
self .connect_model2 = OceanCorr ()
50
50
51
51
52
- class OceanPlus (OceanPlus_ ):
53
- def __init__ (self , online = False ):
54
- super (OceanPlus , self ).__init__ ()
55
- self .features = ResNet50 (used_layers = [3 ], online = online ) # in param
56
- self .neck = AdjustLayer (in_channels = 1024 , out_channels = 256 )
57
- self .connect_model = box_tower (inchannels = 256 , outchannels = 256 , towernum = 4 )
58
- self .mask_model = MultiRefine (addCorr = True , mulOradd = 'add' )
52
+ # class OceanPlus(OceanPlus_):
53
+ # def __init__(self, online=False):
54
+ # super(OceanPlus, self).__init__()
55
+ # self.features = ResNet50(used_layers=[3], online=online) # in param
56
+ # self.neck = AdjustLayer(in_channels=1024, out_channels=256)
57
+ # self.connect_model = box_tower(inchannels=256, outchannels=256, towernum=4)
58
+ # self.mask_model = MultiRefine(addCorr=True, mulOradd='add')
59
59
60
60
61
- class OceanPlusTRT (OceanPlusTRT_ ):
62
- def __init__ (self , online = False ):
63
- super (OceanPlusTRT , self ).__init__ ()
64
- self .features = ResNet50 (used_layers = [3 ], online = online ) # in param
65
- self .neck = AdjustLayer (in_channels = 1024 , out_channels = 256 )
66
- self .connect_model0 = MultiDiCorr (inchannels = 256 , outchannels = 256 )
67
- self .connect_model1 = box_tower (inchannels = 256 , outchannels = 256 , towernum = 4 )
68
- self .connect_model2 = OceanCorr ()
69
- self .mask_model = MultiRefineTRT (addCorr = True , mulOradd = 'add' )
61
+ # class OceanPlusTRT(OceanPlusTRT_):
62
+ # def __init__(self, online=False):
63
+ # super(OceanPlusTRT, self).__init__()
64
+ # self.features = ResNet50(used_layers=[3], online=online) # in param
65
+ # self.neck = AdjustLayer(in_channels=1024, out_channels=256)
66
+ # self.connect_model0 = MultiDiCorr(inchannels=256, outchannels=256)
67
+ # self.connect_model1 = box_tower(inchannels=256, outchannels=256, towernum=4)
68
+ # self.connect_model2 = OceanCorr()
69
+ # self.mask_model = MultiRefineTRT(addCorr=True, mulOradd='add')
70
70
71
71
72
72
# ------------------------------
@@ -348,4 +348,4 @@ def ONLINEnet50(filter_size=4, optim_iter=5, optim_init_step=0.9, optim_init_reg
348
348
349
349
# ONLINE network
350
350
net = ONLINEnet (feature_extractor = backbone_net , classifier = classifier , classification_layer = classification_layer )
351
- return net
351
+ return net
0 commit comments