Skip to content

MingyuLau/SymNet_torch

Repository files navigation

SymNet_torch_dev

SymNet_torch development repo

requirement

See requirementes.txt. Python 3.7 + PyTorch 1.8.1

usage

MIT

python run_symnet.py --network fc_obj --name MIT_obj_lr3e-3 --data MIT --epoch 1500 --batchnorm --lr 3e-3
python run_symnet.py --name MIT_best --data MIT --epoch 400 --obj_pred MIT_obj_lr3e-3_ep1120.pkl --batchnorm --lr 5e-4 --bz 512 --lambda_cls_attr 1 --lambda_cls_obj 0.01 --lambda_trip 0.03 --lambda_sym 0.05 --lambda_axiom 0.01 --rmd_metric softmax

UT

python run_symnet.py --network fc_obj --name UT_obj_lr1e-3 --data UT --epoch 300 --batchnorm --lr 1e-3
python run_symnet.py --name UT_best --data UT --epoch 700 --obj_pred UT_obj_lr1e-3_ep140.pkl --batchnorm  --wordvec onehot  --lr 1e-4 --bz 256 --lambda_cls_attr 1 --lambda_cls_obj 0.5 --lambda_trip 0.5 --lambda_sym 0.01 --lambda_axiom 0.03 --rmd_metric softmax

progress

UT已经有正常分数了(51.3+) MIT14分(差很多)

  1. gczsl run/test (GCZSL evaluator还没有改过,可能会有问题)
  2. 多卡训练

changes/notes

训练loss时用self.args.rmd_metric, test时原本(开源版本)是用"rmd",现在测试也用self.args.rmd_metric

  1. logs和weights合并到了logs
  2. args.weight_type 改成了可读的str类型
  3. args.trained_weight现在是直接的绝对/相对路径
  4. prediction现在不是dict是只有一个list了
  5. prob_pair, prob_attr开源时是分开产生的,现在是同一个

TODOs

  1. yaml以及自动备份
  2. MSEloss在L2的时候不对:少个平方
  3. activation function和weight initializer没有设置
  4. args的key名字跟operator不太一样,可以考虑统一一下
  5. lr scheduler还没实现。如果要加的话还要存进statedict
  6. GRADIENT_CLIPPING还没实现
  7. focal loss not implemented
  8. loss的log精简一下,tb不要显示那么多(参考tf版本
  9. reshape->view
  10. symnet的compute_loss参数prob_RMD_plus, prob_RMD_minus太明显了 藏起来
  11. make this repo more Python3 (type, etc.)
  12. 检查snapshot继续训练时读取有没有错,分数是不是合理

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published