blob: 77f8d4c804fe0f0afe2f4bec41d0fc461279d71a [file] [log] [blame]
graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : UndefinedTensor,
%3 : UndefinedTensor,
%4 : UndefinedTensor,
%5 : UndefinedTensor,
%6 : UndefinedTensor,
%7 : UndefinedTensor,
%8 : UndefinedTensor,
%9 : Float(*, *),
%10 : Float(*, *),
%11 : Float(*, *),
%12 : Float(*, *),
%13 : Float(*, *),
%14 : int[],
%15 : int[],
%16 : int[],
%17 : int[],
%18 : int[],
%19 : int[],
%ingate : Float(*, *),
%forgetgate : Float(*, *),
%cellgate : Float(*, *),
%outgate : Float(*, *),
%24 : int[],
%25 : int[],
%26 : Float(*, *)):
%27 : int = prim::Constant[value=1]()
%28 : int[] = aten::size(%outgate)
%29 : int[] = aten::size(%26)
%30 : int[] = aten::size(%ingate)
%31 : int[] = aten::size(%cellgate)
%32 : int[] = aten::size(%forgetgate)
%33 : int[] = aten::size(%9)
%34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
%grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
%39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
%40 : Tensor = aten::cat(%39, %27)
%41 : Tensor = aten::_grad_sum_to_size(%40, %19)
%42 : Tensor = aten::_grad_sum_to_size(%40, %17)
%43 : Tensor = aten::_grad_sum_to_size(%40, %14)
%44 : Tensor = aten::_grad_sum_to_size(%40, %15)
%45 : Float(*, *) = aten::t(%13)
%grad_self.7 : Float(*, *) = aten::mm(%44, %45)
%47 : Float(*, *) = aten::t(%10)
%grad_mat2.1 : Float(*, *) = aten::mm(%47, %44)
%grad_self.9 : Float(*, *) = aten::t(%grad_mat2.1)
%50 : Float(*, *) = aten::t(%12)
%grad_self.11 : Float(*, *) = aten::mm(%43, %50)
%52 : Float(*, *) = aten::t(%11)
%grad_mat2.3 : Float(*, *) = aten::mm(%52, %43)
%grad_self.13 : Float(*, *) = aten::t(%grad_mat2.3)
return (%grad_other.5, %41, %42, %grad_self.7, %grad_self.9, %grad_self.11, %grad_self.13)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : Float(*, *),
%3 : int[]):
%4 : int = prim::Constant[value=1]()
%5 : Float(*, *) = aten::mul(%1, %2)
%grad_self.1 : Tensor = aten::_grad_sum_to_size(%5, %3)
%7 : Float(*, *) = aten::neg(%0)
%8 : Float(*, *) = aten::add(%7, %4, %4)
%9 : Float(*, *) = aten::mul(%8, %0)
%10 : Tensor = aten::mul(%9, %grad_self.1)
return (%10)
with prim::FusionGroup_1 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : Float(*, *),
%3 : Float(*, *),
%4 : Float(*, *),
%5 : Float(*, *),
%6 : Float(*, *),
%7 : Float(*, *),
%8 : int[],
%9 : int[],
%10 : int[],
%11 : int[],
%12 : int[],
%13 : int[],
%14 : int[]):
%15 : int = prim::Constant[value=1]()
%16 : Float(*, *) = aten::neg(%0)
%17 : Float(*, *) = aten::add(%16, %15, %15)
%18 : Float(*, *) = aten::mul(%17, %0)
%19 : Float(*, *) = aten::mul(%3, %3)
%20 : Float(*, *) = aten::neg(%19)
%21 : Float(*, *) = aten::add(%20, %15, %15)
%22 : Float(*, *) = aten::mul(%6, %7)
%grad_other.1 : Tensor = aten::_grad_sum_to_size(%22, %14)
%24 : Float(*, *) = aten::mul(%5, %5)
%25 : Float(*, *) = aten::neg(%24)
%26 : Float(*, *) = aten::add(%25, %15, %15)
%27 : Tensor = aten::mul(%grad_other.1, %26)
%28 : Tensor = aten::add(%4, %27, %15)
%29 : Tensor = aten::_grad_sum_to_size(%28, %13)
%30 : Tensor = aten::mul(%29, %3)
%grad_self.3 : Tensor = aten::_grad_sum_to_size(%30, %12)
%32 : Float(*, *) = aten::neg(%2)
%33 : Float(*, *) = aten::add(%32, %15, %15)
%34 : Float(*, *) = aten::mul(%33, %2)
%35 : Tensor = aten::mul(%34, %grad_self.3)
%36 : Tensor = aten::mul(%29, %2)
%grad_other.3 : Tensor = aten::_grad_sum_to_size(%36, %11)
%38 : Tensor = aten::mul(%grad_other.3, %21)
%39 : Tensor = aten::_grad_sum_to_size(%28, %10)
%40 : Tensor = aten::mul(%39, %1)
%grad_self.5 : Tensor = aten::_grad_sum_to_size(%40, %9)
%42 : Tensor = aten::mul(%18, %grad_self.5)
%43 : Tensor = aten::mul(%39, %0)
%grad_other.5 : Tensor = aten::_grad_sum_to_size(%43, %8)
return (%grad_other.5, %42, %38, %35)