@@ -84,6 +84,9 @@ public static string ReplaceCodes(string text)
84
84
text = Regex . Replace ( text , @"\bnp\.inf\b" , "np.Inf" ) ;
85
85
text = text . Replace ( "time.time()" , "DateTime.Now" ) ;
86
86
87
+ //Tenser.requires_grad
88
+ text = text . Replace ( ".require_grad = true;" , ".requires_grad = true;" ) ;
89
+ text = text . Replace ( ".require_grad = false;" , ".requires_grad = false;" ) ;
87
90
88
91
return text ;
89
92
}
@@ -185,7 +188,7 @@ private static string replaceFieldType(string text)
185
188
186
189
text = Regex . Replace ( text , @"public (object|void) (\w+_path;)" , "public string $2" ) ;
187
190
text = Regex . Replace ( text , @"public (object|void) (\w+_name;)" , "public string $2" ) ;
188
-
191
+
189
192
return text ;
190
193
}
191
194
@@ -196,22 +199,25 @@ private static string replaceFieldType3(string text)
196
199
foreach ( Match m in ms ) {
197
200
var name = m . Groups [ 2 ] . Value ;
198
201
if ( text . Contains ( $ "this.{ name } = { name } ;") ) {
199
- if ( text . Contains ( $ "int { name } = ") ) {
202
+ if ( Regex . IsMatch ( text , @ $ "int { name } \b ") ) {
200
203
text = text . Replace ( $ "public object { name } ;", $ "public int { name } ;") ;
201
204
text = text . Replace ( $ "public void { name } ;", $ "public int { name } ;") ;
202
- } else if ( text . Contains ( $ "long { name } = ") ) {
205
+ } else if ( Regex . IsMatch ( text , @ $ "long { name } \b ") ) {
203
206
text = text . Replace ( $ "public object { name } ;", $ "public long { name } ;") ;
204
207
text = text . Replace ( $ "public void { name } ;", $ "public long { name } ;") ;
205
- } else if ( text . Contains ( $ "doulbe { name } = ") ) {
208
+ } else if ( Regex . IsMatch ( text , @ $ "doulbe { name } \b ") ) {
206
209
text = text . Replace ( $ "public object { name } ;", $ "public doulbe { name } ;") ;
207
210
text = text . Replace ( $ "public void { name } ;", $ "public doulbe { name } ;") ;
208
- } else if ( text . Contains ( $ "string { name } = ") ) {
211
+ } else if ( Regex . IsMatch ( text , @ $ "string { name } \b ") ) {
209
212
text = text . Replace ( $ "public object { name } ;", $ "public string { name } ;") ;
210
213
text = text . Replace ( $ "public void { name } ;", $ "public string { name } ;") ;
211
- } else if ( text . Contains ( $ "bool { name } = ") ) {
214
+ } else if ( Regex . IsMatch ( text , @ $ "bool { name } \b ") ) {
212
215
text = text . Replace ( $ "public object { name } ;", $ "public bool { name } ;") ;
213
216
text = text . Replace ( $ "public void { name } ;", $ "public bool { name } ;") ;
214
217
}
218
+ } else if ( text . Contains ( $ "if (this.{ name } )") || text . Contains ( $ "if (!this.{ name } )") || text . Contains ( $ "if (this.{ name } == true)") || text . Contains ( $ "if (this.{ name } == false)") ) {
219
+ text = text . Replace ( $ "public object { name } ;", $ "public bool { name } ;") ;
220
+ text = text . Replace ( $ "public void { name } ;", $ "public bool { name } ;") ;
215
221
}
216
222
}
217
223
}
@@ -459,12 +465,13 @@ private static string replaceCallForwardMethod(string text)
459
465
/// <returns></returns>
460
466
private static string replaceTensorList ( string text )
461
467
{
462
- text = text . Replace ( " torch.cat(new List<object>" , " torch.cat(new List<Tensor>" ) ;
463
- text = text . Replace ( " torch.ones(new List<object>" , " torch.ones(new List<Tensor>" ) ;
464
- text = text . Replace ( " torch.zeros(new List<object>" , " torch.zeros(new List<Tensor>" ) ;
468
+ text = text . Replace ( "torch.cat(new List<object>" , "torch.cat(new List<Tensor>" ) ;
469
+ text = text . Replace ( "torch.ones(new List<object>" , "torch.ones(new long[]" ) ;
470
+ text = text . Replace ( "torch.ones(new List<int>" , "torch.ones(new long[]" ) ;
471
+ text = text . Replace ( "torch.zeros(new List<object>" , "torch.zeros(new long[]" ) ;
472
+ text = text . Replace ( "torch.zeros(new List<int>" , "torch.zeros(new long[]" ) ;
465
473
466
- text = text . Replace ( "var attns = new List<object>();" , "var attns = new List<Tensor>();" ) ;
467
- text = text . Replace ( "attns.append(attn);" , "attns.Add(attn);" ) ;
474
+ text = text . Replace ( "new List<object>();" , "new List<Tensor>();" ) ;
468
475
return text ;
469
476
}
470
477
/// <summary>
0 commit comments