blob: 4540043a1636f43834ec496ffef1e78444ba312b [file] [log] [blame]
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
class TensorTests: XCTestCase {
// MARK: - Tensor
func testTensor_Init() {
let name = "InputTensor"
let dataType: TensorDataType = .uInt8
let shape = TensorShape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let inputTensor = Tensor(
name: name,
dataType: dataType,
shape: shape,
data: data,
quantizationParameters: quantizationParameters
)
XCTAssertEqual(inputTensor.name, name)
XCTAssertEqual(inputTensor.dataType, dataType)
XCTAssertEqual(inputTensor.shape, shape)
XCTAssertEqual(inputTensor.data, data)
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
}
// MARK: - TensorShape
func testTensorShape_InitWithArray() {
let shape = TensorShape(Constant.dimensions)
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
func testTensorShape_InitWithElements() {
let shape = TensorShape(2, 2, 3)
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
func testTensorShape_InitWithArrayLiteral() {
let shape: TensorShape = [2, 2, 3]
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
}
// MARK: - Constants
private enum Constant {
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
static let dimensions = [2, 2, 3]
}
// MARK: - Extensions
extension TensorShape: Equatable {
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
}
}
extension Tensor: Equatable {
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
}
}