| graph(%x : Float(*, *) |
| %hx : Float(*, *) |
| %cx : Float(*, *) |
| %w_ih : Float(*, *) |
| %w_hh : Float(*, *) |
| %b_ih : Float(*) |
| %b_hh : Float(*)) { |
| %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) |
| return (%hy, %cy); |
| } |
| with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) |
| %1 : Float(*) |
| %2 : Float(*) |
| %3 : Float(*, *) |
| %4 : Float(*, *) |
| %5 : Float(*, *) |
| %6 : Float(*, *)) { |
| %7 : Float(*, *) = aten::t(%6) |
| %8 : Float(*, *) = aten::mm(%5, %7) |
| %9 : Float(*, *) = aten::t(%4) |
| %10 : Float(*, *) = aten::mm(%3, %9) |
| %11 : int[] = aten::size(%8) |
| %12 : int[] = aten::size(%10) |
| %13 : int[] = aten::size(%2) |
| %14 : int[] = aten::size(%1) |
| %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10) |
| %16 : Tensor[] = aten::broadcast_tensors(%15) |
| %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16) |
| %21 : int[] = prim::BroadcastSizes(%11, %12) |
| %22 : int[] = prim::BroadcastSizes(%21, %13) |
| %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) |
| %30 : int[] = aten::size(%0) |
| %31 : int[] = aten::size(%cellgate.1) |
| %32 : int[] = aten::size(%forgetgate.1) |
| %33 : int[] = aten::size(%ingate.1) |
| %34 : int[] = prim::BroadcastSizes(%32, %30) |
| %35 : int[] = prim::BroadcastSizes(%33, %31) |
| return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24); |
| } |
| with prim::FusionGroup_0 = graph(%0 : Float(*, *) |
| %1 : Tensor |
| %2 : Tensor |
| %3 : Tensor |
| %4 : Tensor) { |
| %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4) |
| %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3) |
| %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) |
| %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) |
| %21 : int = prim::Constant[value=1]() |
| %22 : Float(*, *) = aten::add(%13, %17, %21) |
| %23 : int = prim::Constant[value=1]() |
| %24 : Float(*, *) = aten::add(%14, %18, %23) |
| %25 : int = prim::Constant[value=1]() |
| %26 : Float(*, *) = aten::add(%15, %19, %25) |
| %27 : int = prim::Constant[value=1]() |
| %28 : Float(*, *) = aten::add(%16, %20, %27) |
| %29 : int = prim::Constant[value=1]() |
| %30 : Float(*, *) = aten::add(%22, %9, %29) |
| %31 : int = prim::Constant[value=1]() |
| %32 : Float(*, *) = aten::add(%24, %10, %31) |
| %33 : int = prim::Constant[value=1]() |
| %34 : Float(*, *) = aten::add(%26, %11, %33) |
| %35 : int = prim::Constant[value=1]() |
| %36 : Float(*, *) = aten::add(%28, %12, %35) |
| %37 : int = prim::Constant[value=1]() |
| %38 : Float(*, *) = aten::add(%30, %5, %37) |
| %39 : int = prim::Constant[value=1]() |
| %40 : Float(*, *) = aten::add(%32, %6, %39) |
| %41 : int = prim::Constant[value=1]() |
| %42 : Float(*, *) = aten::add(%34, %7, %41) |
| %43 : int = prim::Constant[value=1]() |
| %44 : Float(*, *) = aten::add(%36, %8, %43) |
| %ingate.1 : Float(*, *) = aten::sigmoid(%38) |
| %forgetgate.1 : Float(*, *) = aten::sigmoid(%40) |
| %cellgate.1 : Float(*, *) = aten::tanh(%42) |
| %outgate.1 : Float(*, *) = aten::sigmoid(%44) |
| %49 : Float(*, *) = aten::mul(%forgetgate.1, %0) |
| %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) |
| %51 : int = prim::Constant[value=1]() |
| %cy : Float(*, *) = aten::add(%49, %50, %51) |
| %53 : Float(*, *) = aten::tanh(%cy) |
| %hy : Float(*, *) = aten::mul(%outgate.1, %53) |
| return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); |
| } |