blob: f35642a148e3b3c1258302999a2cd4a2a2a6e289 [file] [log] [blame]
#import <XCTest/XCTest.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/script.h>
@interface TestAppTests : XCTestCase
@end
@implementation TestAppTests {
}
- (void)testLiteInterpreter {
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_lite"
ofType:@"ptl"];
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
c10::InferenceMode mode;
auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
auto outputTensor = module.forward({input}).toTensor();
XCTAssertTrue(outputTensor.numel() == 1000);
}
- (void)testCoreML {
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml"
ofType:@"ptl"];
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
c10::InferenceMode mode;
auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
auto outputTensor = module.forward({input}).toTensor();
XCTAssertTrue(outputTensor.numel() == 1000);
}
@end