Skip to content

Commit c459def

Browse files
authoredFeb 16, 2023
Merge pull request #95 from toolgood/master
add TorchSharp replace
2 parents b0aa894 + 67e63f2 commit c459def

File tree

4 files changed

+744
-2
lines changed

4 files changed

+744
-2
lines changed
 

‎src/TorchCs/Resources/netstandard.cs

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#region License
2+
// Copyright 2023 ToolGood
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
#endregion
16+
using System.Runtime.CompilerServices;
17+
18+
namespace System
19+
{
20+
public static partial class TorchExtension
21+
{
22+
public static string format(this string str, params object[] strings)
23+
{
24+
return String.Format(str, strings);
25+
}
26+
27+
public static string[] split(this string str, string splitStr)
28+
{
29+
return str.Split(splitStr);
30+
}
31+
32+
public static ICollection<T1> keys<T1, T2>(this IDictionary<T1, T2> dict)
33+
{
34+
return dict.Keys;
35+
}
36+
/// <summary>
37+
/// Simplify code, similar to python syntax
38+
/// python code : B, L = queries.shape
39+
/// csharp code : var (B, L) = queries.shape.ToLong2();
40+
/// </summary>
41+
/// <param name="array"></param>
42+
/// <returns></returns>
43+
public static (long, long) ToLong2(this long[] array)
44+
{
45+
return (array[0], array[1]);
46+
}
47+
/// <summary>
48+
/// Simplify code, similar to python syntax
49+
/// python code : B, L, _ = queries.shape
50+
/// csharp code : var (B, L, _) = queries.shape.ToLong3();
51+
/// </summary>
52+
/// <param name="array"></param>
53+
/// <returns></returns>
54+
public static (long, long, long) ToLong3(this long[] array)
55+
{
56+
return (array[0], array[1], array[2]);
57+
}
58+
/// <summary>
59+
/// Simplify code, similar to python syntax
60+
/// python code : B, L, _, _ = queries.shape
61+
/// csharp code : var (B, L, _, _) = queries.shape.ToLong4();
62+
/// </summary>
63+
/// <param name="array"></param>
64+
/// <returns></returns>
65+
public static (long, long, long, long) ToLong4(this long[] array)
66+
{
67+
return (array[0], array[1], array[2], array[3]);
68+
}
69+
70+
}
71+
72+
public static partial class TorchEnumerable
73+
{
74+
public static IEnumerable<(TFirst First, TSecond Second)> zip<TFirst, TSecond>(this IEnumerable<TFirst> first, IEnumerable<TSecond> second)
75+
{
76+
if (first is null) {
77+
throw new ArgumentNullException(nameof(first));
78+
}
79+
80+
if (second is null) {
81+
throw new ArgumentNullException(nameof(second));
82+
}
83+
84+
return ZipIterator(first, second);
85+
}
86+
87+
public static IEnumerable<(TFirst First, TSecond Second, TThird Third)> zip<TFirst, TSecond, TThird>(this IEnumerable<TFirst> first, IEnumerable<TSecond> second, IEnumerable<TThird> third)
88+
{
89+
if (first is null) {
90+
throw new ArgumentNullException(nameof(first));
91+
}
92+
93+
if (second is null) {
94+
throw new ArgumentNullException(nameof(second));
95+
}
96+
97+
if (third is null) {
98+
throw new ArgumentNullException(nameof(third));
99+
}
100+
101+
return ZipIterator(first, second, third);
102+
}
103+
104+
private static IEnumerable<(TFirst First, TSecond Second)> ZipIterator<TFirst, TSecond>(IEnumerable<TFirst> first, IEnumerable<TSecond> second)
105+
{
106+
using (IEnumerator<TFirst> e1 = first.GetEnumerator())
107+
using (IEnumerator<TSecond> e2 = second.GetEnumerator()) {
108+
while (e1.MoveNext() && e2.MoveNext()) {
109+
yield return (e1.Current, e2.Current);
110+
}
111+
}
112+
}
113+
114+
private static IEnumerable<(TFirst First, TSecond Second, TThird Third)> ZipIterator<TFirst, TSecond, TThird>(IEnumerable<TFirst> first, IEnumerable<TSecond> second, IEnumerable<TThird> third)
115+
{
116+
using (IEnumerator<TFirst> e1 = first.GetEnumerator())
117+
using (IEnumerator<TSecond> e2 = second.GetEnumerator())
118+
using (IEnumerator<TThird> e3 = third.GetEnumerator()) {
119+
while (e1.MoveNext() && e2.MoveNext() && e3.MoveNext()) {
120+
yield return (e1.Current, e2.Current, e3.Current);
121+
}
122+
}
123+
}
124+
}
125+
126+
public static class os
127+
{
128+
public static void makedirs(string path)
129+
{
130+
Directory.CreateDirectory(path);
131+
}
132+
133+
public class path
134+
{
135+
public static string join(params string[] paths)
136+
{
137+
var ps = paths.ToList();
138+
ps.RemoveAll(q => q == null);
139+
return Path.Combine(ps.ToArray());
140+
}
141+
public static bool exists(string path)
142+
{
143+
return File.Exists(path) || Directory.Exists(path);
144+
}
145+
}
146+
}
147+
}

