|
| 1 | +#region License |
| 2 | +// Copyright 2023 ToolGood |
| 3 | +// |
| 4 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +// you may not use this file except in compliance with the License. |
| 6 | +// You may obtain a copy of the License at |
| 7 | +// |
| 8 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +// |
| 10 | +// Unless required by applicable law or agreed to in writing, software |
| 11 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +// See the License for the specific language governing permissions and |
| 14 | +// limitations under the License. |
| 15 | +#endregion |
| 16 | +using System.Reflection; |
| 17 | +using System.Text; |
| 18 | +using System.Text.RegularExpressions; |
| 19 | + |
| 20 | +namespace TorchCs |
| 21 | +{ |
| 22 | + public class TorchUtil |
| 23 | + { |
| 24 | + private const int MAX_LAYER = 5; // Number of times code contains code |
| 25 | + |
| 26 | + /// <summary> |
| 27 | + /// Convert all *.py.cs files in the folder ,Replace grammar rules |
| 28 | + /// </summary> |
| 29 | + /// <param name="folder"></param> |
| 30 | + public static void ReplaceFolder(string folder) |
| 31 | + { |
| 32 | + var files = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); |
| 33 | + foreach (var file in files) { |
| 34 | + var text = File.ReadAllText(file); |
| 35 | + File.WriteAllText(file, ReplaceCodes(text)); |
| 36 | + } |
| 37 | + } |
| 38 | + /// <summary> |
| 39 | + /// Convert file, Replace grammar rules |
| 40 | + /// </summary> |
| 41 | + /// <param name="file"></param> |
| 42 | + public static void ReplaceFile(string file) |
| 43 | + { |
| 44 | + var text = File.ReadAllText(file); |
| 45 | + File.WriteAllText(file, ReplaceCodes(text)); |
| 46 | + } |
| 47 | + /// <summary> |
| 48 | + /// Convert code, Replace grammar rules |
| 49 | + /// </summary> |
| 50 | + /// <param name="text"></param> |
| 51 | + /// <returns></returns> |
| 52 | + public static string ReplaceCodes(string text) |
| 53 | + { |
| 54 | + text = Regex.Replace(text, @"object (\w+ = ""\w+""[,;)])", "string $1"); |
| 55 | + text = Regex.Replace(text, @"object (\w+ = \d+[,;)])", "int $1"); |
| 56 | + text = Regex.Replace(text, @"object (\w+ = \d+\.\d+[,;)])", "double $1"); |
| 57 | + |
| 58 | + text = replaceNamespace(text); |
| 59 | + text = replaceConstructor(text); |
| 60 | + text = replaceFieldType(text); |
| 61 | + text = replaceMethodParameterName(text); |
| 62 | + text = replaceMethodParamenterType(text); |
| 63 | + text = replaceMathMethod(text); |
| 64 | + text = replaceStringToEnum(text); |
| 65 | + |
| 66 | + text = replaceForwardMethod(text); |
| 67 | + text = replaceCallForwardMethod(text); |
| 68 | + |
| 69 | + text = replaceListSlice(text); |
| 70 | + |
| 71 | + text = replaceTensorList(text); |
| 72 | + text = replaceIsType(text); |
| 73 | + |
| 74 | + text = replaceStringToNetstandard(text); |
| 75 | + |
| 76 | + text = text.Replace("using (var torch.no_grad())", "using (var _no_grad= torch.no_grad())"); |
| 77 | + text = text.Replace("using (var torch.cuda.amp.autocast())", "using (var _autocast= torch.cuda.amp.autocast())"); |
| 78 | + |
| 79 | + text = Regex.Replace(text, @"\bnp\.inf\b", "np.Inf"); |
| 80 | + text = text.Replace("time.time()", "DateTime.Now"); |
| 81 | + |
| 82 | + |
| 83 | + return text; |
| 84 | + } |
| 85 | + /// <summary> |
| 86 | + /// Create netstandard.cs file. |
| 87 | + /// </summary> |
| 88 | + /// <param name="folder"></param> |
| 89 | + public static void CreateNetstandardCode(string folder) |
| 90 | + { |
| 91 | + Assembly myAssem = Assembly.GetExecutingAssembly(); |
| 92 | + var manifestResourceStream = myAssem.GetManifestResourceStream("TorchCs.Resources.netstandard.cs"); |
| 93 | + if (manifestResourceStream == null) { return; } |
| 94 | + |
| 95 | + manifestResourceStream.Position = 0; |
| 96 | + using (StreamReader reader = new StreamReader(manifestResourceStream, Encoding.UTF8)) { |
| 97 | + var str = reader.ReadToEnd(); |
| 98 | + File.WriteAllText(Path.Combine(folder, "netstandard.cs"), str); |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + /// <summary> |
| 103 | + /// Replace namespace grammar rules |
| 104 | + /// </summary> |
| 105 | + /// <param name="text"></param> |
| 106 | + /// <returns></returns> |
| 107 | + private static string replaceNamespace(string text) |
| 108 | + { |
| 109 | + text = text.Replace("using np = numpy;", "using NumpyDotNet;"); |
| 110 | + text = text.Replace("using torch;", "using static TorchSharp.torch;\r\nusing torch = TorchSharp.torch;\r\nusing TorchSharp.Modules;"); |
| 111 | + text = text.Replace("using nn = torch.nn;", "using nn = TorchSharp.torch.nn;"); |
| 112 | + text = text.Replace("using F = torch.nn.functional;", "using F = TorchSharp.torch.nn.functional;"); |
| 113 | + text = text.Replace("using optim = torch.optim;", "using optim = TorchSharp.torch.optim;"); |
| 114 | + text = text.Replace("using DataLoader = torch.utils.data.DataLoader;", "using DataLoader = TorchSharp.torch.utils.data.DataLoader;"); |
| 115 | + |
| 116 | + text = text.Replace("using math;", ""); |
| 117 | + text = text.Replace("using os;", ""); |
| 118 | + text = text.Replace("using time;", ""); |
| 119 | + text = text.Replace("using warnings;", ""); |
| 120 | + |
| 121 | + return text; |
| 122 | + } |
| 123 | + |
| 124 | + /// <summary> |
| 125 | + /// Replace constructor grammar rules |
| 126 | + /// </summary> |
| 127 | + /// <param name="text"></param> |
| 128 | + /// <returns></returns> |
| 129 | + private static string replaceConstructor(string text) |
| 130 | + { |
| 131 | + var ms = Regex.Matches(text, @"public class (\S+)[\s \t]*: nn.Module"); |
| 132 | + if (ms.Count > 0) { |
| 133 | + foreach (Match item in ms) { |
| 134 | + var name = item.Groups[1].Value.Trim(); |
| 135 | + text = Regex.Replace(text, $@"(public {name}\([^)]*\))", $"$1:base(\"{name}\")"); |
| 136 | + text = text.Replace($":base(\"{name}\"):base(\"{name}\")", $":base(\"{name}\")"); |
| 137 | + } |
| 138 | + } |
| 139 | + return text; |
| 140 | + } |
| 141 | + |
| 142 | + /// <summary> |
| 143 | + /// Replace field type |
| 144 | + /// </summary> |
| 145 | + /// <param name="text"></param> |
| 146 | + /// <returns></returns> |
| 147 | + private static string replaceFieldType(string text) |
| 148 | + { |
| 149 | + var nnType = typeof(TorchSharp.torch.nn); |
| 150 | + var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); |
| 151 | + foreach (var method in nnMethods) { |
| 152 | + var fieldType = method.ReturnType.Name; |
| 153 | + var methodName = method.Name; |
| 154 | + if (methodName == "ModuleDict" || methodName == "ModuleList") { |
| 155 | + continue; |
| 156 | + } |
| 157 | + var r = $@"this\.(\S+) = nn\.{methodName}\("; |
| 158 | + var ms = Regex.Matches(text, r); |
| 159 | + if (ms.Count > 0) { |
| 160 | + foreach (Match m in ms) { |
| 161 | + var name = m.Groups[1].Value; |
| 162 | + text = text.Replace($"public object {name};", $"public {fieldType} {name};"); |
| 163 | + text = text.Replace($"public void {name};", $"public {fieldType} {name};"); |
| 164 | + text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward("); |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + text = replaceFieldType3(text); |
| 169 | + |
| 170 | + text = Regex.Replace(text, @"public (object|void) (\w+_len;)", "public int $2"); |
| 171 | + text = Regex.Replace(text, @"public (object|void) (\w+_in;)", "public int $2"); |
| 172 | + text = Regex.Replace(text, @"public (object|void) (\w+_model;)", "public int $2"); |
| 173 | + text = Regex.Replace(text, @"public (object|void) (\w+_out;)", "public int $2"); |
| 174 | + text = Regex.Replace(text, @"public (object|void) (\w+_channels;)", "public int $2"); |
| 175 | + text = Regex.Replace(text, @"public (object|void) (num_\w+;)", "public int $2"); |
| 176 | + |
| 177 | + return text; |
| 178 | + } |
| 179 | + |
| 180 | + private static string replaceFieldType3(string text) |
| 181 | + { |
| 182 | + var ms = Regex.Matches(text, @"public (object|void) (\S+);"); |
| 183 | + if (ms.Count > 0) { |
| 184 | + foreach (Match m in ms) { |
| 185 | + var name = m.Groups[2].Value; |
| 186 | + if (text.Contains($"this.{name} = {name};")) { |
| 187 | + if (text.Contains($"int {name} =")) { |
| 188 | + text = text.Replace($"public object {name};", $"public int {name};"); |
| 189 | + text = text.Replace($"public void {name};", $"public int {name};"); |
| 190 | + } else if (text.Contains($"long {name} =")) { |
| 191 | + text = text.Replace($"public object {name};", $"public long {name};"); |
| 192 | + text = text.Replace($"public void {name};", $"public long {name};"); |
| 193 | + } else if (text.Contains($"doulbe {name} =")) { |
| 194 | + text = text.Replace($"public object {name};", $"public doulbe {name};"); |
| 195 | + text = text.Replace($"public void {name};", $"public doulbe {name};"); |
| 196 | + } else if (text.Contains($"string {name} =")) { |
| 197 | + text = text.Replace($"public object {name};", $"public string {name};"); |
| 198 | + text = text.Replace($"public void {name};", $"public string {name};"); |
| 199 | + } else if (text.Contains($"bool {name} =")) { |
| 200 | + text = text.Replace($"public object {name};", $"public bool {name};"); |
| 201 | + text = text.Replace($"public void {name};", $"public bool {name};"); |
| 202 | + } |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | + return text; |
| 207 | + } |
| 208 | + /// <summary> |
| 209 | + /// Replace Method Parameter Name |
| 210 | + /// </summary> |
| 211 | + /// <param name="text"></param> |
| 212 | + /// <returns></returns> |
| 213 | + private static string replaceMethodParameterName(string text) |
| 214 | + { |
| 215 | + var nnType = typeof(TorchSharp.torch.nn); |
| 216 | + var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); |
| 217 | + var torchType = typeof(TorchSharp.torch); |
| 218 | + var torchMethods = torchType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); |
| 219 | + |
| 220 | + Dictionary<string, string> parameters = new Dictionary<string, string>() { |
| 221 | + {"inputChannel","in_channels" }, |
| 222 | + {"outputChannel" ,"out_channels"}, |
| 223 | + {"dimensions" ,"dim"}, |
| 224 | + {"hasBias" ,"bias"}, |
| 225 | + }; |
| 226 | + |
| 227 | + foreach (var methodInfo in nnMethods) { |
| 228 | + var ps = methodInfo.GetParameters(); |
| 229 | + foreach (var p in ps) { |
| 230 | + text = replaceMethodParameterName(text, "nn." + methodInfo.Name, getPythonParameterName(p.Name), p.Name); |
| 231 | + if (parameters.ContainsKey(p.Name)) { |
| 232 | + text = replaceMethodParameterName(text, "nn." + methodInfo.Name, parameters[p.Name], p.Name); |
| 233 | + } |
| 234 | + } |
| 235 | + } |
| 236 | + foreach (var methodInfo in torchMethods) { |
| 237 | + var ps = methodInfo.GetParameters(); |
| 238 | + foreach (var p in ps) { |
| 239 | + for (int i = 0; i < MAX_LAYER; i++) { |
| 240 | + text = replaceMethodParameterName(text, "torch." + methodInfo.Name, getPythonParameterName(p.Name), p.Name); |
| 241 | + if (parameters.ContainsKey(p.Name)) { |
| 242 | + text = replaceMethodParameterName(text, "torch." + methodInfo.Name, parameters[p.Name], p.Name); |
| 243 | + } |
| 244 | + } |
| 245 | + } |
| 246 | + } |
| 247 | + return text; |
| 248 | + } |
| 249 | + |
| 250 | + private static string replaceMethodParameterName(string text, string methodName, string oldName, string newName) |
| 251 | + { |
| 252 | + if (oldName == newName) { return text; } |
| 253 | + var r = $"({methodName}\\([^;]*?)\\b{oldName}:"; |
| 254 | + return Regex.Replace(text, r, new MatchEvaluator((m) => { |
| 255 | + return m.Groups[1].Value + newName + ":"; |
| 256 | + })); |
| 257 | + } |
| 258 | + private static string getPythonParameterName(string text) |
| 259 | + { |
| 260 | + StringBuilder stringBuilder = new StringBuilder(); |
| 261 | + |
| 262 | + for (int i = 0; i < text.Length; i++) { |
| 263 | + var c = text[i]; |
| 264 | + if (i == 0) { |
| 265 | + stringBuilder.Append(char.ToLower(c)); |
| 266 | + } else if (c >= 'A' && c <= 'Z') { |
| 267 | + stringBuilder.Append('_'); |
| 268 | + stringBuilder.Append(char.ToLower(c)); |
| 269 | + } else { |
| 270 | + stringBuilder.Append(c); |
| 271 | + } |
| 272 | + } |
| 273 | + return stringBuilder.ToString(); |
| 274 | + } |
| 275 | + |
| 276 | + /// <summary> |
| 277 | + /// Replace Method Parameter Type |
| 278 | + /// </summary> |
| 279 | + /// <param name="text"></param> |
| 280 | + /// <returns></returns> |
| 281 | + private static string replaceMethodParamenterType(string text) |
| 282 | + { |
| 283 | + var tensorType = typeof(TorchSharp.torch.Tensor); |
| 284 | + var fields = tensorType.GetFields(); |
| 285 | + HashSet<string> names = new HashSet<string>(); |
| 286 | + |
| 287 | + foreach (var field in fields) { |
| 288 | + var ms2 = Regex.Matches(text, @"\b(\w+)\." + field.Name + "\\b"); |
| 289 | + foreach (Match m in ms2) { |
| 290 | + names.Add(m.Groups[1].Value); |
| 291 | + } |
| 292 | + } |
| 293 | + var properties = tensorType.GetProperties(); |
| 294 | + foreach (var property in properties) { |
| 295 | + var ms2 = Regex.Matches(text, @"\b(\w+)\." + property.Name + "\\b"); |
| 296 | + foreach (Match m in ms2) { |
| 297 | + names.Add(m.Groups[1].Value); |
| 298 | + } |
| 299 | + } |
| 300 | + var methodInfos = tensorType.GetMethods(BindingFlags.Public | BindingFlags.Instance); |
| 301 | + foreach (var method in methodInfos) { |
| 302 | + var ms2 = Regex.Matches(text, @"\b(\w+)\." + method.Name + "\\("); |
| 303 | + foreach (Match m in ms2) { |
| 304 | + names.Add(m.Groups[1].Value); |
| 305 | + } |
| 306 | + } |
| 307 | + var ms = Regex.Matches(text, @"\b(\w+) = torch\."); |
| 308 | + foreach (Match m in ms) { |
| 309 | + names.Add(m.Groups[1].Value); |
| 310 | + } |
| 311 | + foreach (var name in names) { |
| 312 | + text = text.Replace("object " + name + ",", "Tensor " + name + ","); |
| 313 | + text = text.Replace("void " + name + ",", "Tensor " + name + ","); |
| 314 | + text = text.Replace("object " + name + ";", "Tensor " + name + ";"); |
| 315 | + text = text.Replace("void " + name + ";", "Tensor " + name + ";"); |
| 316 | + text = text.Replace("object " + name + ")", "Tensor " + name + ")"); |
| 317 | + text = text.Replace("void " + name + ")", "Tensor " + name + ")"); |
| 318 | + } |
| 319 | + |
| 320 | + text = Regex.Replace(text, @"(object|void) (\w+_len[,;)])", "int $2"); |
| 321 | + text = Regex.Replace(text, @"(object|void) (\w+_in[,;)])", "int $2"); |
| 322 | + text = Regex.Replace(text, @"(object|void) (\w+_model[,;)])", "int $2"); |
| 323 | + text = Regex.Replace(text, @"(object|void) (\w+_out[,;)])", "int $2"); |
| 324 | + text = Regex.Replace(text, @"(object|void) (\w+_channels[,;)])", "int $2"); |
| 325 | + text = Regex.Replace(text, @"(object|void) (num_\w+[,;)])", "int $2"); |
| 326 | + |
| 327 | + |
| 328 | + return text; |
| 329 | + } |
| 330 | + |
| 331 | + /// <summary> |
| 332 | + /// Replace Math Method |
| 333 | + /// Convert 'math.log'(python) to 'Math.Log'(C#) |
| 334 | + /// </summary> |
| 335 | + /// <param name="text"></param> |
| 336 | + /// <returns></returns> |
| 337 | + private static string replaceMathMethod(string text) |
| 338 | + { |
| 339 | + var mathType = typeof(Math); |
| 340 | + var mathMethods = mathType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); |
| 341 | + foreach (var methodInfo in mathMethods) { |
| 342 | + var name = methodInfo.Name; |
| 343 | + var nameL = name.ToLower(); |
| 344 | + text = Regex.Replace(text, @$"\bmath\.{nameL}\(", $"Math.{name}("); |
| 345 | + } |
| 346 | + return text; |
| 347 | + } |
| 348 | + /// <summary> |
| 349 | + /// Replace forward method's return type and forward method's parameter type |
| 350 | + /// </summary> |
| 351 | + /// <param name="text"></param> |
| 352 | + /// <returns></returns> |
| 353 | + private static string replaceForwardMethod(string text) |
| 354 | + { |
| 355 | + text = text.Replace(" Tuple<object, object>", " (Tensor, Tensor)"); |
| 356 | + text = text.Replace(" Tuple<object, void> forward(", " (Tensor, Tensor) forward("); |
| 357 | + text = text.Replace(" object[] forward(", " (Tensor, Tensor) forward("); |
| 358 | + text = text.Replace(" Tuple<object, List<object>> forward(", " (Tensor, List<Tensor>) forward("); |
| 359 | + text = text.Replace(" object forward(", " Tensor forward("); |
| 360 | + text = text.Replace(" void forward(", " Tensor forward("); |
| 361 | + text = text.Replace(" forward(object x", " forward(Tensor x"); |
| 362 | + text = text.Replace(" forward(object t", " forward(Tensor t"); |
| 363 | + text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values"); |
| 364 | + return text; |
| 365 | + } |
| 366 | + /// <summary> |
| 367 | + /// Replace common forward method calls |
| 368 | + /// </summary> |
| 369 | + /// <param name="text"></param> |
| 370 | + /// <returns></returns> |
| 371 | + private static string replaceCallForwardMethod(string text) |
| 372 | + { |
| 373 | + text = Regex.Replace(text, @"\bthis\.inner_attention\(", "this.inner_attention.forward("); |
| 374 | + text = Regex.Replace(text, @"\bthis\.dropout\(", "this.dropout.forward("); |
| 375 | + text = Regex.Replace(text, @"\bthis\.attention\(", "this.attention.forward("); |
| 376 | + text = Regex.Replace(text, @"\bthis\.self_attention\(", "this.self_attention.forward("); |
| 377 | + text = Regex.Replace(text, @"\bthis\.cross_attention\(", "this.cross_attention.forward("); |
| 378 | + text = Regex.Replace(text, @"\bthis\.projection\(", "this.projection.forward("); |
| 379 | + text = Regex.Replace(text, @"\bthis\.activation\(", "this.activation.forward("); |
| 380 | + text = Regex.Replace(text, @"\bthis\.norm\(", "this.norm.forward("); |
| 381 | + text = Regex.Replace(text, @"\bthis\.conv\(", "this.conv.forward("); |
| 382 | + text = Regex.Replace(text, @"\bthis\.decomp\(", "this.decomp.forward("); |
| 383 | + text = Regex.Replace(text, @"\bthis\.decomp1\(", "this.decomp1.forward("); |
| 384 | + text = Regex.Replace(text, @"\bthis\.decomp2\(", "this.decomp2.forward("); |
| 385 | + text = Regex.Replace(text, @"\bthis\.decomp3\(", "this.decomp3.forward("); |
| 386 | + text = Regex.Replace(text, @"\bthis\.decomp4\(", "this.decomp4.forward("); |
| 387 | + text = Regex.Replace(text, @"\bthis\.decomp5\(", "this.decomp5.forward("); |
| 388 | + text = Regex.Replace(text, @"\bthis\.conv1\(", "this.conv1.forward("); |
| 389 | + text = Regex.Replace(text, @"\bthis\.conv2\(", "this.conv2.forward("); |
| 390 | + text = Regex.Replace(text, @"\bthis\.conv3\(", "this.conv3.forward("); |
| 391 | + text = Regex.Replace(text, @"\bthis\.conv4\(", "this.conv4.forward("); |
| 392 | + text = Regex.Replace(text, @"\bthis\.conv5\(", "this.conv5.forward("); |
| 393 | + text = Regex.Replace(text, @"\bthis\.norm1\(", "this.norm1.forward("); |
| 394 | + text = Regex.Replace(text, @"\bthis\.norm2\(", "this.norm2.forward("); |
| 395 | + text = Regex.Replace(text, @"\bthis\.norm3\(", "this.norm3.forward("); |
| 396 | + text = Regex.Replace(text, @"\bthis\.norm4\(", "this.norm4.forward("); |
| 397 | + text = Regex.Replace(text, @"\bthis\.norm5\(", "this.norm5.forward("); |
| 398 | + |
| 399 | + text = Regex.Replace(text, @"\bthis\.downConv\(", "this.downConv.forward("); |
| 400 | + text = Regex.Replace(text, @"\bthis\.maxPool\(", "this.maxPool.forward("); |
| 401 | + text = Regex.Replace(text, @"\bthis\.avg\(", "this.avg.forward("); |
| 402 | + text = Regex.Replace(text, @"\bthis\.layernorm\(", "this.layernorm.forward("); |
| 403 | + text = Regex.Replace(text, @"\bthis\.tokenConv\(", "this.tokenConv.forward("); |
| 404 | + |
| 405 | + text = Regex.Replace(text, @"\bthis\.embedding\(", "this.embedding.forward("); |
| 406 | + text = Regex.Replace(text, @"\bthis\.emb\(", "this.emb.forward("); |
| 407 | + text = Regex.Replace(text, @"\bthis\.embed\(", "this.embed.forward("); |
| 408 | + text = Regex.Replace(text, @"\bthis\.position_embedding\(", "this.position_embedding.forward("); |
| 409 | + text = Regex.Replace(text, @"\bthis\.temporal_embedding\(", "this.temporal_embedding.forward("); |
| 410 | + text = Regex.Replace(text, @"\bthis\.value_embedding\(", "this.value_embedding.forward("); |
| 411 | + |
| 412 | + text = Regex.Replace(text, @"\bthis\.month_embed\(", "this.month_embed.forward("); |
| 413 | + text = Regex.Replace(text, @"\bthis\.day_embed\(", "this.day_embed.forward("); |
| 414 | + text = Regex.Replace(text, @"\bthis\.hour_embed\(", "this.hour_embed.forward("); |
| 415 | + text = Regex.Replace(text, @"\bthis\.minute_embed\(", "this.minute_embed.forward("); |
| 416 | + text = Regex.Replace(text, @"\bthis\.weekday_embed\(", "this.weekday_embed.forward("); |
| 417 | + |
| 418 | + text = Regex.Replace(text, @"\bthis\.enc_embedding\(", "this.enc_embedding.forward("); |
| 419 | + text = Regex.Replace(text, @"\bthis\.encoder\(", "this.encoder.forward("); |
| 420 | + text = Regex.Replace(text, @"\bthis\.dec_embedding\(", "this.dec_embedding.forward("); |
| 421 | + text = Regex.Replace(text, @"\bthis\.decoder\(", "this.decoder.forward("); |
| 422 | + |
| 423 | + text = Regex.Replace(text, @"\bthis\.query_projection\(", "this.query_projection.forward("); |
| 424 | + text = Regex.Replace(text, @"\bthis\.key_projection\(", "this.key_projection.forward("); |
| 425 | + text = Regex.Replace(text, @"\bthis\.value_projection\(", "this.value_projection.forward("); |
| 426 | + text = Regex.Replace(text, @"\bthis\.out_projection\(", "this.out_projection.forward("); |
| 427 | + |
| 428 | + text = Regex.Replace(text, @"\bthis\.attn\(", "this.attn.forward("); |
| 429 | + return text; |
| 430 | + } |
| 431 | + |
| 432 | + /// <summary> |
| 433 | + /// Replace common Tensor list |
| 434 | + /// </summary> |
| 435 | + /// <param name="text"></param> |
| 436 | + /// <returns></returns> |
| 437 | + private static string replaceTensorList(string text) |
| 438 | + { |
| 439 | + text = text.Replace(" torch.cat(new List<object>", " torch.cat(new List<Tensor>"); |
| 440 | + text = text.Replace(" torch.ones(new List<object>", " torch.ones(new List<Tensor>"); |
| 441 | + text = text.Replace(" torch.zeros(new List<object>", " torch.zeros(new List<Tensor>"); |
| 442 | + |
| 443 | + text = text.Replace("var attns = new List<object>();", "var attns = new List<Tensor>();"); |
| 444 | + text = text.Replace("attns.append(attn);", "attns.Add(attn);"); |
| 445 | + return text; |
| 446 | + } |
| 447 | + /// <summary> |
| 448 | + /// Convert python's [:,:,:] syntax |
| 449 | + /// </summary> |
| 450 | + /// <param name="text"></param> |
| 451 | + /// <returns></returns> |
| 452 | + private static string replaceListSlice(string text) |
| 453 | + { |
| 454 | + text = Regex.Replace(text, @"\[([^\[\]]*?)\]", new MatchEvaluator(m => { |
| 455 | + if (m.Groups[1].Value.Contains(":") == false) { |
| 456 | + return m.Value; |
| 457 | + } |
| 458 | + var strs = m.Groups[1].Value.Split(','); |
| 459 | + List<string> list = new List<string>(); |
| 460 | + foreach (var str in strs) { |
| 461 | + if (str.Trim() == "\":\"") { |
| 462 | + list.Add("TensorIndex.Ellipsis"); |
| 463 | + } else if (str.Trim() == "") { |
| 464 | + list.Add("TensorIndex.Null"); |
| 465 | + } else if (str.Contains(":")) { |
| 466 | + var ss = str.Trim().Split(':'); |
| 467 | + string r = "TensorIndex.Slice("; |
| 468 | + for (int i = 0; i < ss.Length; i++) { |
| 469 | + var s = ss[i]; |
| 470 | + if (i > 0) { r += ","; } |
| 471 | + if (s.Trim() == "") { |
| 472 | + r += "null"; |
| 473 | + } else { |
| 474 | + if (s.StartsWith("self.")) { |
| 475 | + r += s.Replace("self.", "this."); |
| 476 | + } else { |
| 477 | + r += s; |
| 478 | + } |
| 479 | + } |
| 480 | + } |
| 481 | + r += ")"; |
| 482 | + list.Add(r); |
| 483 | + } else { |
| 484 | + list.Add(str); |
| 485 | + } |
| 486 | + } |
| 487 | + return "[" + string.Join(",", list) + "]"; |
| 488 | + })); |
| 489 | + return text; |
| 490 | + } |
| 491 | + |
| 492 | + /// <summary> |
| 493 | + /// Convert 'xx is nn.Conv1d' to 'xx is Conv1d' |
| 494 | + /// </summary> |
| 495 | + /// <param name="text"></param> |
| 496 | + /// <returns></returns> |
| 497 | + private static string replaceIsType(string text) |
| 498 | + { |
| 499 | + var nnType = typeof(TorchSharp.torch.nn); |
| 500 | + var nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); |
| 501 | + foreach (var method in nnMethods) { |
| 502 | + var fieldType = method.ReturnType.Name; |
| 503 | + var methodName = method.Name; |
| 504 | + if (methodName == "ModuleDict" || methodName == "ModuleList") { |
| 505 | + continue; |
| 506 | + } |
| 507 | + text = text.Replace($" is nn.{methodName}", $" is {methodName}"); |
| 508 | + } |
| 509 | + return text; |
| 510 | + } |
| 511 | + |
| 512 | + /// <summary> |
| 513 | + /// Replace String To Enum |
| 514 | + /// example: Convert 'paddingMode: "zeros"' to 'paddingMode: TorchSharp.PaddingModes.Zeros' |
| 515 | + /// </summary> |
| 516 | + /// <param name="text"></param> |
| 517 | + /// <returns></returns> |
| 518 | + private static string replaceStringToEnum(string text) |
| 519 | + { |
| 520 | + text = Regex.Replace(text, @"\bpaddingMode: ""zeros""", "paddingMode: TorchSharp.PaddingModes.Zeros"); |
| 521 | + text = Regex.Replace(text, @"\bpaddingMode: ""reflect""", "paddingMode: TorchSharp.PaddingModes.Reflect"); |
| 522 | + text = Regex.Replace(text, @"\bpaddingMode: ""replicate""", "paddingMode: TorchSharp.PaddingModes.Replicate"); |
| 523 | + text = Regex.Replace(text, @"\bpaddingMode: ""circular""", "paddingMode: TorchSharp.PaddingModes.Circular"); |
| 524 | + text = Regex.Replace(text, @"\bpaddingMode: ""constant""", "paddingMode: TorchSharp.PaddingModes.Constant"); |
| 525 | + |
| 526 | + text = Regex.Replace(text, @"\breduction: ""none""", "reduction: Reduction.None"); |
| 527 | + text = Regex.Replace(text, @"\breduction: ""mean""", "reduction: Reduction.Mean"); |
| 528 | + text = Regex.Replace(text, @"\breduction: ""sum""", "reduction: Reduction.Sum"); |
| 529 | + |
| 530 | + text = Regex.Replace(text, @"\bnonLinearity: ""relu""", "nonLinearity: NonLinearities.ReLU"); |
| 531 | + text = Regex.Replace(text, @"\bnonLinearity: ""tanh""", "nonLinearity: NonLinearities.Tanh"); |
| 532 | + |
| 533 | + text = Regex.Replace(text, @"\bactivation: ""relu""", "activation: Activations.ReLU"); |
| 534 | + text = Regex.Replace(text, @"\bactivation: ""gelu""", "activation: Activations.GELU"); |
| 535 | + |
| 536 | + text = Regex.Replace(text, @"\bmode: ""nearest""", "mode: UpsampleMode.Nearest"); |
| 537 | + text = Regex.Replace(text, @"\bmode: ""linear""", "mode: UpsampleMode.Linear"); |
| 538 | + text = Regex.Replace(text, @"\bmode: ""bilinear""", "mode: UpsampleMode.Bilinear"); |
| 539 | + text = Regex.Replace(text, @"\bmode: ""bicubic""", "mode: UpsampleMode.Bicubic"); |
| 540 | + text = Regex.Replace(text, @"\bmode: ""trilinear""", "mode: UpsampleMode.Trilinear"); |
| 541 | + |
| 542 | + return text; |
| 543 | + } |
| 544 | + |
| 545 | + /// <summary> |
| 546 | + /// Convert to the syntax style of netstandard.cs |
| 547 | + /// </summary> |
| 548 | + /// <param name="text"></param> |
| 549 | + /// <returns></returns> |
| 550 | + private static string replaceStringToNetstandard(string text) |
| 551 | + { |
| 552 | + text = Regex.Replace(text, @" zip\(", " TorchEnumerable.zip("); |
| 553 | + |
| 554 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong2();"); |
| 555 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong3();"); |
| 556 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.shape);", "var $1.ToLong4();"); |
| 557 | + |
| 558 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong2();"); |
| 559 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong3();"); |
| 560 | + text = Regex.Replace(text, @"(\([A-Za-z_0-9]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+,[A-Za-z_0-9 ]+\) = \w+\.size\(\));", "var $1.ToLong4();"); |
| 561 | + |
| 562 | + return text; |
| 563 | + } |
| 564 | + |
| 565 | + |
| 566 | + |
| 567 | + } |
| 568 | +} |
0 commit comments