SymNet_torch development repo
See requirementes.txt
. Python 3.7 + PyTorch 1.8.1
MIT
python run_symnet.py --network fc_obj --name MIT_obj_lr3e-3 --data MIT --epoch 1500 --batchnorm --lr 3e-3
python run_symnet.py --name MIT_best --data MIT --epoch 400 --obj_pred MIT_obj_lr3e-3_ep1120.pkl --batchnorm --lr 5e-4 --bz 512 --lambda_cls_attr 1 --lambda_cls_obj 0.01 --lambda_trip 0.03 --lambda_sym 0.05 --lambda_axiom 0.01 --rmd_metric softmax
UT
python run_symnet.py --network fc_obj --name UT_obj_lr1e-3 --data UT --epoch 300 --batchnorm --lr 1e-3
python run_symnet.py --name UT_best --data UT --epoch 700 --obj_pred UT_obj_lr1e-3_ep140.pkl --batchnorm --wordvec onehot --lr 1e-4 --bz 256 --lambda_cls_attr 1 --lambda_cls_obj 0.5 --lambda_trip 0.5 --lambda_sym 0.01 --lambda_axiom 0.03 --rmd_metric softmax
UT已经有正常分数了(51.3+) MIT14分(差很多)
- gczsl run/test (GCZSL evaluator还没有改过,可能会有问题)
- 多卡训练
训练loss时用self.args.rmd_metric
, test时原本(开源版本)是用"rmd",现在测试也用self.args.rmd_metric
- logs和weights合并到了logs
- args.weight_type 改成了可读的str类型
- args.trained_weight现在是直接的绝对/相对路径
- prediction现在不是dict是只有一个list了
- prob_pair, prob_attr开源时是分开产生的,现在是同一个
- yaml以及自动备份
- MSEloss在L2的时候不对:少个平方
- activation function和weight initializer没有设置
- args的key名字跟operator不太一样,可以考虑统一一下
- lr scheduler还没实现。如果要加的话还要存进statedict
- GRADIENT_CLIPPING还没实现
- focal loss not implemented
- loss的log精简一下,tb不要显示那么多(参考tf版本
- reshape->view
- symnet的compute_loss参数prob_RMD_plus, prob_RMD_minus太明显了 藏起来
- make this repo more Python3 (type, etc.)
- 检查snapshot继续训练时读取有没有错,分数是不是合理