Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HawkingC authored Jun 9, 2023
1 parent e5d2a55 commit ec0756c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
14 changes: 12 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from importlib import import_module
import mindspore.ops as ops
import argparse
import mindspore.nn as nn

class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()

# The following implements mindspore.nn.Cell.get_parameters() with MindSpore.
net = MyNet()


parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
Expand Down Expand Up @@ -43,8 +52,9 @@

# train
config.n_vocab = len(vocab)
model = x.Model(config).to(config.device)
model = x.Model(config)
if model_name != 'Transformer':
init_network(model)
print(model.parameters)
for params in net.get_parameters():
print("params:", params)
train(config, model, train_iter, dev_iter, test_iter)
11 changes: 8 additions & 3 deletions train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from mindspore.common.initializer import initializer, XavierNormal, Normal, HeNormal, Constant
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor

class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()

# The following implements mindspore.nn.Cell.get_parameters() with MindSpore.
net = MyNet()

# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in model.named_parameters():
for name, w in net.trainable_params():
if exclude not in name:
if 'weight' in name:
if method == 'xavier':
Expand All @@ -30,8 +35,8 @@ def init_network(model, method='xavier', exclude='embedding', seed=123):

def train(config, model, train_iter, dev_iter, test_iter):
start_time = time.time()
model.train()
optimizer = nn.Adam(model.parameters(), lr=config.learning_rate)
model = Model(net)
optimizer = nn.Adam(net.trainable_params(), learning_rate=0.001)

# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
Expand Down
11 changes: 5 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,21 @@ def load_dataset(path, pad_size=32):
return vocab, train, dev, test

class DatasetIterater(object):
def __init__(self, batches, batch_size, device):
def __init__(self, batches, batch_size):
self.batch_size = batch_size
self.batches = batches
self.n_batches = len(batches) // batch_size
self.residue = False # 记录batch数量是否为整数
if len(batches) % self.n_batches != 0:
self.residue = True
self.index = 0
self.device = device

def _to_tensor(self, datas):
x = Tensor.long([_[0] for _ in datas]).to(self.device)
y = Tensor.long([_[1] for _ in datas]).to(self.device)
x = Tensor.long([_[0] for _ in datas])
y = Tensor.long([_[1] for _ in datas])

# pad前的长度(超过pad_size的设为pad_size)
seq_len = Tensor.long([_[2] for _ in datas]).to(self.device)
seq_len = Tensor.long([_[2] for _ in datas])
return (x, seq_len), y

def __next__(self):
Expand Down Expand Up @@ -115,7 +114,7 @@ def __len__(self):
return self.n_batches

def build_iterator(dataset, config):
iter = DatasetIterater(dataset, config.batch_size, config.device)
iter = DatasetIterater(dataset, config.batch_size)
return iter


Expand Down

0 comments on commit ec0756c

Please sign in to comment.