Skip to content

Commit

Permalink
Change to FileHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
Watchful1 committed Jan 28, 2023
1 parent 87d2b22 commit 33b5b93
Showing 1 changed file with 119 additions and 80 deletions.
199 changes: 119 additions & 80 deletions scripts/combine_folder_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,84 @@ def __str__(self):
return f"{self.input_path} : {self.output_path} : {self.file_size} : {self.complete} : {self.bytes_processed} : {self.lines_processed}"


# another convenience object to read and write from both zst files and ndjson files
class FileHandle:
def __init__(self, path):
self.path = path
if self.path.endswith(".zst"):
self.is_compressed = True
elif self.path.endswith(".ndjson"):
self.is_compressed = False
else:
raise TypeError(f"File type not supported for writing {self.path}")

self.write_handle = None
self.other_handle = None
self.newline_encoded = "\n".encode('utf-8')

# recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes
@staticmethod
def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
chunk = reader.read(chunk_size)
bytes_read += chunk_size
if previous_chunk is not None:
chunk = previous_chunk + chunk
try:
return chunk.decode()
except UnicodeDecodeError:
if bytes_read > max_window_size:
raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
return FileHandle.read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)

# open a zst compressed ndjson file, or a regular uncompressed ndjson file and yield lines one at a time
# also passes back file progress
def yield_lines(self):
if self.is_compressed:
with open(self.path, 'rb') as file_handle:
buffer = ''
reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle)
while True:
chunk = FileHandle.read_and_decode(reader, 2**27, (2**29) * 2)
if not chunk:
break
lines = (buffer + chunk).split("\n")

for line in lines[:-1]:
yield line, file_handle.tell()

buffer = lines[-1]
reader.close()

else:
with open(self.path, 'r') as file_handle:
line = file_handle.readline()
while line:
yield line.rstrip('\n'), file_handle.tell()
line = file_handle.readline()

# write a line, opening the appropriate handle
def write_line(self, line):
if self.write_handle is None:
if self.is_compressed:
self.other_handle = open(self.path, 'wb')
self.write_handle = zstandard.ZstdCompressor().stream_writer(self.other_handle)
else:
self.write_handle = open(self.path, 'w', encoding="utf-8")

if self.is_compressed:
self.write_handle.write(line.encode('utf-8'))
self.write_handle.write(self.newline_encoded)
else:
self.write_handle.write(line)
self.write_handle.write("\n")

def close(self):
if self.write_handle:
self.write_handle.close()
if self.other_handle:
self.other_handle.close()


# used for calculating running average of read speed
class Queue:
def __init__(self, max_size):
Expand Down Expand Up @@ -108,72 +186,35 @@ def load_file_list(status_json):
return None, None, None


# recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes
def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
chunk = reader.read(chunk_size)
bytes_read += chunk_size
if previous_chunk is not None:
chunk = previous_chunk + chunk
try:
return chunk.decode()
except UnicodeDecodeError:
if bytes_read > max_window_size:
raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
return read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)


# open a zst compressed ndjson file and yield lines one at a time
# also passes back file progress
def read_lines_zst(file_name):
with open(file_name, 'rb') as file_handle:
buffer = ''
reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle)
while True:
chunk = read_and_decode(reader, 2**27, (2**29) * 2)
if not chunk:
break
lines = (buffer + chunk).split("\n")

for line in lines[:-1]:
yield line, file_handle.tell()

buffer = lines[-1]
reader.close()


# base of each separate process. Loads a file, iterates through lines and writes out
# the ones where the `field` of the object matches `value`. Also passes status
# information back to the parent via a queue
def process_file(file, queue, field, value, values, case_sensitive):
output_file = None
def process_file(file, queue, field, value, values):
queue.put(file)
input_handle = FileHandle(file.input_path)
output_handle = FileHandle(file.output_path)
try:
for line, file_bytes_processed in read_lines_zst(file.input_path):
for line, file_bytes_processed in input_handle.yield_lines():
try:
obj = json.loads(line)
matched = False
observed = obj[field] if case_sensitive else obj[field].lower()
observed = obj[field].lower()
if value is not None:
if observed == value:
matched = True
elif observed in values:
matched = True

if matched:
if output_file is None:
output_file = open(file.output_path, 'w', encoding="utf-8")
output_file.write(line)
output_file.write("\n")
output_handle.write_line(line)
except (KeyError, json.JSONDecodeError) as err:
file.error_lines += 1
file.lines_processed += 1
if file.lines_processed % 1000000 == 0:
file.bytes_processed = file_bytes_processed
queue.put(file)

if output_file is not None:
output_file.close()

