TensorFlow 사용하기 (YOLO v.8)
TensorFlow란?
텐서플로우는 구글에서 개발한 오픈소스 러닝머신 라이브러리이다.
신경망 모델 이미지를 제공하여, 딥러닝 모델을 쉽게 만들고 학습시킴으로써 다양한 종류의 머신러닝 작업을 수행할 수 있게 한다.
물론, 신경망 모델을 사용하여, 딥러닝 모델을 직접 커스터마이징해 원하는 기능의 딥러닝 모델을 만들 수도 있지만,
일반적으로는 이미 만들어지고 학습된 모델을 사용하는 방법을 사용해보자!
LiteRT용 선행 학습된 모델 | Google AI Edge | Google AI for Developers
LiteRT 소개: 온디바이스 AI를 위한 Google의 고성능 런타임(이전 명칭: TensorFlow Lite)입니다. 이 페이지는 Cloud Translation API를 통해 번역되었습니다. 의견 보내기 LiteRT용 선행 학습된 모델 이미 학습된
ai.google.dev
이번 포스팅에서는 이미지 객체 감지 모델중 가장 유명한 YOLO 모델을 사용해 볼 예정이다.
YOLO 란?
이미지를 한번만 보고, 바로 물체를 검출하는 딥러닝 기술을 이용한 객체 검출 딥러닝 모델.
YOLO(You Only Look Once) 모델 소개
Object Detection 국가대표 딥러닝 모델 | YOLO(You Only Look Once)는 물체 검출(Object Detection)에 관심이 있는 분들이라면 한번쯤은 들어봤을 Object Detection의 국가대표 딥러닝 모델이라고 할 수 있습니다. 물
brunch.co.kr
이번에 사용될 YOLO 모델은
COCO (Common Objects in Context) 데이터 셋을 통해 학습된 YOLO 모델을 사용해 볼 예정이다.
사용법
- 학습된 딥러닝 모델을 사용하기 위해, 프로젝트 폴더에 추가
이번에 사용할 YOLO 모델 ( COCO 데이터셋을 통해 학습된 버전 )
GitHub - this-is-spartaa/YOLO-v8-model
Contribute to this-is-spartaa/YOLO-v8-model development by creating an account on GitHub.
github.com
yolov8n.tflite : 실제 YOLO 딥러닝 모델
yolo-test.jpg : 이번 실습에 사용할 테스트 이미지
labels.txt : 결과로 나올 객체들 라벨
이 중 yolov8n.tflite, labels.txt 를 assets 폴더에 추가한 후, pubspec.yaml 에 asset 폴더 경로를 추가하자.
assets:
- assets/
- Tensorflow 및 이미지 객체 검출기능을 사용하기 위한 라이브러리 설치
플러터에서 텐서플로우 라이트 모델로 추론(학습 완료된 모델을 사용하여 데이터 예측)할 수 있게 해주는 텐서플로우 공식 패키지
// Terminal
flutter pub add tflite_flutter
Yolo v8 모델로 추론할 때 고정된 입력 크기, 즉 이미지 크기를 요구
이때, 640 x 640 으로 이미지를 변환해야되는데 image 패키지사용하면 편리하게 변환 가능
// Terminal
flutter pub add image
앨범의 이미지를 어플에서 사용하기 위한 이미지 피커
// Terminal
flutter pub add image_picker
Yolo 모델로 이미지 추론을 하면 얻게되는 결과는
[ 클래스 인덱스, 신뢰도 점수, 바운딩박스(bbox - 탐지된 객체의 이미지 내에서 영역. x,y좌표 width, height) ] 형태로써
가공되지 않는 다차원 배열로 반환되는데, 이를 쉽게 가공시켜주는 패키지로 yolo_helper 패키지가 있음
다만, 공식 패키지가 아니기 때문에, github에서 가져와야 한다. (직접 pubspec.yaml 에 github 주소 입력)
// pubspec.yaml
yolo_helper:
git:
url: https://github.com/fhelper/flutter-yolo-helper.git
ref: main
- 실제 YOLO 모델을 사용하기 위한 클래스 생성
import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:yolo_helper/yolo_helper.dart';
class YoloDetection {
int? _numClasses;
List<String>? _labels;
Interpreter? _interpreter;
String label(int index) => _labels?[index] ?? '';
bool get isInitialized => _interpreter != null && _labels != null;
// 1. 모델 불러오기
Future<void> initialize() async {
_interpreter = await Interpreter.fromAsset('assets/yolov8n.tflite');
final labelAsset = await rootBundle.loadString('assets/labels.txt');
_labels = labelAsset.split('\n');
_numClasses = _labels!.length;
}
// 2. 이미지 입력받아서 추론
List<DetectedObject> runInference(Image image) {
if (!isInitialized) {
throw Exception('The model must be initialized');
}
// 3. 이미지를 YOLO v8 input 에 맞게 640x640 사이즈로 변환
final imgResized = copyResize(image, width: 640, height: 640);
// 4. 변환된 이미지 픽셀 nomalize(정규화)
// 640x640 이미지에서 각 픽셀값을 가져와서
// 0~255 사이의 값인 RGB 값을 0~1 로 변환
final imgNormalized = List.generate(
640,
(y) => List.generate(
640,
(x) {
final pixel = imgResized.getPixel(x, y);
return [pixel.rNormalized, pixel.gNormalized, pixel.bNormalized];
},
),
);
final output = [
List<List<double>>.filled(4 + _numClasses!, List<double>.filled(8400, 0))
];
_interpreter!.run([imgNormalized], output);
// 원본 이미지 사이즈 넘기기!!!
return YoloHelper.parse(output[0], image.width, image.height);
}
}
- 실제 사용하기
이미지 피커로 이미지를 넣은후, 해당 이미지를 YOLO를 사용하여 객체를 검출하기
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter_tflite_temp/yolo_detection.dart';
import 'package:image/image.dart' as img;
import 'package:image_picker/image_picker.dart';
import 'package:yolo_helper/yolo_helper.dart';
class HomePage extends StatefulWidget {
const HomePage({super.key});
@override
State<HomePage> createState() => _HomePageState();
}
class _HomePageState extends State<HomePage> {
final YoloDetection model = YoloDetection();
final ImagePicker picker = ImagePicker();
List<DetectedObject>? detectedObjects;
Uint8List? imageBytes;
int? imageWidth;
int? imageHeight;
@override
void initState() {
super.initState();
model.initialize();
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: const Text('YOLO')),
body: GestureDetector(
onTap: () async {
// 이미지 피커를 이용한 이미지 가져오기
if (!model.isInitialized) {
return;
}
final XFile? newImageFile =
await picker.pickImage(source: ImageSource.gallery);
if (newImageFile != null) {
final bytes = await newImageFile.readAsBytes();
final image = img.decodeImage(bytes)!;
imageWidth = image.width;
imageHeight = image.height;
setState(() {
imageBytes = bytes;
});
// YOLO 모델을 통해 이미지에서 객체 검출
// [pixel.rNormalized, pixel.gNormalized, pixel.bNormalized] 형태
detectedObjects = model.runInference(image);
}
},
child: ListView(
children: [
if (imageBytes == null)
const Icon(
Icons.file_open_outlined,
size: 80,
)
else
Stack(
children: [
AspectRatio(
aspectRatio: imageWidth! / imageHeight!,
child: Image.memory(
imageBytes!,
fit: BoxFit.cover,
),
),
if (detectedObjects != null)
...detectedObjects!.map(
(e) => Bbox(
detectedObject: detectedObject,
imageWidth: imageWidth!,
imageHeihgt: imageHeight!,
label: model.label(e.labelIndex),
),
),
],
),
],
),
),
);
}
}
객체를 감싸는 박스
class Bbox extends StatelessWidget {
final DetectedObject detectedObject;
final int imageWidth;
final int imageHeight;
final String label;
Bbox({
required this.detectedObject,
required this.imageWidth,
required this.imageHeight,
required this.label,
});
@override
Widget build(BuildContext context){
final deviceWidth = MediaQuery,sizeOf(context).width;
final resizeFactor = devieceWidth / imageWidth;
// 객체를 감싸고 있는 박스의 중간 좌표
final resizedX = detectedObject.x * resizeFactor;
final resizedY = detectedObject.y * resizeFactor;
// 박스의 크기
final resizedW = detectedObject.width * resizeFactor;
final resizedH = detectedObject.height * resizeFactor;
return Positioned(
left : resizedX - resizedW / 2,
top : resizedY - resizedH / 2,
child : Container(
width : resizedW,
height : resizedH,
decoration : BoxDecoration(
border: Border.all(
color: Colors.amber,
width: 3,
)
)
child: Text(
label,
style: TextStyle(
color: Colors.red,
fontWeight: FontWeight.bold,
fontSize: 11,
),
),
);
);
}
}
## 참고
혹시 머신러닝 관련 커스텀 모델을 만들고 싶다면 참고
Tensorflow/텐서플로우를 이용한 사진 분류 딥러닝 모델 만들기.ipynb at master · boringariel/Tensorflow
Contribute to boringariel/Tensorflow development by creating an account on GitHub.
github.com