blob: 720d3e15e032f78d6fc55908cd40297e6b4ea371 [file] [log] [blame]
# Copyright 2016 The Gemmlowp Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""."""
import common
def _AlignForLanes(lanes_count):
if lanes_count is 8 or lanes_count is 4:
return 256
elif lanes_count is 6 or lanes_count is 2:
return 128
else:
return 64
def _AlignForSums(lanes_count):
if lanes_count is 8:
return 256
elif lanes_count in [2, 4, 6]:
return 128
else:
return 64
def _GenerateInputs(emitter, registers, lanes_count, input_address, stride):
"""."""
inputs = []
last_address_register = input_address
for i in range(lanes_count):
if not i:
inputs.append(input_address)
else:
address_register = registers.GeneralRegister()
inputs.append(address_register)
emitter.EmitAdd(address_register, last_address_register, stride)
last_address_register = address_register
return inputs
def _GenerateClear(emitter, clear_type, block):
for row in block:
emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0))
def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count,
aggregators, inputs, output):
"""Emit inner loop code for reading N lanes and interweaving them."""
emitter.EmitNewline()
emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count,
elements_count))
block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
if elements_count is not 8:
_GenerateClear(emitter, 'i8', block)
for (row, input_address) in zip(block, inputs):
emitter.EmitVLoadE(8, elements_count, row, input_address, None)
for (aggregator, row) in zip(aggregators, block):
emitter.EmitVAddw('u8', aggregator, aggregator, row)
emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
_AlignForLanes(lanes_count))
registers.FreeRegisters(block)
def _LoadMemoryParameter(emitter, registers, name, source):
register = registers.GeneralRegister()
emitter.EmitLdr(register, registers.MapMemoryParameter(name, source))
return register
def _GenerateAggregatorReductionLowRegisters(emitter, registers,
aggregators, output_address):
emitter.EmitNewline()
emitter.EmitComment('Aggregator Reduction.')
_GenerateAggregatorReduction(
emitter, registers, aggregators, output_address,
_LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
_LoadMemoryParameter(emitter, registers, 'additive_sum_offset',
'params.additive_sum_offset'))
def _GenerateAggregatorReductionHighRegisters(emitter, registers,
aggregators, output_address):
emitter.EmitNewline()
emitter.EmitComment('Aggregator Reduction.')
_GenerateAggregatorReduction(
emitter, registers, aggregators, output_address,
registers.MapParameter('multiplicative_sum_offset',
'params.multiplicative_sum_offset'),
registers.MapParameter('additive_sum_offset',
'params.additive_sum_offset'))
def _GenerateAggregatorReduction(emitter, registers, aggregators,
output_address, multiplicative_sum_offset,
additive_sum_offset):
"""Reduce 4 lane sum aggregators to 1 value and store the sums."""
multiplier = registers.DoubleRegister()
emitter.EmitVMov('32',
emitter.Lane(32, multiplier, 0), multiplicative_sum_offset)
offset = registers.QuadRegister()
emitter.EmitVDup('32', offset, additive_sum_offset)
for aggregator in aggregators:
emitter.EmitVPaddl('u16', aggregator, aggregator)
reduced_count = (len(aggregators) + 3) / 4
reduced = aggregators[:reduced_count]
emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
for temp in reduced:
emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0))
for temp in reduced:
emitter.EmitVAdd('i32', temp, temp, offset)
emitter.EmitVStoreA(1, 32, reduced,
emitter.Dereference(output_address,
_AlignForSums(len(aggregators))))
class RowMajorWithSumUInt8x8(common.StreamGenerator):
"""."""
def __init__(self, emitter, asm_emitter):
common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum')
self.asm_emitter = asm_emitter
def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
assert pack_size is 8
assert in_type is 'uint8_t'
registers = self.asm_emitter.CreateRegisters()
self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
self.asm_emitter.PushIndent(self.emitter.indent)
self.asm_emitter.EmitAsmBegin()
count = registers.MapOutputParameter('count', 'params_count_copy')
output = registers.MapOutputParameter('out')
inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count,
registers.MapOutputParameter('in'),
registers.MapParameter('stride', 'params.stride'))
aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
_GenerateClear(self.asm_emitter, 'i16', aggregators)
if leftovers:
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitComment('Reduce count by leftovers.')
self.asm_emitter.EmitSubs(count, count,
self.asm_emitter.ImmediateConstant(leftovers))
self.asm_emitter.EmitBeqFront(2)
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitNumericalLabel(1)
self.asm_emitter.EmitSubs(count, count,
self.asm_emitter.ImmediateConstant(8))
_GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
aggregators, inputs, output)
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitBneBack(1)
if leftovers:
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitNumericalLabel(2)
_GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count,
leftovers, aggregators, inputs, output)
registers.FreeRegisters(inputs)
if len(inputs) <= 6:
_GenerateAggregatorReductionHighRegisters(
self.asm_emitter, registers, aggregators, output)
else:
_GenerateAggregatorReductionLowRegisters(
self.asm_emitter, registers, aggregators, output)
self.asm_emitter.EmitAsmEnd(registers)
self.asm_emitter.PopIndent(len(self.emitter.indent))
def _GenerateColLoadAggregateStore(emitter, registers, lanes_count,
elements_count, aggregators, input_address,
stride, output):
"""Emit inner loop code for reading N col lanes and interweaving them."""
emitter.EmitNewline()
emitter.EmitComment('Load Aggregate Store - column major %dx%d' %
(lanes_count, elements_count))
block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
if elements_count is not 8:
_GenerateClear(emitter, 'i8', block)
block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count,
block, input_address, stride)
for (aggregator, row) in zip(aggregators, block):
emitter.EmitVAddw('u8', aggregator, aggregator, row)
emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
_AlignForLanes(lanes_count))
registers.FreeRegisters(block)
class ColumnMajorWithSumUInt8x8(common.StreamGenerator):
"""."""
def __init__(self, emitter, asm_emitter):
common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum')
self.asm_emitter = asm_emitter
def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
assert pack_size is 8
assert in_type is 'uint8_t'
registers = self.asm_emitter.CreateRegisters()
self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride')
self.asm_emitter.PushIndent(self.emitter.indent)
self.asm_emitter.EmitAsmBegin()
count = registers.MapOutputParameter('count', 'params_count_copy')
input_address = registers.MapOutputParameter('in')
output_address = registers.MapOutputParameter('out')
aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
stride = registers.MapOutputParameter('stride', 'params_stride_copy')
self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride)
_GenerateClear(self.asm_emitter, 'i16', aggregators)
if leftovers:
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitComment('Reduce count by leftovers.')
self.asm_emitter.EmitSubs(count, count,
self.asm_emitter.ImmediateConstant(leftovers))
self.asm_emitter.EmitBeqFront(2)
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitNumericalLabel(1)
self.asm_emitter.EmitSubs(count, count,
self.asm_emitter.ImmediateConstant(8))
_GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
aggregators, input_address, stride,
output_address)
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitBneBack(1)
if leftovers:
self.asm_emitter.EmitNewline()
self.asm_emitter.EmitNumericalLabel(2)
_GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count,
leftovers, aggregators, input_address,
stride, output_address)
_GenerateAggregatorReductionHighRegisters(
self.asm_emitter, registers, aggregators, output_address)
self.asm_emitter.EmitAsmEnd(registers)
self.asm_emitter.PopIndent(len(self.emitter.indent))
def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count):
row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter)
column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter)
for lanes_count in range(1, 1 + lanes_count):
for leftovers in range(8):
row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers)
for lanes_count in range(1, 1 + lanes_count):
for leftovers in range(8):
column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8,
leftovers)