blob: 5bd3f4b06c564309b126ee854307cd72d3223523 [file] [log] [blame]
graph(%input_1 : Float(*, *)
%input : Float(*, *)
%cx : Float(*, *)
%weight_1 : Float(*, *)
%weight : Float(*, *)
%bias_1 : Float(*)
%bias : Float(*)) {
%7 : Float(*, *) = aten::t(%weight_1)
%8 : Float(*, *) = aten::mm(%input_1, %7)
%9 : Float(*, *) = aten::t(%weight)
%10 : Float(*, *) = aten::mm(%input, %9)
%11 : Tensor[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%12 : Tensor[] = aten::broadcast_tensors(%11)
%13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
%17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
return (%17, %cy);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%2 : Tensor
%3 : Tensor
%4 : Tensor) {
%5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
%9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
%13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%13, %17, %21)
%23 : int = prim::Constant[value=1]()
%24 : Float(*, *) = aten::add(%14, %18, %23)
%25 : int = prim::Constant[value=1]()
%26 : Float(*, *) = aten::add(%15, %19, %25)
%27 : int = prim::Constant[value=1]()
%28 : Float(*, *) = aten::add(%16, %20, %27)
%29 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%5, %9, %29)
%31 : int = prim::Constant[value=1]()
%32 : Float(*, *) = aten::add(%6, %10, %31)
%33 : int = prim::Constant[value=1]()
%34 : Float(*, *) = aten::add(%7, %11, %33)
%35 : int = prim::Constant[value=1]()
%36 : Float(*, *) = aten::add(%8, %12, %35)
%37 : int = prim::Constant[value=1]()
%38 : Float(*, *) = aten::add(%30, %22, %37)
%39 : int = prim::Constant[value=1]()
%40 : Float(*, *) = aten::add(%32, %24, %39)
%41 : int = prim::Constant[value=1]()
%42 : Float(*, *) = aten::add(%34, %26, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %28, %43)
%ingate : Float(*, *) = aten::sigmoid(%38)
%forgetgate : Float(*, *) = aten::sigmoid(%40)
%cellgate : Float(*, *) = aten::tanh(%42)
%outgate : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%forgetgate, %0)
%50 : Float(*, *) = aten::mul(%ingate, %cellgate)
%51 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%cy)
%54 : Float(*, *) = aten::mul(%outgate, %53)
return (%54, %cy);
}