Skip to content

Commit

Permalink
[Fix] Cleans the variable name to avoid comflict characters and numbe…
Browse files Browse the repository at this point in the history
…rs. pydn#10
  • Loading branch information
FelipeMurguia committed Aug 23, 2023
1 parent 48248e0 commit c0830e5
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions comfyui_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import random
import sys
import re
from typing import Dict, List, Any, Callable, Tuple

import black
Expand Down Expand Up @@ -217,7 +218,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
continue

class_type, import_statement, class_code = self.get_class_info(class_type)
initialized_objects[class_type] = class_type.lower().strip()
initialized_objects[class_type] = self.clean_variable_name(class_type)
if class_type in self.base_node_class_mappings.keys():
import_statements.add(import_statement)
if class_type not in self.base_node_class_mappings.keys():
Expand All @@ -234,9 +235,9 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
inputs['unique_id'] = random.randint(1, 2**64)

# Create executed variable and generate code
executed_variables[idx] = f'{class_type.lower().strip()}_{idx}'
executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}'
inputs = self.update_inputs(inputs, executed_variables)

if is_special_function:
special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
else:
Expand Down Expand Up @@ -329,6 +330,21 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L

return final_code

def clean_variable_name(self, class_type: str) -> str:
clean_name = class_type.lower().strip()

# Convert to lowercase and replace spaces with underscores
clean_name = clean_name.lower().replace("-", "_").replace(" ", "_")

# Remove characters that are not letters, numbers, or underscores
clean_name = re.sub(r'[^a-z0-9_]', '', clean_name)

# Ensure that it doesn't start with a number
if clean_name[0].isdigit():
clean_name = "_" + clean_name

return clean_name

def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
"""Generates and returns necessary information about class type.
Expand All @@ -339,10 +355,11 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
"""
import_statement = class_type
variable_name = self.clean_variable_name(class_type)
if class_type in self.base_node_class_mappings.keys():
class_code = f'{class_type.lower().strip()} = {class_type.strip()}()'
class_code = f'{variable_name} = {class_type.strip()}()'
else:
class_code = f'{class_type.lower().strip()} = NODE_CLASS_MAPPINGS["{class_type}"]()'
class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()'

return class_type, import_statement, class_code

Expand Down

0 comments on commit c0830e5

Please sign in to comment.