defcmd_sensitive_analysis(weight, device, cocodir, summary_save, num_image): ... print("Sensitive analysis by each layer...") for i inrange(0, len(model.model)): layer = model.model[i] if quantize.have_quantizer(layer): print(f"Quantization disable model.{i}") quantize.disable_quantization(layer).apply() ap = evaluate_coco(model, val_dataloader) summary.append([ap, f"model.{i}"]) quantize.enable_quantization(layer).apply() else: print(f"ignore model.{i} because it is {type(layer)}") ....
def build_sensitivity_profile(model, data_loader_val, dataset_val, eval_model_callback : Callable = None): quant_layer_names = [] for name, module in model.named_modules(): if name.endswith("_quantizer"): print('use quant layer:{}',name) module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) for i, quant_layer in enumerate(quant_layer_names): print("Enable", quant_layer) for name, module in model.named_modules(): if name.endswith("_quantizer") and quant_layer in name: module.enable() print(F"{name:40}: {module}") with torch.no_grad(): eval_model_callback(model,data_loader_val, dataset_val) for name, module in model.named_modules(): if name.endswith("_quantizer") and quant_layer in name: module.disable() print(F"{name:40}: {module}")
初始化:quantize.initialize()
代码如下
1 2 3 4 5 6 7 8 9 10 11 12
definitialize(): quant_desc_input = QuantDescriptor(calib_method="histogram")#max or histogram other methods are all hisogram based. Default "max". quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) #quant_desc_input = QuantDescriptor(calib_method="max")#["max", "histogram"] #quant_desc_weight = QuantDescriptor(calib_method="max")#["max", "histogram"] #quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) #quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight) quant_logging.set_verbosity(quant_logging.ERROR)
module_dict = {} for entry in quant_modules._DEFAULT_QUANT_MAP: module = getattr(entry.orig_mod, entry.mod_name) module_dict[id(module)] = entry.replace_mod
defrecursive_and_replace_module(module, prefix=""): if module isNone: print(f"[WARNING] Module at '{prefix}' is None, skipping...") return for name in module._modules: submodule = module._modules[name] path = name if prefix == ""else prefix + "." + name recursive_and_replace_module(submodule, path)
submodule_id = id(type(submodule)) if submodule_id in module_dict: ignored = quantization_ignore_match(ignore_policy, path) if ignored: print(f"Quantization: {path} has ignored.") continue
model = onnx.load(onnx_file) match_pairs = [] for node in model.graph.node: if node.op_type == "Concat": qnodes = find_all_with_input_node(model, node.output[0]) major = None for qnode in qnodes: if qnode.op_type != "QuantizeLinear": continue
conv = find_quantizelinear_conv(model, qnode) if major isNone: major = find_quantize_conv_name(model, conv.input[1]) else: match_pairs.append([major, find_quantize_conv_name(model, conv.input[1])])
for subnode in model.graph.node: iflen(subnode.input) > 0and subnode.op_type == "QuantizeLinear"and subnode.input[0] in node.input: subconv = find_quantizelinear_conv(model, subnode) match_pairs.append([major, find_quantize_conv_name(model, subconv.input[1])])
defapply_custom_rules_to_quantizer(model: torch.nn.Module, export_onnx: Callable): # apply rules to graph export_onnx(model, "quantization-custom-rules-temp.onnx") pairs = find_quantizer_pairs("quantization-custom-rules-temp.onnx") for major, sub in pairs: print(f"Rules: {sub} match to {major}") if sub in [ "model.img_backbone.stage1.1.blocks.0.addop", "model.img_backbone.stage2.1.blocks.0.addop", "model.img_backbone.stage2.1.blocks.1.addop", "model.img_backbone.stage3.1.blocks.0.addop", "model.img_backbone.stage3.1.blocks.1.addop", "model.img_backbone.stage4.1.blocks.0.addop", ]:
major = get_attr_with_path(model, major)._input_quantizer get_attr_with_path(model, sub)._input0_quantizer = major get_attr_with_path(model, sub)._input1_quantizer = major else: get_attr_with_path(model, sub)._input_quantizer = get_attr_with_path(model, major)._input_quantizer os.remove("quantization-custom-rules-temp.onnx")
defcompute_amax(model, **kwargs): for name, module in model.named_modules(): ifisinstance(module, quant_nn.TensorQuantizer): if module._calibrator isnotNone: ifisinstance(module._calibrator, calib.MaxCalibrator): module.load_calib_amax() else: module.load_calib_amax(**kwargs)
module._amax = module._amax.to(device) defcollect_stats(model, data_loader, device, num_batch=200): """Feed data to the network and collect statistics""" # Enable calibrators model.eval() for name, module in model.named_modules(): ifisinstance(module, quant_nn.TensorQuantizer): if module._calibrator isnotNone: module.disable_quant() module.enable_calib() else: module.disable()
# Feed data to the network for collecting stats with torch.no_grad(): for i, datas in tqdm(enumerate(data_loader), total=num_batch, desc="Collect stats for calibrating"): imgs = datas[0].to(device, non_blocking=True).float() / 255.0 model(imgs)
if i >= num_batch: break
# Disable calibrators for name, module in model.named_modules(): ifisinstance(module, quant_nn.TensorQuantizer): if module._calibrator isnotNone: module.enable_quant() module.disable_calib() else: module.enable()
defbuild_sensitivity_profile(model, data_loader_val, dataset_val, eval_model_callback : Callable = None): quant_layer_names = [] for name, module in model.named_modules(): if name.endswith("_quantizer"): print('use quant layer:{}',name) module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name notin quant_layer_names: quant_layer_names.append(layer_name) for i, quant_layer inenumerate(quant_layer_names): print("Enable", quant_layer) for name, module in model.named_modules(): if name.endswith("_quantizer") and quant_layer in name: module.enable() print(F"{name:40}: {module}") with torch.no_grad(): eval_model_callback(model,data_loader_val, dataset_val) for name, module in model.named_modules(): if name.endswith("_quantizer") and quant_layer in name: module.disable() print(F"{name:40}: {module}")