‎src/TorchCs/TorchCs.csproj

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>net6.0</TargetFramework>
5+
<ImplicitUsings>enable</ImplicitUsings>
6+
<Nullable>enable</Nullable>
7+
</PropertyGroup>
8+
9+
<ItemGroup>
10+
<Compile Remove="netstandard.cs" />
11+
</ItemGroup>
12+
13+
<ItemGroup>
14+
<PackageReference Include="TorchSharp" Version="0.99.2" />
15+
</ItemGroup>
16+
17+
<ItemGroup>
18+
<EmbeddedResource Include="Resources\netstandard.cs" />
19+
</ItemGroup>
20+
21+
</Project>

‎src/TorchCs/TorchUtil.cs

+568
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,568 @@
1+
#region License
2+
// Copyright 2023 ToolGood
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
#endregion
16+
using System.Reflection;
17+
using System.Text;
18+
using System.Text.RegularExpressions;
19+
20+
namespace TorchCs
21+
{
22+
public class TorchUtil
23+
{
24+
private const int MAX_LAYER = 5; // Number of times code contains code
25+
26+
/// <summary>
27+
/// Convert all *.py.cs files in the folder ,Replace grammar rules
28+
/// </summary>
29+
/// <param name="folder"></param>
30+
public static void ReplaceFolder(string folder)
31+
{
32+
var files = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories);
33+
foreach (var file in files) {
34+
var text = File.ReadAllText(file);
35+
File.WriteAllText(file, ReplaceCodes(text));
36+
}
37+
}
38+
/// <summary>
39+
/// Convert file, Replace grammar rules
40+
/// </summary>
41+
/// <param name="file"></param>
42+
public static void ReplaceFile(string file)
43+
{
44+
var text = File.ReadAllText(file);
45+
File.WriteAllText(file, ReplaceCodes(text));
46+
}
47+
/// <summary>
48+
/// Convert code, Replace grammar rules
49+
/// </summary>
50+
/// <param name="text"></param>
51+
/// <returns></returns>
52+
public static string ReplaceCodes(string text)
53+
{
54+
text = Regex.Replace(text, @"object (\w+ = ""\w+""[,;)])", "string $1");
55+
text = Regex.Replace(text, @"object (\w+ = \d+[,;)])", "int $1");
56+
text = Regex.Replace(text, @"object (\w+ = \d+\.\d+[,;)])", "double $1");
57+
58+
text = replaceNamespace(text);
59+
text = replaceConstructor(text);
60+
text = replaceFieldType(text);
61+
text = replaceMethodParameterName(text);
62+
text = replaceMethodParamenterType(text);
63+
text = replaceMathMethod(text);
64+
text = replaceStringToEnum(text);
65+
66+
text = replaceForwardMethod(text);
67+
text = replaceCallForwardMethod(text);
68+
69+
text = replaceListSlice(text);
70+
71+
text = replaceTensorList(text);
72+
text = replaceIsType(text);
73+
74+
text = replaceStringToNetstandard(text);
75+
76+
text = text.Replace("using (var torch.no_grad())", "using (var _no_grad= torch.no_grad())");
77+
text = text.Replace("using (var torch.cuda.amp.autocast())", "using (var _autocast= torch.cuda.amp.autocast())");
78+
79+
text = Regex.Replace(text, @"\bnp\.inf\b", "np.Inf");
80+
text = text.Replace("time.time()", "DateTime.Now");
81+
82+
83+
return text;
84+
}
85+
/// <summary>
86+
/// Create netstandard.cs file.
87+
/// </summary>
88+
/// <param name="folder"></param>
89+
public static void CreateNetstandardCode(string folder)
90+
{
91+
Assembly myAssem = Assembly.GetExecutingAssembly();
92+
var manifestResourceStream = myAssem.GetManifestResourceStream("TorchCs.Resources.netstandard.cs");
93+
if (manifestResourceStream == null) { return; }
94+
95+
manifestResourceStream.Position = 0;
96+
using (StreamReader reader = new StreamReader(manifestResourceStream, Encoding.UTF8)) {
97+
var str = reader.ReadToEnd();
98+
File.WriteAllText(Path.Combine(folder, "netstandard.cs"), str);
99+
}
100+
}
101+
102+
/// <summary>
103+
/// Replace namespace grammar rules
104+
/// </summary>
105+
/// <param name="text"></param>
106+
/// <returns></returns>
107+
private static string replaceNamespace(string text)
108+
{
109+
text = text.Replace("using np = numpy;", "using NumpyDotNet;");
110+
text = text.Replace("using torch;", "using static TorchSharp.torch;\r\nusing torch = TorchSharp.torch;\r\nusing TorchSharp.Modules;");
111+
text = text.Replace("using nn = torch.nn;", "using nn = TorchSharp.torch.nn;");
112+
text = text.Replace("using F = torch.nn.functional;", "using F = TorchSharp.torch.nn.functional;");
113+
text = text.Replace("using optim = torch.optim;", "using optim = TorchSharp.torch.optim;");
114+
text = text.Replace("using DataLoader = torch.utils.data.DataLoader;", "using DataLoader = TorchSharp.torch.utils.data.DataLoader;");
115+
116+
text = text.Replace("using math;", "");
117+
text = text.Replace("using os;", "");
118+
text = text.Replace("using time;", "");
119+
text = text.Replace("using warnings;", "");
120+
121+
return text;
122+
}
123+
124+
/// <summary>
125+
/// Replace constructor grammar rules
126+
/// </summary>
127+
/// <param name="text"></param>
128+
/// <returns></returns>
129+
private static string replaceConstructor(string text)
130+
{
131+
var ms = Regex.Matches(text, @"public class (\S+)[\s \t]*: nn.Module");
132+
if (ms.Count > 0) {
133+
foreach (Match item in ms) {
134+
var name = item.Groups[1].Value.Trim();
135+
text = Regex.Replace(text, $@"(public {name}\([^)]*\))", $"$1:base(\"{name}\")");
136+
text = text.Replace($":base(\"{name}\"):base(\"{name}\")", $":base(\"{name}\")");
137+
}
138+
}
139+
return text;
140+
}
141+
142+
/// <summary>
143+
/// Replace field type
144+
/// </summary>
145+
/// <param name="text"></param>
146+
/// <returns></returns>
147+
private static string replaceFieldType(string text)
148+
{
149+
var nnType = typeof(TorchSharp.torch.nn);
150+
var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public);
151+
foreach (var method in nnMethods) {
152+
var fieldType = method.ReturnType.Name;
153+
var methodName = method.Name;
154+
if (methodName == "ModuleDict" || methodName == "ModuleList") {
155+
continue;
156+
}
157+
var r = $@"this\.(\S+) = nn\.{methodName}\(";
158+
var ms = Regex.Matches(text, r);
159+
if (ms.Count > 0) {
160+
foreach (Match m in ms) {
161+
var name = m.Groups[1].Value;
162+
text = text.Replace($"public object {name};", $"public {fieldType} {name};");
163+
text = text.Replace($"public void {name};", $"public {fieldType} {name};");
164+
text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward(");
165+
}
166+
}
167+
}
168+
text = replaceFieldType3(text);
169+
170+
text = Regex.Replace(text, @"public (object|void) (\w+_len;)", "public int $2");
171+
text = Regex.Replace(text, @"public (object|void) (\w+_in;)", "public int $2");
172+
text = Regex.Replace(text, @"public (object|void) (\w+_model;)", "public int $2");
173+
text = Regex.Replace(text, @"public (object|void) (\w+_out;)", "public int $2");
174+
text = Regex.Replace(text, @"public (object|void) (\w+_channels;)", "public int $2");
175+
text = Regex.Replace(text, @"public (object|void) (num_\w+;)", "public int $2");
176+
177+
return text;
178+
}
179+
180+
private static string replaceFieldType3(string text)
181+
{
182+
var ms = Regex.Matches(text, @"public (object|void) (\S+);");
183+
if (ms.Count > 0) {
184+
foreach (Match m in ms) {
185+
var name = m.Groups[2].Value;
186+
if (text.Contains($"this.{name} = {name};")) {
187+
if (text.Contains($"int {name} =")) {
188+
text = text.Replace($"public object {name};", $"public int {name};");
189+
text = text.Replace($"public void {name};", $"public int {name};");
190+
} else if (text.Contains($"long {name} =")) {
191+
text = text.Replace($"public object {name};", $"public long {name};");
192+
text = text.Replace($"public void {name};", $"public long {name};");
193+
} else if (text.Contains($"doulbe {name} =")) {
194+
text = text.Replace($"public object {name};", $"public doulbe {name};");
195+
text = text.Replace($"public void {name};", $"public doulbe {name};");
196+
} else if (text.Contains($"string {name} =")) {
197+
text = text.Replace($"public object {name};", $"public string {name};");
198+
text = text.Replace($"public void {name};", $"public string {name};");
199+
} else if (text.Contains($"bool {name} =")) {
200+
text = text.Replace($"public object {name};", $"public bool {name};");
201+
text = text.Replace($"public void {name};", $"public bool {name};");
202+
}
203+
}
204+
}
205+
}
206+
return text;
207+
}
208+
/// <summary>
209+
/// Replace Method Parameter Name
210+
/// </summary>
211+
/// <param name="text"></param>
212+
/// <returns></returns>
213+
private static string replaceMethodParameterName(string text)
214+
{
215+
var nnType = typeof(TorchSharp.torch.nn);
216+
var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public);
217+
var torchType = typeof(TorchSharp.torch);
218+
var torchMethods = torchType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public);
219+
220+
Dictionary<string, string> parameters = new Dictionary<string, string>() {
221+
{"inputChannel","in_channels" },
222+
{"outputChannel" ,"out_channels"},
223+
{"dimensions" ,"dim"},
224+
{"hasBias" ,"bias"},
225+
};
226+
227+
foreach (var methodInfo in nnMethods) {
228+
var ps = methodInfo.GetParameters();
229+
foreach (var p in ps) {
230+
text = replaceMethodParameterName(text, "nn." + methodInfo.Name, getPythonParameterName(p.Name), p.Name);
231+
if (parameters.ContainsKey(p.Name)) {
232+
text = replaceMethodParameterName(text, "nn." + methodInfo.Name, parameters[p.Name], p.Name);
233+
}
234+
}
235+
}
236+
foreach (var methodInfo in torchMethods) {
237+
var ps = methodInfo.GetParameters();
238+
foreach (var p in ps) {
239+
for (int i = 0; i < MAX_LAYER; i++) {
240+
text = replaceMethodParameterName(text, "torch." + methodInfo.Name, getPythonParameterName(p.Name), p.Name);
241+
if (parameters.ContainsKey(p.Name)) {
242+
text = replaceMethodParameterName(text, "torch." + methodInfo.Name, parameters[p.Name], p.Name);
243+
}
244+
}
245+
}
246+
}
247+
return text;
248+
}
249+
250+
private static string replaceMethodParameterName(string text, string methodName, string oldName, string newName)
251+
{
252+
if (oldName == newName) { return text; }
253+
var r = $"({methodName}\\([^;]*?)\\b{oldName}:";
254+
return Regex.Replace(text, r, new MatchEvaluator((m) => {
255+
return m.Groups[1].Value + newName + ":";
256+
}));
257+
}
258+
private static string getPythonParameterName(string text)
259+
{
260+
StringBuilder stringBuilder = new StringBuilder();
261+
262+
for (int i = 0; i < text.Length; i++) {
263+
var c = text[i];
264+
if (i == 0) {
265+
stringBuilder.Append(char.ToLower(c));
266+
} else if (c >= 'A' && c <= 'Z') {
267+
stringBuilder.Append('_');
268+
stringBuilder.Append(char.ToLower(c));
269+
} else {
270+
stringBuilder.Append(c);
271+
}
272+
}
273+
return stringBuilder.ToString();
274+
}
275+
276+
/// <summary>
277+
/// Replace Method Parameter Type
278+
/// </summary>
279+
/// <param name="text"></param>
280+
/// <returns></returns>
281+
private static string replaceMethodParamenterType(string text)
282+
{
283+
var tensorType = typeof(TorchSharp.torch.Tensor);
284+
var fields = tensorType.GetFields();
285+
HashSet<string> names = new HashSet<string>();
286+
287+
foreach (var field in fields) {
288+
var ms2 = Regex.Matches(text, @"\b(\w+)\." + field.Name + "\\b");
289+
foreach (Match m in ms2) {
290+
names.Add(m.Groups[1].Value);
291+
}
292+
}
293+
var properties = tensorType.GetProperties();
294+
foreach (var property in properties) {
295+
var ms2 = Regex.Matches(text, @"\b(\w+)\." + property.Name + "\\b");
296+
foreach (Match m in ms2) {
297+
names.Add(m.Groups[1].Value);
298+
}
299+
}
300+
var methodInfos = tensorType.GetMethods(BindingFlags.Public | BindingFlags.Instance);
301+
foreach (var method in methodInfos) {
302+
var ms2 = Regex.Matches(text, @"\b(\w+)\." + method.Name + "\\(");
303+
foreach (Match m in ms2) {
304+
names.Add(m.Groups[1].Value);
305+
}
306+
}
307+
var ms = Regex.Matches(text, @"\b(\w+) = torch\.");
308+
foreach (Match m in ms) {
309+
names.Add(m.Groups[1].Value);
310+
}
311+
foreach (var name in names) {
312+
text = text.Replace("object " + name + ",", "Tensor " + name + ",");
313+
text = text.Replace("void " + name + ",", "Tensor " + name + ",");
314+
text = text.Replace("object " + name + ";", "Tensor " + name + ";");
315+
text = text.Replace("void " + name + ";", "Tensor " + name + ";");
316+
text = text.Replace("object " + name + ")", "Tensor " + name + ")");
317+
text = text.Replace("void " + name + ")", "Tensor " + name + ")");
318+
}
319+
320+
text = Regex.Replace(text, @"(object|void) (\w+_len[,;)])", "int $2");
321+
text = Regex.Replace(text, @"(object|void) (\w+_in[,;)])", "int $2");
322+
text = Regex.Replace(text, @"(object|void) (\w+_model[,;)])", "int $2");
323+
text = Regex.Replace(text, @"(object|void) (\w+_out[,;)])", "int $2");
324+
text = Regex.Replace(text, @"(object|void) (\w+_channels[,;)])", "int $2");
325+
text = Regex.Replace(text, @"(object|void) (num_\w+[,;)])", "int $2");
326+
327+
328+
return text;
329+
}
330+
331+
/// <summary>
332+
/// Replace Math Method
333+
/// Convert 'math.log'(python) to 'Math.Log'(C#)
334+
/// </summary>
335+
/// <param name="text"></param>
336+
/// <returns></returns>
337+
private static string replaceMathMethod(string text)
338+
{
339+
var mathType = typeof(Math);
340+
var mathMethods = mathType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public);
341+
foreach (var methodInfo in mathMethods) {
342+
var name = methodInfo.Name;
343+
var nameL = name.ToLower();
344+
text = Regex.Replace(text, @$"\bmath\.{nameL}\(", $"Math.{name}(");
345+
}
346+
return text;
347+
}
348+
/// <summary>
349+
/// Replace forward method's return type and forward method's parameter type
350+
/// </summary>
351+
/// <param name="text"></param>
352+
/// <returns></returns>
353+
private static string replaceForwardMethod(string text)
354+
{
355+
text = text.Replace(" Tuple<object, object>", " (Tensor, Tensor)");
356+
text = text.Replace(" Tuple<object, void> forward(", " (Tensor, Tensor) forward(");
357+
text = text.Replace(" object[] forward(", " (Tensor, Tensor) forward(");
358+
text = text.Replace(" Tuple<object, List<object>> forward(", " (Tensor, List<Tensor>) forward(");
359+
text = text.Replace(" object forward(", " Tensor forward(");
360+
text = text.Replace(" void forward(", " Tensor forward(");
361+
text = text.Replace(" forward(object x", " forward(Tensor x");
362+
text = text.Replace(" forward(object t", " forward(Tensor t");
363+
text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values");
364+
return text;
365+
}
366+
/// <summary>
367+
/// Replace common forward method calls
368+
/// </summary>
369+
/// <param name="text"></param>
370+
/// <returns></returns>
371+
private static string replaceCallForwardMethod(string text)
372+
{
373+
text = Regex.Replace(text, @"\bthis\.inner_attention\(", "this.inner_attention.forward(");
374+
text = Regex.Replace(text, @"\bthis\.dropout\(", "this.dropout.forward(");
375+
text = Regex.Replace(text, @"\bthis\.attention\(", "this.attention.forward(");
376+
text = Regex.Replace(text, @"\bthis\.self_attention\(", "this.self_attention.forward(");
377+
text = Regex.Replace(text, @"\bthis\.cross_attention\(", "this.cross_attention.forward(");
378+
text = Regex.Replace(text, @"\bthis\.projection\(", "this.projection.forward(");
379+
text = Regex.Replace(text, @"\bthis\.activation\(", "this.activation.forward(");
380+
text = Regex.Replace(text, @"\bthis\.norm\(", "this.norm.forward(");
381+
text = Regex.Replace(text, @"\bthis\.conv\(", "this.conv.forward(");
382+
text = Regex.Replace(text, @"\bthis\.decomp\(", "this.decomp.forward(");
383+
text = Regex.Replace(text, @"\bthis\.decomp1\(", "this.decomp1.forward(");
384+
text = Regex.Replace(text, @"\bthis\.decomp2\(", "this.decomp2.forward(");
385+
text = Regex.Replace(text, @"\bthis\.decomp3\(", "this.decomp3.forward(");
386+
text = Regex.Replace(text, @"\bthis\.decomp4\(", "this.decomp4.forward(");
387+
text = Regex.Replace(text, @"\bthis\.decomp5\(", "this.decomp5.forward(");
388+
text = Regex.Replace(text, @"\bthis\.conv1\(", "this.conv1.forward(");
389+
text = Regex.Replace(text, @"\bthis\.conv2\(", "this.conv2.forward(");
390+
text = Regex.Replace(text, @"\bthis\.conv3\(", "this.conv3.forward(");
391+
text = Regex.Replace(text, @"\bthis\.conv4\(", "this.conv4.forward(");
392+
text = Regex.Replace(text, @"\bthis\.conv5\(", "this.conv5.forward(");
393+
text = Regex.Replace(text, @"\bthis\.norm1\(", "this.norm1.forward(");
394+
text = Regex.Replace(text, @"\bthis\.norm2\(", "this.norm2.forward(");
395+
text = Regex.Replace(text, @"\bthis\.norm3\(", "this.norm3.forward(");
396+
text = Regex.Replace(text, @"\bthis\.norm4\(", "this.norm4.forward(");
397+
text = Regex.Replace(text, @"\bthis\.norm5\(", "this.norm5.forward(");
398+
399+
text = Regex.Replace(text, @"\bthis\.downConv\(", "this.downConv.forward(");
400+
text = Regex.Replace(text, @"\bthis\.maxPool\(", "this.maxPool.forward(");
401+
text = Regex.Replace(text, @"\bthis\.avg\(", "this.avg.forward(");
402+
text = Regex.Replace(text, @"\bthis\.layernorm\(", "this.layernorm.forward(");
403+
text = Regex.Replace(text, @"\bthis\.tokenConv\(", "this.tokenConv.forward(");
404+
405+
text = Regex.Replace(text, @"\bthis\.embedding\(", "this.embedding.forward(");
406+
text = Regex.Replace(text, @"\bthis\.emb\(", "this.emb.forward(");
407+
text = Regex.Replace(text, @"\bthis\.embed\(", "this.embed.forward(");
408+
text = Regex.Replace(text, @"\bthis\.position_embedding\(", "this.position_embedding.forward(");
409+
text = Regex.Replace(text, @"\bthis\.temporal_embedding\(", "this.temporal_embedding.forward(");
410+
text = Regex.Replace(text, @"\bthis\.value_embedding\(", "this.value_embedding.forward(");
411+
412+
text = Regex.Replace(text, @"\bthis\.month_embed\(", "this.month_embed.forward(");
413+
text = Regex.Replace(text, @"\bthis\.day_embed\(", "this.day_embed.forward(");
414+
text = Regex.Replace(text, @"\bthis\.hour_embed\(", "this.hour_embed.forward(");
415+
text = Regex.Replace(text, @"\bthis\.minute_embed\(", "this.minute_embed.forward(");
416+
text = Regex.Replace(text, @"\bthis\.weekday_embed\(", "this.weekday_embed.forward(");
417+
418+
text = Regex.Replace(text, @"\bthis\.enc_embedding\(", "this.enc_embedding.forward(");
419+
text = Regex.Replace(text, @"\bthis\.encoder\(", "this.encoder.forward(");
420+
text = Regex.Replace(text, @"\bthis\.dec_embedding\(", "this.dec_embedding.forward(");
421+
text = Regex.Replace(text, @"\bthis\.decoder\(", "this.decoder.forward(");
422+
423+
text = Regex.Replace(text, @"\bthis\.query_projection\(", "this.query_projection.forward(");
424+
text = Regex.Replace(text, @"\bthis\.key_projection\(", "this.key_projection.forward(");
425+
text = Regex.Replace(text, @"\bthis\.value_projection\(", "this.value_projection.forward(");
426+
text = Regex.Replace(text, @"\bthis\.out_projection\(", "this.out_projection.forward(");
427+
428+
text = Regex.Replace(text, @"\bthis\.attn\(", "this.attn.forward(");
429+
return text;
430+
}
431+
432+
/// <summary>
433+
/// Replace common Tensor list
434+
/// </summary>
435+
/// <param name="text"></param>
436+
/// <returns></returns>
437+
private static string replaceTensorList(string text)
438+
{
439+
text = text.Replace(" torch.cat(new List<object>", " torch.cat(new List<Tensor>");
440+
text = text.Replace(" torch.ones(new List<object>", " torch.ones(new List<Tensor>");
441+
text = text.Replace(" torch.zeros(new List<object>", " torch.zeros(new List<Tensor>");
442+
443+
text = text.Replace("var attns = new List<object>();", "var attns = new List<Tensor>();");
444+
text = text.Replace("attns.append(attn);", "attns.Add(attn);");
445+
return text;
446+
}
447+
/// <summary>
448+
/// Convert python's [:,:,:] syntax
449+
/// </summary>
450+
/// <param name="text"></param>
451+
/// <returns></returns>
452+
private static string replaceListSlice(string text)
453+
{
454+
text = Regex.Replace(text, @"\[([^\[\]]*?)\]", new MatchEvaluator(m => {
455+
if (m.Groups[1].Value.Contains(":") == false) {
456+
return m.Value;
457+
}
458+
var strs = m.Groups[1].Value.Split(',');
459+
List<string> list = new List<string>();
460+
foreach (var str in strs) {
461+
if (str.Trim() == "\":\"") {
462+
list.Add("TensorIndex.Ellipsis");
463+
} else if (str.Trim() == "") {
464+
list.Add("TensorIndex.Null");
465+
} else if (str.Contains(":")) {
466+
var ss = str.Trim().Split(':');
467+
string r = "TensorIndex.Slice(";
468+
for (int i = 0; i < ss.Length; i++) {
469+
var s = ss[i];
470+
if (i > 0) { r += ","; }
471+
if (s.Trim() == "") {
472+
r += "null";
473+
} else {
474+
if (s.StartsWith("self.")) {
475+
r += s.Replace("self.", "this.");
476+
} else {
477+
r += s;
478+
}
479+
}
480+
}
481+
r += ")";
482+
list.Add(r);
483+
} else {
484+
list.Add(str);
485+
}
486+
}
487+
return "[" + string.Join(",", list) + "]";
488+
}));
489+
return text;
490+
}
491+
492+
/// <summary>
493+
/// Convert 'xx is nn.Conv1d' to 'xx is Conv1d'
494+
/// </summary>
495+
/// <param name="text"></param>
496+
/// <returns></returns>
497+
private static string replaceIsType(string text)
498+
{
499+
var nnType = typeof(TorchSharp.torch.nn);
500+
var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public);
501+
foreach (var method in nnMethods) {
502+
var fieldType = method.ReturnType.Name;
503+
var methodName = method.Name;
504+
if (methodName == "ModuleDict" || methodName == "ModuleList") {
505+
continue;
506+
}
507+
text = text.Replace($" is nn.{methodName}", $" is {methodName}");
508+
}
509+
return text;
510+
}
511+
512+
/// <summary>
513+
/// Replace String To Enum
514+
/// example: Convert 'paddingMode: "zeros"' to 'paddingMode: TorchSharp.PaddingModes.Zeros'
515+
/// </summary>
516+
/// <param name="text"></param>
517+
/// <returns></returns>
518+
private static string replaceStringToEnum(string text)
519+
{
520+
text = Regex.Replace(text, @"\bpaddingMode: ""zeros""", "paddingMode: TorchSharp.PaddingModes.Zeros");
521+
text = Regex.Replace(text, @"\bpaddingMode: ""reflect""", "paddingMode: TorchSharp.PaddingModes.Reflect");
522+
text = Regex.Replace(text, @"\bpaddingMode: ""replicate""", "paddingMode: TorchSharp.PaddingModes.Replicate");
523+
text = Regex.Replace(text, @"\bpaddingMode: ""circular""", "paddingMode: TorchSharp.PaddingModes.Circular");
524+
text = Regex.Replace(text, @"\bpaddingMode: ""constant""", "paddingMode: TorchSharp.PaddingModes.Constant");
525+
526+
text = Regex.Replace(text, @"\breduction: ""none""", "reduction: Reduction.None");
527+
text = Regex.Replace(text, @"\breduction: ""mean""", "reduction: Reduction.Mean");
528+
text = Regex.Replace(text, @"\breduction: ""sum""", "reduction: Reduction.Sum");
529+
530+
text = Regex.Replace(text, @"\bnonLinearity: ""relu""", "nonLinearity: NonLinearities.ReLU");
531+
text = Regex.Replace(text, @"\bnonLinearity: ""tanh""", "nonLinearity: NonLinearities.Tanh");
532+
533+
text = Regex.Replace(text, @"\bactivation: ""relu""", "activation: Activations.ReLU");
534+
text = Regex.Replace(text, @"\bactivation: ""gelu""", "activation: Activations.GELU");
535+
536+
text = Regex.Replace(text, @"\bmode: ""nearest""", "mode: UpsampleMode.Nearest");
537+
text = Regex.Replace(text, @"\bmode: ""linear""", "mode: UpsampleMode.Linear");
538+
text = Regex.Replace(text, @"\bmode: ""bilinear""", "mode: UpsampleMode.Bilinear");
539+
text = Regex.Replace(text, @"\bmode: ""bicubic""", "mode: UpsampleMode.Bicubic");
540+
text = Regex.Replace(text, @"\bmode: ""trilinear""", "mode: UpsampleMode.Trilinear");
541+
542+
return text;
543+
}
544+
545+
/// <summary>
546+
/// Convert to the syntax style of netstandard.cs
547+
/// </summary>
548+
/// <param name="text"></param>
549+
/// <returns></returns>
550+
private static string replaceStringToNetstandard(string text)
551+
{
552+
text = Regex.Replace(text, @" zip\(", " TorchEnumerable.zip(");
553+
554+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong2();");
555+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong3();");
556+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong4();");
557+
558+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong2();");
559+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong3();");
560+
text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong4();");
561+
562+
return text;
563+
}
564+
565+
566+
567+
}
568+
}

