Skip to content

Commit

Permalink
Merge pull request Mikubill#486 from aiton-sd/main
Browse files Browse the repository at this point in the history
Support for specifying a range of numbers.
  • Loading branch information
Mikubill authored Mar 5, 2023
2 parents 2ff1154 + 92b105a commit d314d1b
Showing 1 changed file with 238 additions and 84 deletions.
322 changes: 238 additions & 84 deletions scripts/xyz_grid_support.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import numpy as np

from modules import scripts, shared

Expand Down Expand Up @@ -50,6 +51,239 @@ def is_all_included(target_list, check_list, allow_blank=False, stop=False):
return True


class ListParser():
"""This class restores a broken list caused by the following process
in the xyz_grid module.
-> valslist = [x.strip() for x in chain.from_iterable(
csv.reader(StringIO(vals)))]
It also performs type conversion,
adjusts the number of elements in the list, and other operations.
This class directly modifies the received list.
"""
numeric_pattern = {
int: {
"range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*",
"count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*"
},
float: {
"range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*",
"count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*"
}
}

################################################
#
# Initialization method from here.
#
################################################

def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True):
self.my_list = my_list
self.converter = converter
self.allow_blank = allow_blank
self.exclude_list = exclude_list
self.re_bracket_start = None
self.re_bracket_start_precheck = None
self.re_bracket_end = None
self.re_bracket_end_precheck = None
self.re_range = None
self.re_count = None
self.compile_regex()
if run:
self.auto_normalize()

def compile_regex(self):
exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None
if exclude_pattern is None:
self.re_bracket_start = re.compile(r"^\[")
self.re_bracket_end = re.compile(r"\]$")
else:
self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])")
self.re_bracket_end = re.compile(fr"(?<!\[(?:{exclude_pattern}))\]$")

if self.converter not in self.numeric_pattern:
return self
# If the converter is either int or float.
self.re_range = re.compile(self.numeric_pattern[self.converter]["range"])
self.re_count = re.compile(self.numeric_pattern[self.converter]["count"])
self.re_bracket_start_precheck = None
self.re_bracket_end_precheck = self.re_count
return self

################################################
#
# Public method from here.
#
################################################

################################################
# This method is executed at the time of initialization.
#
def auto_normalize(self):
if not self.has_list_notation():
self.numeric_range_parser()
self.type_convert()
return self
else:
self.fix_structure()
self.numeric_range_parser()
self.type_convert()
self.fill_to_longest()
return self

def has_list_notation(self):
return any(self._search_bracket(s) for s in self.my_list)

def numeric_range_parser(self, my_list=None, depth=0):
if self.converter not in self.numeric_pattern:
return self

my_list = self.my_list if my_list is None else my_list
result = []
is_matched = False
for s in my_list:
if isinstance(s, list):
result.extend(self.numeric_range_parser(s, depth+1))
continue

match = self._numeric_range_to_list(s)
if s != match:
is_matched = True
result.extend(match if not depth else [match])
continue
else:
result.append(s)
continue

if depth:
return self._transpose(result) if is_matched else [result]
else:
my_list[:] = result
return self

