Skip to content

Commit

Permalink
tensorrtqat modfiied
Browse files Browse the repository at this point in the history
  • Loading branch information
Akash-guna committed Aug 2, 2024
1 parent 1181b04 commit ecb121a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 40 deletions.
3 changes: 3 additions & 0 deletions quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def initialize_initialization(algoname, task):
from .nncf import NNCFQATObjectDetection

return NNCFQATObjectDetection
elif algoname == "TensorRTQAT":
from .tensorrt import TensorRTQAT
return TensorRTQAT

else:
return None
2 changes: 1 addition & 1 deletion quant/tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ptq import TensorRT
#from .ptq import TensorRT
from .qat import TensorRTQAT

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions quant/tensorrt/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .main import TensorRTQAT
78 changes: 42 additions & 36 deletions quant/tensorrt/qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import shutil
from .utils import *
from ....core.utils.mmutils import create_input_image, customize_config
from vision.core.utils.mmutils import create_input_image, customize_config


class TensorRTQAT:
Expand Down Expand Up @@ -44,25 +44,17 @@ def __init__(self, model, loaders=None, **kwargs):
self.val_interval = kwargs.get("VALIDATION_INTERVAL", 1)
self.weight_decay = kwargs.get("WEIGHT_DECAY", 0.0005)
self.model_path = kwargs.get("MODEL_PATH", "models")
self.logging_path = kwargs.get("LOGGING_PATH", "logs")

self.logging_path = logging.getLogger(__name__)
self.work_dir = os.getcwd()
self.logger = logging.getLogger(__name__)
self.logger.info(f"Experiment Arguments: {self.kwargs}")
self.job_id = kwargs.get("JOB_ID", "1")
self.fake_quantize_step = kwargs.get("FAKE")
if self.wandb:
wandb.init(project="Kompress Tensorrt QAT", name=str(self.job_id))
wandb.config.update(self.kwargs)
self.qat = kwargs.get("QAT", False)
self.platform = kwargs.get("PLATFORM", "mmdet")

def list_subdirectories(self, directory):
subdirectories = [
d
for d in os.listdir(directory)
if os.path.isdir(os.path.join(directory, d))
]
return subdirectories

def compress_model(self):
if (
self.custom_model_path
Expand All @@ -84,6 +76,7 @@ def compress_model(self):
self.max_box,
self.pre_top_k,
self.keep_top_k,
self.cache_path
)
config = build_quantization_config(
self.ckpt_path,
Expand All @@ -97,33 +90,46 @@ def compress_model(self):
self.factor,
self.weight_decay,
)
customize_config(config, self.data_path, self.model_path, self.batch_size)
customize_config(config, self.data_path, self.model_path, self.batch_size,self.cache_path)
quant_config_path = f"{self.cache_path}/current_quant_final.py"

os.system(
f"python /mmrazor/tools/train.py current_quant_final.py --work-dir {self.job_path}"
)
self.quantized_pth_location = None
if self.fake_quantize_step == True:
os.system(
f"python {self.work_dir}/vision/core/utils/mmrazortrain.py {quant_config_path} --work-dir {self.job_path}"
)
self.quantized_pth_location = None

if "last_checkpoint" in os.listdir(self.model_path):
with open(os.path.join(self.model_path, "last_checkpoint"), "r") as f:
self.quantized_pth_location = f.readline()
self.logger.info(
f"Fake Quantized pth is present at {self.quantized_pth_location}"
)
if "last_checkpoint" in os.listdir(self.model_path):
with open(os.path.join(self.model_path, "last_checkpoint"), "r") as f:
self.quantized_pth_location = f.readline()
self.logger.info(
f"Fake Quantized pth is present at {self.quantized_pth_location}"
)

if self.quantized_pth_location == None:
self.logger.info("Fake Quantization Unsuccessful, checkpoint not found")
raise Exception("Fake Quantization Unsuccessful, checkpoint not found")
if self.quantized_pth_location == None:
self.logger.info("Fake Quantization Unsuccessful, checkpoint not found")
raise Exception("Fake Quantization Unsuccessful, checkpoint not found")
else:
self.logger.info("Fake Quantization Successful")
else:
self.logger.info("Fake Quantization Successful")
# deply config
build_mmdeploy_config(self.imsize)
create_input_image(self.loaders["test"])
os.system(
f"docker exec -it Nyun_Kompress python mmdeploy/tools/deploy.py current_tensorrt_deploy_config.py current_quant_final.py {self.quantized_pth_location} demo_image.png"
)
self.logger.info("Deployment Successful")
shutil.move("/workspace/end2end.xml", os.path.join(self.model_path, "mds.xml"))
shutil.move("/workspace/end2end.bin", os.path.join(self.model_path, "mds.bin"))
# deply config
self.quantized_pth_location = self.kwargs.get("FAKE_QUANTIZED_PATH", "")

if self.quantized_pth_location == None:
self.logger.info(f"Fake Quantized Path is None")
raise Exception("Fake Quantized Path is None")
elif not os.path.exists(self.quantized_pth_location):
self.logger.info(f"Fake Quantized Path is Not Present at {self.quantized_pth_location}")
raise Exception(f"Fake Quantized Path is Not Present at {self.quantized_pth_location}")
build_mmdeploy_config(self.imsize,self.cache_path)
create_input_image(self.loaders["test"],self.cache_path)
deploy_config_path = f"{self.cache_path}/current_tensorrt_deploy_config.py"
demo_image_path = f"{self.cache_path}/demo_image.png"
os.system(
f"python {self.work_dir}/vision/core/utils/mmrazordeploy.py {deploy_config_path} {quant_config_path} {self.quantized_pth_location} {demo_image_path}"
)
self.logger.info("Deployment Successful")
shutil.move("end2end.xml", os.path.join(self.model_path, "mds.xml"))
shutil.move("end2end.bin", os.path.join(self.model_path, "mds.bin"))

return self.model, __name__
7 changes: 4 additions & 3 deletions quant/tensorrt/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def write_deploy_cfg(
max_box,
pre_top_k,
keep_top_k,
cache_path
):
cfg = f"""deploy_cfg = dict(
onnx_config=dict(
Expand Down Expand Up @@ -66,7 +67,7 @@ def write_deploy_cfg(
'mmdet.models.detectors.single_stage_instance_seg.SingleStageInstanceSegmentor.forward', # noqa: E501
'torch.cat'
])"""
with open("current_base_tensorrt_deploy_config.py", "w") as f:
with open(f"{cache_path}current_base_tensorrt_deploy_config.py", "w") as f:
f.write(cfg)


Expand Down Expand Up @@ -151,7 +152,7 @@ def build_quantization_config(
custom_hooks = []"""


def build_mmdeploy_config(insize):
def build_mmdeploy_config(insize,cache_path):
config = f"""_base_ = ['../_base_/base_static.py', '../../_base_/backends/tensorrt-int8.py']
onnx_config = dict(input_shape=(320, 320))
Expand All @@ -166,5 +167,5 @@ def build_mmdeploy_config(insize):
opt_shape=[1, 3, {insize}, {insize}],
max_shape=[1, 3, {insize}, {insize}])))
])"""
with open("current_tensorrt_deploy_config.py", "w") as f:
with open(f"{cache_path}/current_tensorrt_deploy_config.py", "w") as f:
f.write(config)

0 comments on commit ecb121a

Please sign in to comment.