Flutter

TensorFlow 사용하기 (YOLO v.8)

hamiric 2024. 12. 20. 12:25

TensorFlow란?

텐서플로우는 구글에서 개발한 오픈소스 러닝머신 라이브러리이다.

신경망 모델 이미지를 제공하여, 딥러닝 모델을 쉽게 만들고 학습시킴으로써 다양한 종류의 머신러닝 작업을 수행할 수 있게 한다.

 

물론, 신경망 모델을 사용하여, 딥러닝 모델을 직접 커스터마이징해 원하는 기능의 딥러닝 모델을 만들 수도 있지만,

일반적으로는 이미 만들어지고 학습된 모델을 사용하는 방법을 사용해보자!

 

LiteRT용 선행 학습된 모델  |  Google AI Edge  |  Google AI for Developers

LiteRT 소개: 온디바이스 AI를 위한 Google의 고성능 런타임(이전 명칭: TensorFlow Lite)입니다. 이 페이지는 Cloud Translation API를 통해 번역되었습니다. 의견 보내기 LiteRT용 선행 학습된 모델 이미 학습된

ai.google.dev

 

이번 포스팅에서는 이미지 객체 감지 모델중 가장 유명한 YOLO 모델을 사용해 볼 예정이다.

 

 

YOLO 란?

이미지를 한번만 보고, 바로 물체를 검출하는 딥러닝 기술을 이용한 객체 검출 딥러닝 모델.

 

YOLO(You Only Look Once) 모델 소개

Object Detection 국가대표 딥러닝 모델 | YOLO(You Only Look Once)는 물체 검출(Object Detection)에 관심이 있는 분들이라면 한번쯤은 들어봤을 Object Detection의 국가대표 딥러닝 모델이라고 할 수 있습니다. 물

brunch.co.kr

 

이번에 사용될 YOLO 모델은

COCO (Common Objects in Context) 데이터 셋을 통해 학습된 YOLO 모델을 사용해 볼 예정이다.

 

 

사용법

  • 학습된 딥러닝 모델을 사용하기 위해, 프로젝트 폴더에 추가

이번에 사용할 YOLO 모델 ( COCO 데이터셋을 통해 학습된 버전 )

 

GitHub - this-is-spartaa/YOLO-v8-model

Contribute to this-is-spartaa/YOLO-v8-model development by creating an account on GitHub.

github.com

yolov8n.tflite : 실제 YOLO 딥러닝 모델

yolo-test.jpg : 이번 실습에 사용할 테스트 이미지 

labels.txt : 결과로 나올 객체들 라벨

 

이 중 yolov8n.tflite, labels.txt 를 assets 폴더에 추가한 후, pubspec.yaml 에 asset 폴더 경로를 추가하자.

  assets:
    - assets/

 

  • Tensorflow 및 이미지 객체 검출기능을 사용하기 위한 라이브러리 설치

플러터에서 텐서플로우 라이트 모델로 추론(학습 완료된 모델을 사용하여 데이터 예측)할 수 있게 해주는 텐서플로우 공식 패키지

// Terminal
flutter pub add tflite_flutter

 

Yolo v8 모델로 추론할 때 고정된 입력 크기, 즉 이미지 크기를 요구

이때, 640 x 640 으로 이미지를 변환해야되는데 image 패키지사용하면 편리하게 변환 가능

// Terminal
flutter pub add image

 

앨범의 이미지를 어플에서 사용하기 위한 이미지 피커

// Terminal
flutter pub add image_picker

 

Yolo 모델로 이미지 추론을 하면 얻게되는 결과는

[ 클래스 인덱스, 신뢰도 점수, 바운딩박스(bbox - 탐지된 객체의 이미지 내에서 영역. x,y좌표 width, height) ] 형태로써

가공되지 않는 다차원 배열로 반환되는데, 이를 쉽게 가공시켜주는 패키지로 yolo_helper 패키지가 있음

다만, 공식 패키지가 아니기 때문에, github에서 가져와야 한다. (직접 pubspec.yaml 에 github 주소 입력)

// pubspec.yaml
  yolo_helper:
    git:
      url: https://github.com/fhelper/flutter-yolo-helper.git
      ref: main

 

 

  • 실제 YOLO 모델을 사용하기 위한 클래스 생성
import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:yolo_helper/yolo_helper.dart';

class YoloDetection {
  int? _numClasses;
  List<String>? _labels;
  Interpreter? _interpreter;

  String label(int index) => _labels?[index] ?? '';

  bool get isInitialized => _interpreter != null && _labels != null;

  // 1. 모델 불러오기
  Future<void> initialize() async {
    _interpreter = await Interpreter.fromAsset('assets/yolov8n.tflite');
    final labelAsset = await rootBundle.loadString('assets/labels.txt');
    _labels = labelAsset.split('\n');
    _numClasses = _labels!.length;
  }

