diff --git a/train_cls.py b/train_cls.py index 89fbf0b7..089ac00f 100644 --- a/train_cls.py +++ b/train_cls.py @@ -230,10 +230,10 @@ def get_args(): # prompt tuning hyper-parameter parser.add_argument("--n_ctx", type=int, default=4) - parser.add_argument("--n_ctx_ab", type=int, default=100) + parser.add_argument("--n_ctx_ab", type=int, default=1) parser.add_argument("--n_pro", type=int, default=3) parser.add_argument("--n_pro_ab", type=int, default=4) - parser.add_argument("--Epoch", type=int, default=1) + parser.add_argument("--Epoch", type=int, default=100) # optimizer parser.add_argument("--lr", type=float, default=0.002)