def type_convert(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
for i, s in enumerate(my_list):
if isinstance(s, list):
self.type_convert(s)
elif self.allow_blank and (str(s) in ["None", ""]):
my_list[i] = None
elif self.converter:
my_list[i] = self.converter(s)
else:
my_list[i] = s
return self

def fix_structure(self):
def is_same_length(list1, list2):
return len(list1) == len(list2)

start_indices, end_indices = [], []
for i, s in enumerate(self.my_list):
if is_same_length(start_indices, end_indices):
replace_string = self._search_bracket(s, "[", replace="")
if s != replace_string:
s = replace_string
start_indices.append(i)
if not is_same_length(start_indices, end_indices):
replace_string = self._search_bracket(s, "]", replace="")
if s != replace_string:
s = replace_string
end_indices.append(i + 1)
self.my_list[i] = s
if not is_same_length(start_indices, end_indices):
raise ValueError(f"Lengths of {start_indices} and {end_indices} are different.")
# Restore the structure of a list.
for i, j in zip(reversed(start_indices), reversed(end_indices)):
self.my_list[i:j] = [self.my_list[i:j]]
return self

def fill_to_longest(self, my_list=None, value=None, index=None):
my_list = self.my_list if my_list is None else my_list
if not self.sublist_exists(my_list):
return self
max_length = max(len(sub_list) for sub_list in my_list if isinstance(sub_list, list))
for i, sub_list in enumerate(my_list):
if isinstance(sub_list, list):
fill_value = value if index is None else sub_list[index]
my_list[i] = sub_list + [fill_value] * (max_length-len(sub_list))
return self

def sublist_exists(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
return any(isinstance(item, list) for item in my_list)

def all_sublists(self, my_list=None): # Unused method
my_list = self.my_list if my_list is None else my_list
return all(isinstance(item, list) for item in my_list)

def get_list(self): # Unused method
return self.my_list

################################################
#
# Private method from here.
#
################################################

def _search_bracket(self, string, bracket="[", replace=None):
if bracket == "[":
pattern = self.re_bracket_start
precheck = self.re_bracket_start_precheck # None
elif bracket == "]":
pattern = self.re_bracket_end
precheck = self.re_bracket_end_precheck
else:
raise ValueError(f"Invalid argument provided. (bracket: {bracket})")

if precheck and precheck.fullmatch(string):
return None if replace is None else string
elif replace is None:
return pattern.search(string)
else:
return pattern.sub(replace, string)

def _numeric_range_to_list(self, string):
match = self.re_range.fullmatch(string)
if match is not None:
if self.converter == int:
start = int(match.group(1))
end = int(match.group(2)) + 1
step = int(match.group(3)) if match.group(3) is not None else 1
return list(range(start, end, step))
else: # float
start = float(match.group(1))
end = float(match.group(2))
step = float(match.group(3)) if match.group(3) is not None else 1
return np.arange(start, end + step, step).tolist()

match = self.re_count.fullmatch(string)
if match is not None:
if self.converter == int:
start = int(match.group(1))
end = int(match.group(2))
num = int(match.group(3)) if match.group(3) is not None else 1
return [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else: # float
start = float(match.group(1))
end = float(match.group(2))
num = int(match.group(3)) if match.group(3) is not None else 1
return np.linspace(start=start, stop=end, num=num).tolist()
return string

def _transpose(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
my_list = [item if isinstance(item, list) else [item] for item in my_list]
self.fill_to_longest(my_list, index=-1)
return np.array(my_list, dtype=object).T.tolist()

################################################
#
# The methods of ListParser class end here.
#
################################################

################################################################
################################################################
#
Expand Down Expand Up @@ -106,86 +340,6 @@ def copy(self):
instance_list.append(instance)
return instance_list

def normalize_list(valslist, type_func=None, allow_blank=True, excluded=None):
"""This function restores a broken list caused by the following process
in the xyz_grid module.
-> valslist = [x.strip() for x in chain.from_iterable(
csv.reader(StringIO(vals)))]
It also performs type conversion,
adjusts the number of elements in the list, and other operations.
"""
def search_bracket(string, bracket="[", replace=None, excluded=None):
if excluded:
excluded = "|".join(excluded)

if bracket == "[":
pattern = rf"^\[(?!(?:{excluded})\])" if excluded else r"^\["
elif bracket == "]":
pattern = rf"(?<!\[(?:{excluded}))\]$" if excluded else r"\]$"
else:
raise ValueError(f"Invalid argument provided. (bracket: {bracket})")

if replace is None:
return re.search(pattern, string)
else:
return re.sub(pattern, replace, string)

def sublist_exists(valslist, excluded=None):
return any(search_bracket(s, excluded=excluded) for s in valslist)

def type_convert(valslist, type_func, allow_blank=True):
for i, s in enumerate(valslist):
if isinstance(s, list):
type_convert(s, type_func, allow_blank)
elif allow_blank and (str(s) in ["None", ""]):
valslist[i] = None
elif type_func:
valslist[i] = type_func(s)
else:
valslist[i] = s

def fix_list_structure(valslist, excluded=None):
def is_same_length(list1, list2):
return len(list1) == len(list2)

start_indices = []
end_indices = []
for i, s in enumerate(valslist):
if is_same_length(start_indices, end_indices):
replace_string = search_bracket(s, "[", replace="", excluded=excluded)
if s != replace_string:
s = replace_string
start_indices.append(i)
if not is_same_length(start_indices, end_indices):
replace_string = search_bracket(s, "]", replace="", excluded=excluded)
if s != replace_string:
s = replace_string
end_indices.append(i + 1)
valslist[i] = s
if not is_same_length(start_indices, end_indices):
raise ValueError(f"Lengths of {start_indices} and {end_indices} are different.")
# Restore the structure of a list.
for i, j in zip(reversed(start_indices), reversed(end_indices)):
valslist[i:j] = [valslist[i:j]]

def pad_to_longest(valslist):
max_length = max(len(sub_list) for sub_list in valslist if isinstance(sub_list, list))
for i, sub_list in enumerate(valslist):
if isinstance(sub_list, list):
valslist[i] = sub_list + [None] * (max_length-len(sub_list))

################################################
# Starting the main process of the normalize_list function.
#
if not sublist_exists(valslist, excluded):
type_convert(valslist, type_func, allow_blank)
return
else:
fix_list_structure(valslist, excluded)
type_convert(valslist, type_func, allow_blank)
pad_to_longest(valslist)
return

################################################
################################################
#
Expand Down Expand Up @@ -233,17 +387,17 @@ def identity(x):
def confirm(func_or_str):
@debug_info
def confirm_(p, xs):
if callable(func_or_str): # func_or_str is type_func
normalize_list(xs, func_or_str, allow_blank=True)
if callable(func_or_str): # func_or_str is converter
ListParser(xs, func_or_str, allow_blank=True)
return

elif isinstance(func_or_str, str): # func_or_str is keyword
valid_data = find_dict(validation_data, func_or_str, stop=True)
type_func = valid_data["type"]
converter = valid_data["type"]
exclude_list = valid_data["exclude"]() if valid_data["exclude"] else None
check_list = valid_data["check"]()

normalize_list(xs, type_func, allow_blank=True, excluded=exclude_list)
ListParser(xs, converter, allow_blank=True, exclude_list=exclude_list)
is_all_included(xs, check_list, allow_blank=True, stop=True)
return

Expand Down

0 comments on commit d314d1b

Please sign in to comment.