Fix VideoInputOp memory leak

Summary: VideoInputOp has memory leak

Differential Revision: D5193802

fbshipit-source-id: a48e309b845e84ec83875119646bbb6f926ac755
diff --git a/caffe2/video/video_decoder.cc b/caffe2/video/video_decoder.cc
index bfef41c..a30ce9b 100644
--- a/caffe2/video/video_decoder.cc
+++ b/caffe2/video/video_decoder.cc
@@ -40,307 +40,346 @@
   AVPacket packet;
   av_init_packet(&packet); // init packet
   SwsContext* scaleContext_ = nullptr;
+  try {
+    inputContext->pb = ioctx.get_avio();
+    inputContext->flags |= AVFMT_FLAG_CUSTOM_IO;
+    int ret = 0;
 
-  inputContext->pb = ioctx.get_avio();
-  inputContext->flags |= AVFMT_FLAG_CUSTOM_IO;
-  int ret = 0;
+    // Determining the input format:
+    int probeSz = 32 * 1024 + AVPROBE_PADDING_SIZE;
+    DecodedFrame::AvDataPtr probe((uint8_t*)av_malloc(probeSz));
 
-  // Determining the input format:
-  int probeSz = 32 * 1024 + AVPROBE_PADDING_SIZE;
-  DecodedFrame::AvDataPtr probe((uint8_t*)av_malloc(probeSz));
+    memset(probe.get(), 0, probeSz);
+    int len = ioctx.read(probe.get(), probeSz - AVPROBE_PADDING_SIZE);
+    if (len < probeSz - AVPROBE_PADDING_SIZE) {
+      LOG(ERROR) << "Insufficient data to determine video format";
+    }
 
-  memset(probe.get(), 0, probeSz);
-  int len = ioctx.read(probe.get(), probeSz - AVPROBE_PADDING_SIZE);
-  if (len < probeSz - AVPROBE_PADDING_SIZE) {
-    LOG(ERROR) << "Insufficient data to determine video format";
-  }
+    // seek back to start of stream
+    ioctx.seek(0, SEEK_SET);
 
-  // seek back to start of stream
-  ioctx.seek(0, SEEK_SET);
+    unique_ptr<AVProbeData> probeData(new AVProbeData());
+    probeData->buf = probe.get();
+    probeData->buf_size = len;
+    probeData->filename = "";
+    // Determine the input-format:
+    inputContext->iformat = av_probe_input_format(probeData.get(), 1);
 
-  unique_ptr<AVProbeData> probeData(new AVProbeData());
-  probeData->buf = probe.get();
-  probeData->buf_size = len;
-  probeData->filename = "";
-  // Determine the input-format:
-  inputContext->iformat = av_probe_input_format(probeData.get(), 1);
+    ret = avformat_open_input(&inputContext, "", nullptr, nullptr);
+    if (ret < 0) {
+      LOG(ERROR) << "Unable to open stream " << ffmpegErrorStr(ret);
+    }
 
-  ret = avformat_open_input(&inputContext, "", nullptr, nullptr);
-  if (ret < 0) {
-    LOG(ERROR) << "Unable to open stream " << ffmpegErrorStr(ret);
-  }
+    ret = avformat_find_stream_info(inputContext, nullptr);
+    if (ret < 0) {
+      LOG(ERROR) << "Unable to find stream info in " << videoName << " "
+                 << ffmpegErrorStr(ret);
+    }
 
-  ret = avformat_find_stream_info(inputContext, nullptr);
-  if (ret < 0) {
-    LOG(ERROR) << "Unable to find stream info in " << videoName << " "
-               << ffmpegErrorStr(ret);
-  }
-
-  // Decode the first video stream
-  int videoStreamIndex_ = params.streamIndex_;
-  if (videoStreamIndex_ == -1) {
-    for (int i = 0; i < inputContext->nb_streams; i++) {
-      auto stream = inputContext->streams[i];
-      if (stream->codec->codec_type == AVMEDIA_TYPE_VIDEO) {
-        videoStreamIndex_ = i;
-        videoStream_ = stream;
-        break;
+    // Decode the first video stream
+    int videoStreamIndex_ = params.streamIndex_;
+    if (videoStreamIndex_ == -1) {
+      for (int i = 0; i < inputContext->nb_streams; i++) {
+        auto stream = inputContext->streams[i];
+        if (stream->codec->codec_type == AVMEDIA_TYPE_VIDEO) {
+          videoStreamIndex_ = i;
+          videoStream_ = stream;
+          break;
+        }
       }
     }
-  }
 
-  if (videoStream_ == nullptr) {
-    LOG(ERROR) << "Unable to find video stream in " << videoName << " "
-               << ffmpegErrorStr(ret);
-  }
+    if (videoStream_ == nullptr) {
+      LOG(ERROR) << "Unable to find video stream in " << videoName << " "
+                 << ffmpegErrorStr(ret);
+    }
 
-  // Initialize codec
-  videoCodecContext_ = videoStream_->codec;
+    // Initialize codec
+    videoCodecContext_ = videoStream_->codec;
 
-  ret = avcodec_open2(
-      videoCodecContext_,
-      avcodec_find_decoder(videoCodecContext_->codec_id),
-      nullptr);
-  if (ret < 0) {
-    LOG(ERROR) << "Cannot open video codec : "
-               << videoCodecContext_->codec->name;
-  }
+    ret = avcodec_open2(
+        videoCodecContext_,
+        avcodec_find_decoder(videoCodecContext_->codec_id),
+        nullptr);
+    if (ret < 0) {
+      LOG(ERROR) << "Cannot open video codec : "
+                 << videoCodecContext_->codec->name;
+    }
 
-  // Calcuate if we need to rescale the frames
-  int outWidth = videoCodecContext_->width;
-  int outHeight = videoCodecContext_->height;
+    // Calcuate if we need to rescale the frames
+    int outWidth = videoCodecContext_->width;
+    int outHeight = videoCodecContext_->height;
 
-  if (params.maxOutputDimension_ != -1) {
-    if (videoCodecContext_->width > videoCodecContext_->height) {
-      // dominant width
-      if (params.maxOutputDimension_ < videoCodecContext_->width) {
-        float ratio =
-            (float)params.maxOutputDimension_ / videoCodecContext_->width;
-        outWidth = params.maxOutputDimension_;
-        outHeight = (int)round(videoCodecContext_->height * ratio);
+    if (params.maxOutputDimension_ != -1) {
+      if (videoCodecContext_->width > videoCodecContext_->height) {
+        // dominant width
+        if (params.maxOutputDimension_ < videoCodecContext_->width) {
+          float ratio =
+              (float)params.maxOutputDimension_ / videoCodecContext_->width;
+          outWidth = params.maxOutputDimension_;
+          outHeight = (int)round(videoCodecContext_->height * ratio);
+        }
+      } else {
+        // dominant height
+        if (params.maxOutputDimension_ < videoCodecContext_->height) {
+          float ratio =
+              (float)params.maxOutputDimension_ / videoCodecContext_->height;
+          outWidth = (int)round(videoCodecContext_->width * ratio);
+          outHeight = params.maxOutputDimension_;
+        }
       }
     } else {
-      // dominant height
-      if (params.maxOutputDimension_ < videoCodecContext_->height) {
-        float ratio =
-            (float)params.maxOutputDimension_ / videoCodecContext_->height;
-        outWidth = (int)round(videoCodecContext_->width * ratio);
-        outHeight = params.maxOutputDimension_;
-      }
-    }
-  } else {
-    outWidth = params.outputWidth_ == -1 ? videoCodecContext_->width
-                                         : params.outputWidth_;
-    outHeight = params.outputHeight_ == -1 ? videoCodecContext_->height
-                                           : params.outputHeight_;
-  }
-
-  // Make sure that we have a valid format
-  CAFFE_ENFORCE_NE(videoCodecContext_->pix_fmt, AV_PIX_FMT_NONE);
-
-  // Create a scale context
-  scaleContext_ = sws_getContext(
-      videoCodecContext_->width,
-      videoCodecContext_->height,
-      videoCodecContext_->pix_fmt,
-      outWidth,
-      outHeight,
-      pixFormat,
-      SWS_FAST_BILINEAR,
-      nullptr,
-      nullptr,
-      nullptr);
-
-  // Getting video meta data
-  VideoMeta videoMeta;
-  videoMeta.codec_type = videoCodecContext_->codec_type;
-  videoMeta.width = outWidth;
-  videoMeta.height = outHeight;
-  videoMeta.pixFormat = pixFormat;
-  videoMeta.fps = av_q2d(videoStream_->avg_frame_rate);
-
-  // If sampledFrames is not empty, empty it
-  if (sampledFrames.size() > 0) {
-    sampledFrames.clear();
-  }
-
-  if (params.intervals_.size() == 0) {
-    LOG(ERROR) << "Empty sampling intervals.";
-  }
-
-  std::vector<SampleInterval>::const_iterator itvlIter =
-      params.intervals_.begin();
-  if (itvlIter->timestamp != 0) {
-    LOG(ERROR) << "Sampling interval starting timestamp is not zero.";
-  }
-
-  double currFps = itvlIter->fps;
-  if (currFps < 0 && currFps != SpecialFps::SAMPLE_ALL_FRAMES &&
-      currFps != SpecialFps::SAMPLE_TIMESTAMP_ONLY) {
-    // fps must be 0, -1, -2 or > 0
-    LOG(ERROR) << "Invalid sampling fps.";
-  }
-
-  double prevTimestamp = itvlIter->timestamp;
-  itvlIter++;
-  if (itvlIter != params.intervals_.end() &&
-      prevTimestamp >= itvlIter->timestamp) {
-    LOG(ERROR) << "Sampling interval timestamps must be strictly ascending.";
-  }
-
-  double lastFrameTimestamp = -1.0;
-  double timestamp = -1.0;
-
-  // Initialize frame and packet.
-  // These will be reused across calls.
-  videoStreamFrame_ = av_frame_alloc();
-
-  // frame index in video stream
-  int frameIndex = -1;
-  // frame index of outputed frames
-  int outputFrameIndex = -1;
-
-  int gotPicture = 0;
-  int eof = 0;
-
-  // There is a delay between reading packets from the
-  // transport and getting decoded frames back.
-  // Therefore, after EOF, continue going while
-  // the decoder is still giving us frames.
-  while (!eof || gotPicture) {
-    if (!eof) {
-      ret = av_read_frame(inputContext, &packet);
-
-      if (ret == AVERROR(EAGAIN)) {
-        continue;
-      }
-      // Interpret any other error as EOF
-      if (ret < 0) {
-        eof = 1;
-        continue;
-      }
-
-      // Ignore packets from other streams
-      if (packet.stream_index != videoStreamIndex_) {
-        continue;
-      }
+      outWidth = params.outputWidth_ == -1 ? videoCodecContext_->width
+                                           : params.outputWidth_;
+      outHeight = params.outputHeight_ == -1 ? videoCodecContext_->height
+                                             : params.outputHeight_;
     }
 
-    ret = avcodec_decode_video2(
-        videoCodecContext_, videoStreamFrame_, &gotPicture, &packet);
-    if (ret < 0) {
-      LOG(ERROR) << "Error decoding video frame : " << ffmpegErrorStr(ret);
-    }
+    // Make sure that we have a valid format
+    CAFFE_ENFORCE_NE(videoCodecContext_->pix_fmt, AV_PIX_FMT_NONE);
 
-    // Nothing to do without a picture
-    if (!gotPicture) {
-      continue;
-    }
-
-    frameIndex++;
-
-    timestamp = av_frame_get_best_effort_timestamp(videoStreamFrame_) *
-        av_q2d(videoStream_->time_base);
-
-    // if reaching the next interval, update the current fps
-    // and reset lastFrameTimestamp so the current frame could be sampled
-    // (unless fps == SpecialFps::SAMPLE_NO_FRAME)
-    if (itvlIter != params.intervals_.end() &&
-        timestamp >= itvlIter->timestamp) {
-      lastFrameTimestamp = -1.0;
-      currFps = itvlIter->fps;
-      prevTimestamp = itvlIter->timestamp;
-      itvlIter++;
-      if (itvlIter != params.intervals_.end() &&
-          prevTimestamp >= itvlIter->timestamp) {
-        LOG(ERROR)
-            << "Sampling interval timestamps must be strictly ascending.";
-      }
-    }
-
-    // keyFrame will bypass all checks on fps sampling settings
-    bool keyFrame = params.keyFrames_ && videoStreamFrame_->key_frame;
-    if (!keyFrame) {
-      // if fps == SpecialFps::SAMPLE_NO_FRAME (0), don't sample at all
-      if (currFps == SpecialFps::SAMPLE_NO_FRAME) {
-        continue;
-      }
-
-      // fps is considered reached in the following cases:
-      // 1. lastFrameTimestamp < 0 - start of a new interval (or first frame)
-      // 2. currFps == SpecialFps::SAMPLE_ALL_FRAMES (-1) - sample every frame
-      // 3. timestamp - lastFrameTimestamp has reached target fps and
-      //    currFps > 0 (not special fps setting)
-      // different modes for fps:
-      // SpecialFps::SAMPLE_NO_FRAMES (0):
-      //     disable fps sampling, no frame sampled at all
-      // SpecialFps::SAMPLE_ALL_FRAMES (-1):
-      //     unlimited fps sampling, will sample at native video fps
-      // SpecialFps::SAMPLE_TIMESTAMP_ONLY (-2):
-      //     disable fps sampling, but will get the frame at specific timestamp
-      // others (> 0): decoding at the specified fps
-      bool fpsReached = lastFrameTimestamp < 0 ||
-          currFps == SpecialFps::SAMPLE_ALL_FRAMES ||
-          (currFps > 0 && timestamp >= lastFrameTimestamp + (1 / currFps));
-
-      if (!fpsReached) {
-        continue;
-      }
-    }
-
-    lastFrameTimestamp = timestamp;
-
-    outputFrameIndex++;
-    if (params.maximumOutputFrames_ != -1 &&
-        outputFrameIndex >= params.maximumOutputFrames_) {
-      // enough frames
-      break;
-    }
-
-    AVFrame* rgbFrame = av_frame_alloc();
-    if (!rgbFrame) {
-      LOG(ERROR) << "Error allocating AVframe";
-    }
-
-    // Determine required buffer size and allocate buffer
-    int numBytes = avpicture_get_size(pixFormat, outWidth, outHeight);
-    DecodedFrame::AvDataPtr buffer(
-        (uint8_t*)av_malloc(numBytes * sizeof(uint8_t)));
-
-    int size = avpicture_fill(
-        (AVPicture*)rgbFrame, buffer.get(), pixFormat, outWidth, outHeight);
-
-    sws_scale(
-        scaleContext_,
-        videoStreamFrame_->data,
-        videoStreamFrame_->linesize,
-        0,
+    // Create a scale context
+    scaleContext_ = sws_getContext(
+        videoCodecContext_->width,
         videoCodecContext_->height,
-        rgbFrame->data,
-        rgbFrame->linesize);
+        videoCodecContext_->pix_fmt,
+        outWidth,
+        outHeight,
+        pixFormat,
+        SWS_FAST_BILINEAR,
+        nullptr,
+        nullptr,
+        nullptr);
 
-    unique_ptr<DecodedFrame> frame = make_unique<DecodedFrame>();
-    frame->width_ = outWidth;
-    frame->height_ = outHeight;
-    frame->data_ = move(buffer);
-    frame->size_ = size;
-    frame->index_ = frameIndex;
-    frame->outputFrameIndex_ = outputFrameIndex;
-    frame->timestamp_ = timestamp;
-    frame->keyFrame_ = videoStreamFrame_->key_frame;
+    // Getting video meta data
+    VideoMeta videoMeta;
+    videoMeta.codec_type = videoCodecContext_->codec_type;
+    videoMeta.width = outWidth;
+    videoMeta.height = outHeight;
+    videoMeta.pixFormat = pixFormat;
+    videoMeta.fps = av_q2d(videoStream_->avg_frame_rate);
 
-    sampledFrames.push_back(move(frame));
-    av_frame_free(&rgbFrame);
+    // If sampledFrames is not empty, empty it
+    if (sampledFrames.size() > 0) {
+      sampledFrames.clear();
+    }
+
+    if (params.intervals_.size() == 0) {
+      LOG(ERROR) << "Empty sampling intervals.";
+    }
+
+    std::vector<SampleInterval>::const_iterator itvlIter =
+        params.intervals_.begin();
+    if (itvlIter->timestamp != 0) {
+      LOG(ERROR) << "Sampling interval starting timestamp is not zero.";
+    }
+
+    double currFps = itvlIter->fps;
+    if (currFps < 0 && currFps != SpecialFps::SAMPLE_ALL_FRAMES &&
+        currFps != SpecialFps::SAMPLE_TIMESTAMP_ONLY) {
+      // fps must be 0, -1, -2 or > 0
+      LOG(ERROR) << "Invalid sampling fps.";
+    }
+
+    double prevTimestamp = itvlIter->timestamp;
+    itvlIter++;
+    if (itvlIter != params.intervals_.end() &&
+        prevTimestamp >= itvlIter->timestamp) {
+      LOG(ERROR) << "Sampling interval timestamps must be strictly ascending.";
+    }
+
+    double lastFrameTimestamp = -1.0;
+    double timestamp = -1.0;
+
+    // Initialize frame and packet.
+    // These will be reused across calls.
+    videoStreamFrame_ = av_frame_alloc();
+
+    // frame index in video stream
+    int frameIndex = -1;
+    // frame index of outputed frames
+    int outputFrameIndex = -1;
+
+    int gotPicture = 0;
+    int eof = 0;
+
+    // There is a delay between reading packets from the
+    // transport and getting decoded frames back.
+    // Therefore, after EOF, continue going while
+    // the decoder is still giving us frames.
+    while (!eof || gotPicture) {
+      try {
+        if (!eof) {
+          ret = av_read_frame(inputContext, &packet);
+
+          if (ret == AVERROR(EAGAIN)) {
+            av_free_packet(&packet);
+            continue;
+          }
+          // Interpret any other error as EOF
+          if (ret < 0) {
+            eof = 1;
+            av_free_packet(&packet);
+            continue;
+          }
+
+          // Ignore packets from other streams
+          if (packet.stream_index != videoStreamIndex_) {
+            av_free_packet(&packet);
+            continue;
+          }
+        }
+
+        ret = avcodec_decode_video2(
+            videoCodecContext_, videoStreamFrame_, &gotPicture, &packet);
+        if (ret < 0) {
+          LOG(ERROR) << "Error decoding video frame : " << ffmpegErrorStr(ret);
+        }
+
+        try {
+          // Nothing to do without a picture
+          if (!gotPicture) {
+            av_free_packet(&packet);
+            continue;
+          }
+
+          frameIndex++;
+
+          timestamp = av_frame_get_best_effort_timestamp(videoStreamFrame_) *
+              av_q2d(videoStream_->time_base);
+
+          // if reaching the next interval, update the current fps
+          // and reset lastFrameTimestamp so the current frame could be sampled
+          // (unless fps == SpecialFps::SAMPLE_NO_FRAME)
+          if (itvlIter != params.intervals_.end() &&
+              timestamp >= itvlIter->timestamp) {
+            lastFrameTimestamp = -1.0;
+            currFps = itvlIter->fps;
+            prevTimestamp = itvlIter->timestamp;
+            itvlIter++;
+            if (itvlIter != params.intervals_.end() &&
+                prevTimestamp >= itvlIter->timestamp) {
+              LOG(ERROR)
+                  << "Sampling interval timestamps must be strictly ascending.";
+            }
+          }
+
+          // keyFrame will bypass all checks on fps sampling settings
+          bool keyFrame = params.keyFrames_ && videoStreamFrame_->key_frame;
+          if (!keyFrame) {
+            // if fps == SpecialFps::SAMPLE_NO_FRAME (0), don't sample at all
+            if (currFps == SpecialFps::SAMPLE_NO_FRAME) {
+              av_free_packet(&packet);
+              continue;
+            }
+
+            // fps is considered reached in the following cases:
+            // 1. lastFrameTimestamp < 0 - start of a new interval
+            //    (or first frame)
+            // 2. currFps == SpecialFps::SAMPLE_ALL_FRAMES (-1) - sample every
+            //    frame
+            // 3. timestamp - lastFrameTimestamp has reached target fps and
+            //    currFps > 0 (not special fps setting)
+            // different modes for fps:
+            // SpecialFps::SAMPLE_NO_FRAMES (0):
+            //     disable fps sampling, no frame sampled at all
+            // SpecialFps::SAMPLE_ALL_FRAMES (-1):
+            //     unlimited fps sampling, will sample at native video fps
+            // SpecialFps::SAMPLE_TIMESTAMP_ONLY (-2):
+            //     disable fps sampling, but will get the frame at specific
+            //     timestamp
+            // others (> 0): decoding at the specified fps
+            bool fpsReached = lastFrameTimestamp < 0 ||
+                currFps == SpecialFps::SAMPLE_ALL_FRAMES ||
+                (currFps > 0 && timestamp >=
+                  lastFrameTimestamp + (1 / currFps));
+
+            if (!fpsReached) {
+              av_free_packet(&packet);
+              continue;
+            }
+          }
+
+          lastFrameTimestamp = timestamp;
+
+          outputFrameIndex++;
+          if (params.maximumOutputFrames_ != -1 &&
+              outputFrameIndex >= params.maximumOutputFrames_) {
+            // enough frames
+            av_free_packet(&packet);
+            break;
+          }
+
+          AVFrame* rgbFrame = av_frame_alloc();
+          if (!rgbFrame) {
+            LOG(ERROR) << "Error allocating AVframe";
+          }
+
+          try {
+            // Determine required buffer size and allocate buffer
+            int numBytes = avpicture_get_size(pixFormat, outWidth, outHeight);
+            DecodedFrame::AvDataPtr buffer(
+                (uint8_t*)av_malloc(numBytes * sizeof(uint8_t)));
+
+            int size = avpicture_fill(
+                (AVPicture*)rgbFrame,
+                buffer.get(),
+                pixFormat,
+                outWidth,
+                outHeight);
+
+            sws_scale(
+                scaleContext_,
+                videoStreamFrame_->data,
+                videoStreamFrame_->linesize,
+                0,
+                videoCodecContext_->height,
+                rgbFrame->data,
+                rgbFrame->linesize);
+
+            unique_ptr<DecodedFrame> frame = make_unique<DecodedFrame>();
+            frame->width_ = outWidth;
+            frame->height_ = outHeight;
+            frame->data_ = move(buffer);
+            frame->size_ = size;
+            frame->index_ = frameIndex;
+            frame->outputFrameIndex_ = outputFrameIndex;
+            frame->timestamp_ = timestamp;
+            frame->keyFrame_ = videoStreamFrame_->key_frame;
+
+            sampledFrames.push_back(move(frame));
+            av_frame_free(&rgbFrame);
+          } catch (const std::exception&) {
+            av_frame_free(&rgbFrame);
+          }
+          av_frame_unref(videoStreamFrame_);
+        } catch (const std::exception&) {
+          av_frame_unref(videoStreamFrame_);
+        }
+
+        av_free_packet(&packet);
+      } catch (const std::exception&) {
+        av_free_packet(&packet);
+      }
+    } // of while loop
+
+    // free all stuffs
+    sws_freeContext(scaleContext_);
+    av_packet_unref(&packet);
+    av_frame_free(&videoStreamFrame_);
+    avcodec_close(videoCodecContext_);
+    avformat_close_input(&inputContext);
+    avformat_free_context(inputContext);
+  } catch (const std::exception&) {
+    // In case of decoding error
+    // free all stuffs
+    sws_freeContext(scaleContext_);
+    av_packet_unref(&packet);
+    av_frame_free(&videoStreamFrame_);
+    avcodec_close(videoCodecContext_);
+    avformat_close_input(&inputContext);
+    avformat_free_context(inputContext);
   }
-
-  av_free_packet(&packet);
-  av_frame_unref(videoStreamFrame_);
-  sws_freeContext(scaleContext_);
-  av_packet_unref(&packet);
-  av_frame_free(&videoStreamFrame_);
-  avcodec_close(videoCodecContext_);
-  avformat_close_input(&inputContext);
-  avformat_free_context(inputContext);
 }
 
 void VideoDecoder::decodeMemory(