Skip to content

Commit 5d2864a

Browse files
author
linzhijun
committedFeb 18, 2023
Add the set append method and bool type judgment, Conversion field Tenser.requires_grad
1 parent b8f47a5 commit 5d2864a

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed
 

‎src/Extensions/TorchCs/Resources/netstandard.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ public static string format(this string str, params object[] strings)
2626
return String.Format(str, strings);
2727
}
2828

29-
public static string[] split(this string str, string splitStr)
29+
public static List<string> split(this string str, string splitStr)
3030
{
31-
return str.Split(splitStr);
31+
return str.Split(splitStr).ToList();
3232
}
3333
public static string upper(this string str)
3434
{
@@ -109,6 +109,10 @@ public static string rstrip(this string str)
109109
}
110110

111111

112+
public static void append<T>(this ICollection<T> list, T obj)
113+
{
114+
list.Add(obj);
115+
}
112116

113117
public static ICollection<T1> keys<T1, T2>(this IDictionary<T1, T2> dict)
114118
{
@@ -202,6 +206,7 @@ public static partial class TorchEnumerable
202206
}
203207
}
204208
}
209+
205210
}
206211

207212
public static class os

‎src/Extensions/TorchCs/TorchUtil.cs

+18-11
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ public static string ReplaceCodes(string text)
8484
text = Regex.Replace(text, @"\bnp\.inf\b", "np.Inf");
8585
text = text.Replace("time.time()", "DateTime.Now");
8686

87+
//Tenser.requires_grad
88+
text = text.Replace(".require_grad = true;", ".requires_grad = true;");
89+
text = text.Replace(".require_grad = false;", ".requires_grad = false;");
8790

8891
return text;
8992
}
@@ -185,7 +188,7 @@ private static string replaceFieldType(string text)
185188

186189
text = Regex.Replace(text, @"public (object|void) (\w+_path;)", "public string $2");
187190
text = Regex.Replace(text, @"public (object|void) (\w+_name;)", "public string $2");
188-
191+
189192
return text;
190193
}
191194

@@ -196,22 +199,25 @@ private static string replaceFieldType3(string text)
196199
foreach (Match m in ms) {
197200
var name = m.Groups[2].Value;
198201
if (text.Contains($"this.{name} = {name};")) {
199-
if (text.Contains($"int {name} =")) {
202+
if (Regex.IsMatch(text, @$"int {name}\b")) {
200203
text = text.Replace($"public object {name};", $"public int {name};");
201204
text = text.Replace($"public void {name};", $"public int {name};");
202-
} else if (text.Contains($"long {name} =")) {
205+
} else if (Regex.IsMatch(text, @$"long {name}\b")) {
203206
text = text.Replace($"public object {name};", $"public long {name};");
204207
text = text.Replace($"public void {name};", $"public long {name};");
205-
} else if (text.Contains($"doulbe {name} =")) {
208+
} else if (Regex.IsMatch(text, @$"doulbe {name}\b")) {
206209
text = text.Replace($"public object {name};", $"public doulbe {name};");
207210
text = text.Replace($"public void {name};", $"public doulbe {name};");
208-
} else if (text.Contains($"string {name} =")) {
211+
} else if (Regex.IsMatch(text, @$"string {name}\b")) {
209212
text = text.Replace($"public object {name};", $"public string {name};");
210213
text = text.Replace($"public void {name};", $"public string {name};");
211-
} else if (text.Contains($"bool {name} =")) {
214+
} else if (Regex.IsMatch(text, @$"bool {name}\b")) {
212215
text = text.Replace($"public object {name};", $"public bool {name};");
213216
text = text.Replace($"public void {name};", $"public bool {name};");
214217
}
218+
} else if (text.Contains($"if (this.{name})") || text.Contains($"if (!this.{name})") || text.Contains($"if (this.{name} == true)") || text.Contains($"if (this.{name} == false)")) {
219+
text = text.Replace($"public object {name};", $"public bool {name};");
220+
text = text.Replace($"public void {name};", $"public bool {name};");
215221
}
216222
}
217223
}
@@ -459,12 +465,13 @@ private static string replaceCallForwardMethod(string text)
459465
/// <returns></returns>
460466
private static string replaceTensorList(string text)
461467
{
462-
text = text.Replace(" torch.cat(new List<object>", " torch.cat(new List<Tensor>");
463-
text = text.Replace(" torch.ones(new List<object>", " torch.ones(new List<Tensor>");
464-
text = text.Replace(" torch.zeros(new List<object>", " torch.zeros(new List<Tensor>");
468+
text = text.Replace("torch.cat(new List<object>", "torch.cat(new List<Tensor>");
469+
text = text.Replace("torch.ones(new List<object>", "torch.ones(new long[]");
470+
text = text.Replace("torch.ones(new List<int>", "torch.ones(new long[]");
471+
text = text.Replace("torch.zeros(new List<object>", "torch.zeros(new long[]");
472+
text = text.Replace("torch.zeros(new List<int>", "torch.zeros(new long[]");
465473

466-
text = text.Replace("var attns = new List<object>();", "var attns = new List<Tensor>();");
467-
text = text.Replace("attns.append(attn);", "attns.Add(attn);");
474+
text = text.Replace("new List<object>();", "new List<Tensor>();");
468475
return text;
469476
}
470477
/// <summary>

0 commit comments

Comments
 (0)
Please sign in to comment.