‎src/pytocs.sln

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
Microsoft Visual Studio Solution File, Format Version 12.00
3-
# Visual Studio 15
4-
VisualStudioVersion = 15.0.27130.2010
3+
# Visual Studio Version 17
4+
VisualStudioVersion = 17.4.33213.308
55
MinimumVisualStudioVersion = 10.0.40219.1
66
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{55CBDA16-9AF7-42BC-BBCF-6A59E81F7BC3}"
77
ProjectSection(SolutionItems) = preProject
@@ -20,6 +20,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTests", "Pytocs.Tests\U
2020
EndProject
2121
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tests", "Tests", "{18B7ABDD-2A4A-43E1-886F-9F7DE2294D93}"
2222
EndProject
23+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TorchCs", "TorchCs\TorchCs.csproj", "{16D1F367-128B-45FA-9950-B03F221E7511}"
24+
EndProject
2325
Global
2426
GlobalSection(SolutionConfigurationPlatforms) = preSolution
2527
Debug|Any CPU = Debug|Any CPU
@@ -46,6 +48,10 @@ Global
4648
{4A4B8856-2696-4ABF-895C-EB4F1D13EA03}.Debug|Any CPU.Build.0 = Debug|Any CPU
4749
{4A4B8856-2696-4ABF-895C-EB4F1D13EA03}.Release|Any CPU.ActiveCfg = Release|Any CPU
4850
{4A4B8856-2696-4ABF-895C-EB4F1D13EA03}.Release|Any CPU.Build.0 = Release|Any CPU
51+
{16D1F367-128B-45FA-9950-B03F221E7511}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
52+
{16D1F367-128B-45FA-9950-B03F221E7511}.Debug|Any CPU.Build.0 = Debug|Any CPU
53+
{16D1F367-128B-45FA-9950-B03F221E7511}.Release|Any CPU.ActiveCfg = Release|Any CPU
54+
{16D1F367-128B-45FA-9950-B03F221E7511}.Release|Any CPU.Build.0 = Release|Any CPU
4955
EndGlobalSection
5056
GlobalSection(SolutionProperties) = preSolution
5157
HideSolutionNode = FALSE

0 commit comments

Comments
 (0)
Please sign in to comment.