blob: 3a5911fb03f3af2175ce653cdc43e293235f74da [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# this contains a defineLib that we want
import executorch.exir.passes._quant_patterns_and_replacements # noqa
import torch
from torch.library import impl_abstract
@impl_abstract("quantized_decomposed::embedding_byte")
def embedding_byte_meta(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
):
assert weight.dtype in [
torch.int8,
torch.uint8,
], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}"
assert (
weight.dim() == 2
), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}"
assert weight_scales.dtype in [
torch.float16,
torch.float32,
], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}"
assert (
weight_scales.dim() == 1
), f"Expecting weight_scales tensor to have dim()==1, but found {weight_scales.dim()}"
assert weight_scales.size(0) == weight.size(
0
), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}"
assert (
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
assert (
weight_zero_points is None or weight_zero_points.dim() == 1
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
0
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
output_shape = list(indices.size()) + [weight.size(1)]
return torch.empty(output_shape, device="meta", dtype=weight_scales.dtype)