| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import json |
| import os |
| import re |
| from multiprocessing.connection import Client |
| |
| import numpy as np |
| import piq |
| import torch |
| from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| from executorch.examples.models.edsr import EdsrModel |
| from executorch.examples.qualcomm.utils import ( |
| build_executorch_binary, |
| make_output_dir, |
| parse_skip_delegation_node, |
| setup_common_args_and_variables, |
| SimpleADB, |
| ) |
| |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchsr.datasets import B100 |
| from torchvision.transforms.functional import to_pil_image, to_tensor |
| |
| |
| class SrDataset(Dataset): |
| def __init__(self, hr_dir: str, lr_dir: str): |
| self.input_size = np.asanyarray([224, 224]) |
| self.hr = [] |
| self.lr = [] |
| |
| for file in sorted(os.listdir(hr_dir)): |
| self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) |
| |
| for file in sorted(os.listdir(lr_dir)): |
| self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) |
| |
| if len(self.hr) != len(self.lr): |
| raise AssertionError( |
| "The number of high resolution pics is not equal to low " |
| "resolution pics" |
| ) |
| |
| def __getitem__(self, idx: int): |
| return self.hr[idx], self.lr[idx] |
| |
| def __len__(self): |
| return len(self.lr) |
| |
| def _resize_img(self, file: str, scale: int): |
| with Image.open(file) as img: |
| return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0) |
| |
| def get_input_list(self): |
| input_list = "" |
| for i in range(len(self.lr)): |
| input_list += f"input_{i}_0.raw\n" |
| return input_list |
| |
| |
| def get_b100( |
| dataset_dir: str, |
| ): |
| hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" |
| lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" |
| |
| if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): |
| B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) |
| |
| return SrDataset(hr_dir, lr_dir) |
| |
| |
| def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): |
| if not (lr_dir and hr_dir) and not default_dataset: |
| raise RuntimeError( |
| "Nither custom dataset is provided nor using default dataset." |
| ) |
| |
| if (lr_dir and hr_dir) and default_dataset: |
| raise RuntimeError("Either use custom dataset, or use default dataset.") |
| |
| if default_dataset: |
| return get_b100(dataset_dir) |
| |
| return SrDataset(hr_dir, lr_dir) |
| |
| |
| def main(args): |
| skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) |
| |
| # ensure the working directory exist. |
| os.makedirs(args.artifact, exist_ok=True) |
| |
| if not args.compile_only and args.device is None: |
| raise RuntimeError( |
| "device serial is required if not compile only. " |
| "Please specify a device serial by -s/--device argument." |
| ) |
| |
| dataset = get_dataset( |
| args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact |
| ) |
| |
| inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() |
| pte_filename = "edsr_qnn_q8" |
| instance = EdsrModel() |
| |
| build_executorch_binary( |
| instance.get_eager_model().eval(), |
| (inputs[0],), |
| args.model, |
| f"{args.artifact}/{pte_filename}", |
| [(input,) for input in inputs], |
| skip_node_id_set=skip_node_id_set, |
| skip_node_op_set=skip_node_op_set, |
| quant_dtype=QuantDtype.use_8a8w, |
| shared_buffer=args.shared_buffer, |
| ) |
| |
| if args.compile_only: |
| return |
| |
| adb = SimpleADB( |
| qnn_sdk=os.getenv("QNN_SDK_ROOT"), |
| build_path=f"{args.build_folder}", |
| pte_path=f"{args.artifact}/{pte_filename}.pte", |
| workspace=f"/data/local/tmp/executorch/{pte_filename}", |
| device_id=args.device, |
| host_id=args.host, |
| soc_model=args.model, |
| shared_buffer=args.shared_buffer, |
| ) |
| adb.push(inputs=inputs, input_list=input_list) |
| adb.execute() |
| |
| # collect output data |
| output_data_folder = f"{args.artifact}/outputs" |
| output_pic_folder = f"{args.artifact}/output_pics" |
| make_output_dir(output_data_folder) |
| make_output_dir(output_pic_folder) |
| |
| output_raws = [] |
| |
| def post_process(): |
| cnt = 0 |
| output_shape = tuple(targets[0].size()) |
| for f in sorted( |
| os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) |
| ): |
| filename = os.path.join(output_data_folder, f) |
| if re.match(r"^output_[0-9]+_[1-9].raw$", f): |
| os.remove(filename) |
| else: |
| output = np.fromfile(filename, dtype=np.float32) |
| output = torch.tensor(output).reshape(output_shape).clamp(0, 1) |
| output_raws.append(output) |
| to_pil_image(output.squeeze(0)).save( |
| os.path.join(output_pic_folder, str(cnt) + ".png") |
| ) |
| cnt += 1 |
| |
| adb.pull(output_path=args.artifact, callback=post_process) |
| |
| psnr_list = [] |
| ssim_list = [] |
| for i, hr in enumerate(targets): |
| psnr_list.append(piq.psnr(hr, output_raws[i])) |
| ssim_list.append(piq.ssim(hr, output_raws[i])) |
| |
| avg_PSNR = sum(psnr_list).item() / len(psnr_list) |
| avg_SSIM = sum(ssim_list).item() / len(ssim_list) |
| if args.ip and args.port != -1: |
| with Client((args.ip, args.port)) as conn: |
| conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM})) |
| else: |
| print(f"Average of PNSR is: {avg_PSNR}") |
| print(f"Average of SSIM is: {avg_SSIM}") |
| |
| |
| if __name__ == "__main__": |
| parser = setup_common_args_and_variables() |
| |
| parser.add_argument( |
| "-a", |
| "--artifact", |
| help="path for storing generated artifacts by this example. Default ./edsr", |
| default="./edsr", |
| type=str, |
| ) |
| |
| parser.add_argument( |
| "-r", |
| "--hr_ref_dir", |
| help="Path to the high resolution images", |
| default="", |
| type=str, |
| ) |
| |
| parser.add_argument( |
| "-l", |
| "--lr_dir", |
| help="Path to the low resolution image inputs", |
| default="", |
| type=str, |
| ) |
| |
| parser.add_argument( |
| "-d", |
| "--default_dataset", |
| help="If specified, download and use B100 dataset by torchSR API", |
| action="store_true", |
| default=False, |
| ) |
| |
| args = parser.parse_args() |
| try: |
| main(args) |
| except Exception as e: |
| if args.ip and args.port != -1: |
| with Client((args.ip, args.port)) as conn: |
| conn.send(json.dumps({"Error": str(e)})) |
| else: |
| raise Exception(e) |