  // 2. 이미지 입력받아서 추론
  List<DetectedObject> runInference(Image image) {
    if (!isInitialized) {
      throw Exception('The model must be initialized');
    }

    // 3. 이미지를 YOLO v8 input 에 맞게 640x640 사이즈로 변환
    final imgResized = copyResize(image, width: 640, height: 640);

    // 4. 변환된 이미지 픽셀 nomalize(정규화)
    // 640x640 이미지에서 각 픽셀값을 가져와서
    // 0~255 사이의 값인 RGB 값을 0~1 로 변환
    final imgNormalized = List.generate(
      640,
      (y) => List.generate(
        640,
        (x) {
          final pixel = imgResized.getPixel(x, y);
          return [pixel.rNormalized, pixel.gNormalized, pixel.bNormalized];
        },
      ),
    );

    final output = [
      List<List<double>>.filled(4 + _numClasses!, List<double>.filled(8400, 0))
    ];
    _interpreter!.run([imgNormalized], output);
    // 원본 이미지 사이즈 넘기기!!!
    return YoloHelper.parse(output[0], image.width, image.height);
  }
}

 

  • 실제 사용하기

이미지 피커로 이미지를 넣은후, 해당 이미지를 YOLO를 사용하여 객체를 검출하기

import 'dart:typed_data';

import 'package:flutter/material.dart';
import 'package:flutter_tflite_temp/yolo_detection.dart';
import 'package:image/image.dart' as img;
import 'package:image_picker/image_picker.dart';
import 'package:yolo_helper/yolo_helper.dart';

class HomePage extends StatefulWidget {
  const HomePage({super.key});

  @override
  State<HomePage> createState() => _HomePageState();
}

class _HomePageState extends State<HomePage> {
  final YoloDetection model = YoloDetection();
  final ImagePicker picker = ImagePicker();
  List<DetectedObject>? detectedObjects;
  Uint8List? imageBytes;
  int? imageWidth;
  int? imageHeight;

  @override
  void initState() {
    super.initState();
    model.initialize();
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: const Text('YOLO')),
      body: GestureDetector(
        onTap: () async {
          // 이미지 피커를 이용한 이미지 가져오기
          if (!model.isInitialized) {
            return;
          }
          final XFile? newImageFile =
              await picker.pickImage(source: ImageSource.gallery);
          if (newImageFile != null) {
            final bytes = await newImageFile.readAsBytes();
            final image = img.decodeImage(bytes)!;
            imageWidth = image.width;
            imageHeight = image.height;
            setState(() {
              imageBytes = bytes;
            });

            // YOLO 모델을 통해 이미지에서 객체 검출
            // [pixel.rNormalized, pixel.gNormalized, pixel.bNormalized] 형태
            detectedObjects = model.runInference(image);
          }
        },
        child: ListView(
          children: [
            if (imageBytes == null)
              const Icon(
                Icons.file_open_outlined,
                size: 80,
              )
            else
              Stack(
                children: [
                  AspectRatio(
                    aspectRatio: imageWidth! / imageHeight!,
                    child: Image.memory(
                      imageBytes!,
                      fit: BoxFit.cover,
                    ),
                  ),
                  if (detectedObjects != null)
                    ...detectedObjects!.map(
                      (e) => Bbox(
                        detectedObject: detectedObject,
                        imageWidth: imageWidth!,
                        imageHeihgt: imageHeight!,
                        label: model.label(e.labelIndex),
                      ),
                    ),
                ],
              ),
          ],
        ),
      ),
    );
  }
}

 

객체를 감싸는 박스

class Bbox extends StatelessWidget {
  final DetectedObject detectedObject;
  final int imageWidth;
  final int imageHeight;
  final String label;
  
  Bbox({
    required this.detectedObject,
    required this.imageWidth,
    required this.imageHeight,
    required this.label,
  });
  
  @override
  Widget build(BuildContext context){
    final deviceWidth = MediaQuery,sizeOf(context).width;
    final resizeFactor = devieceWidth / imageWidth;
    
    // 객체를 감싸고 있는 박스의 중간 좌표
    final resizedX = detectedObject.x * resizeFactor;
    final resizedY = detectedObject.y * resizeFactor;
    
    // 박스의 크기
    final resizedW = detectedObject.width * resizeFactor;
    final resizedH = detectedObject.height * resizeFactor;
  
    return Positioned(
      left : resizedX - resizedW / 2,
      top : resizedY - resizedH / 2,
      child : Container(
        width : resizedW,
        height : resizedH,
        decoration : BoxDecoration(
          border: Border.all(
            color: Colors.amber,
            width: 3,
          )
        )
        child: Text(
          label,
          style: TextStyle(
            color: Colors.red,
            fontWeight: FontWeight.bold,
            fontSize: 11,
          ),
        ),
      );
    );
  }
}

 

 

 

 

 

## 참고

혹시 머신러닝 관련 커스텀 모델을 만들고 싶다면 참고

 

Tensorflow/텐서플로우를 이용한 사진 분류 딥러닝 모델 만들기.ipynb at master · boringariel/Tensorflow

Contribute to boringariel/Tensorflow development by creating an account on GitHub.

github.com