output_handle.close()
file.complete = True
file.bytes_processed = file.file_size
except Exception as err:
Expand All @@ -191,8 +232,8 @@ def process_file(file, queue, field, value, values, case_sensitive):
parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value. Supports a comma separated list. This is case sensitive", default="pushshift")
parser.add_argument("--value_list", help="A file of newline separated values to use. Overrides the value param if it is set", default=None)
parser.add_argument("--processes", help="Number of processes to use", default=10, type=int)
parser.add_argument("--case-sensitive", help="Matching should be case sensitive", action="store_true")
parser.add_argument("--file_filter", help="Regex filenames have to match to be processed", default="^rc_|rs_")
parser.add_argument("--compress_intermediate", help="Compress the intermediate files, use if the filter will result in a very large amount of data", action="store_true")
parser.add_argument(
"--error_rate", help=
"Percentage as an integer from 0 to 100 of the lines where the field can be missing. For the subreddit field especially, "
Expand All @@ -201,7 +242,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
script_type = "split"

args = parser.parse_args()
arg_string = f"{args.field}:{args.value}:{args.case_sensitive}"
arg_string = f"{args.field}:{(args.value if args.value else args.value_list)}"

if args.debug:
log.setLevel(logging.DEBUG)
Expand All @@ -212,28 +253,25 @@ def process_file(file, queue, field, value, values, case_sensitive):
else:
log.info(f"Writing output to working folder")

if not args.case_sensitive:
args.value = args.value.lower()

value = None
values = None
if args.value_list:
log.info(f"Reading {args.value_list} for values to compare")
with open(args.value_list, 'r') as value_list_handle:
values = set()
for line in value_list_handle:
values.add(line.strip())
values.add(line.strip().lower())
log.info(f"Comparing {args.field} against {len(values)} values")

else:
value_strings = args.value.split(",")
if len(value_strings) > 1:
values = set()
for value_inner in value_strings:
values.add(value_inner)
values.add(value_inner.lower())
log.info(f"Checking field {args.field} for values {(', '.join(value_strings))}")
elif len(value_strings) == 1:
value = value_strings[0]
value = value_strings[0].lower()
log.info(f"Checking field {args.field} for value {value}")
else:
log.info(f"Invalid value specified, aborting: {args.value}")
Expand All @@ -259,7 +297,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
for file_name in files:
if file_name.endswith(".zst") and re.search(args.file_filter, file_name, re.IGNORECASE) is not None:
input_path = os.path.join(subdir, file_name)
output_path = os.path.join(args.working, file_name[:-4])
output_path = os.path.join(args.working, f"{file_name[:-4]}.{('zst' if args.compress_intermediate else 'ndjson')}")
input_files.append(FileConfig(input_path, output_path=output_path))

save_file_list(input_files, args.working, status_json, arg_string, script_type)
Expand Down Expand Up @@ -295,7 +333,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
log.info(f"Processing file: {file.input_path}")
# start the workers
with multiprocessing.Pool(processes=min(args.processes, len(files_to_process))) as pool:
workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values, args.case_sensitive) for file in files_to_process], chunksize=1, error_callback=log.info)
workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values) for file in files_to_process], chunksize=1, error_callback=log.info)
while not workers.ready():
# loop until the workers are all done, pulling in status messages as they are sent
file_update = queue.get()
Expand Down Expand Up @@ -383,7 +421,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
split = False
for working_file_path in working_file_paths:
files_combined += 1
log.info(f"Reading {files_combined}/{len(working_file_paths)} : {os.path.split(working_file_path)[1]}")
log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}")
working_file_name = os.path.split(working_file_path)[1]
if working_file_name.startswith("RS"):
file_type = "submissions"
Expand All @@ -392,37 +430,38 @@ def process_file(file, queue, field, value, values, case_sensitive):
else:
log.warning(f"Unknown working file type, skipping: {working_file_name}")
continue
input_handle = FileHandle(working_file_path)
if file_type not in output_handles:
output_handles[file_type] = {}
file_type_handles = output_handles[file_type]
with open(working_file_path, 'r') as input_file:
for line in input_file:
output_lines += 1
if split:
obj = json.loads(line)
observed_case = obj[args.field]
else:
observed_case = value
observed = observed_case if args.case_sensitive else observed_case.lower()
if observed not in file_type_handles:
if args.output:
if not os.path.exists(args.output):
os.makedirs(args.output)
output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst")
else:
output_file_path = f"{observed_case}_{file_type}.zst"
log.info(f"Writing to file {output_file_path}")
file_handle = open(output_file_path, 'wb')
writer = zstandard.ZstdCompressor().stream_writer(file_handle)
file_type_handles[observed] = writer
all_handles.append(writer)
all_handles.append(file_handle)

for line, file_bytes_processed in input_handle.yield_lines():
output_lines += 1
if split:
obj = json.loads(line)
observed_case = obj[args.field]
else:
observed_case = value
observed = observed_case.lower()
if observed not in file_type_handles:
if args.output:
if not os.path.exists(args.output):
os.makedirs(args.output)
output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst")
else:
writer = file_type_handles[observed]
output_file_path = f"{observed_case}_{file_type}.zst"
log.debug(f"Writing to file {output_file_path}")
output_handle = FileHandle(output_file_path)
file_type_handles[observed] = output_handle
all_handles.append(output_handle)
else:
output_handle = file_type_handles[observed]

encoded_line = line.encode('utf-8')
writer.write(encoded_line)
output_handle.write_line(line)
if output_lines % 100000 == 0:
log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}")

log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines")
for handle in all_handles:
handle.close()

Expand Down

0 comments on commit 33b5b93

Please sign in to comment.