■ JITHub 개발일지 53일차
□ TIL(Today I Learned) ::
OpenCV를 활용한 딥러닝 이미지 처리
※ 딥 러닝 이미지 처리 기능 학습 내용과 들었던 생각
- 이번에 알게된 딥러닝 기술은 이미지 처리기술 중 style transfer이다.
- 딥러닝을 활용해서 이미지를 처리할 때 다양한 기능들을 활용할 수 있었다. 예를들면 두 개의 이미지를 합치는 기능부터 이미지의 특정부분을 이미 학습한 딥러닝 모델을 통해 이식하는 기능, 특정 화풍을 입혀넛는 기능 등이 있다.
- style transfer : 다른 그림의 화풍을 적용시키는 기술로써 2015년부터 나왔던 간단한 기술이라고 한다.
- 이번에 학습한 기능은 style transfer와 가깝다.
- 특강에서 사용했던 코드는 아래와 같다.
# web_interface.py
import asyncio
from dataclasses import dataclass, is_dataclass
import io
import json
from pathlib import Path
from aiohttp import web
import torch
import torch.multiprocessing as mp
from torchvision.transforms import functional as TF
# from . import srgb_profile, STIterate
srgb_profile = (Path(__file__).resolve().parent / 'sRGB Profile.icc').read_bytes()
from style_transfer import STIterate, StyleTransfer
@dataclass
class WIIterate:
iterate: STIterate
image: torch.Tensor
@dataclass
class WIDone:
pass
@dataclass
class WIStop:
pass
class DCJSONEncoder(json.JSONEncoder):
def default(self, obj):
if is_dataclass(obj):
dct = dict(obj.__dict__)
dct['_type'] = type(obj).__name__
return dct
return super().default(obj)
class WebInterface:
def __init__(self, host, port):
self.host = host
self.port = port
self.q = mp.Queue()
self.encoder = DCJSONEncoder()
self.image = None
self.loop = None
self.runner = None
self.wss = []
self.app = web.Application()
self.static_path = Path(__file__).resolve().parent / 'web_static'
self.app.router.add_routes([web.get('/', self.handle_index),
web.get('/image', self.handle_image),
web.get('/websocket', self.handle_websocket),
web.static('/', self.static_path)])
print(f'Starting web interface at http://{self.host}:{self.port}/')
self.process = mp.Process(target=self.run)
self.process.start()
async def run_app(self):
self.runner = web.AppRunner(self.app)
await self.runner.setup()
site = web.TCPSite(self.runner, self.host, self.port, shutdown_timeout=5)
await site.start()
while True:
await asyncio.sleep(3600)
async def process_events(self):
while True:
f = self.loop.run_in_executor(None, self.q.get)
await f
event = f.result()
if isinstance(event, WIIterate):
self.image = event.image
await self.send_websocket_message(event.iterate)
elif isinstance(event, WIDone):
await self.send_websocket_message(event)
if self.wss:
print('Waiting for web clients to finish...')
await asyncio.sleep(5)
elif isinstance(event, WIStop):
for ws in self.wss:
await ws.close()
if self.runner is not None:
await self.runner.cleanup()
self.loop.stop()
return
def compress_image(self):
buf = io.BytesIO()
TF.to_pil_image(self.image).save(buf, format='jpeg', icc_profile=srgb_profile,
quality=95, subsampling=0)
return buf.getvalue()
async def handle_image(self, request):
if self.image is None:
raise web.HTTPNotFound()
f = self.loop.run_in_executor(None, self.compress_image)
await f
return web.Response(body=f.result(), content_type='image/jpeg')
async def handle_index(self, request):
body = (self.static_path / 'index.html').read_bytes()
return web.Response(body=body, content_type='text/html')
async def handle_websocket(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.wss.append(ws)
async for _ in ws:
pass
try:
self.wss.remove(ws)
except ValueError:
pass
return ws
async def send_websocket_message(self, msg):
for ws in self.wss:
try:
await ws.send_json(msg, dumps=self.encoder.encode)
except ConnectionError:
try:
self.wss.remove(ws)
except ValueError:
pass
def put_iterate(self, iterate, image):
self.q.put_nowait(WIIterate(iterate, image.cpu()))
def put_done(self):
self.q.put(WIDone())
def close(self):
self.q.put(WIStop())
self.process.join(12)
def run(self):
self.loop = asyncio.get_event_loop()
asyncio.ensure_future(self.run_app())
asyncio.ensure_future(self.process_events())
try:
self.loop.run_forever()
except KeyboardInterrupt:
self.q.put(WIStop())
self.loop.run_forever()
# cli.py
"""Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch."""
import argparse
import atexit
from dataclasses import asdict
import io
import json
from pathlib import Path
import platform
import sys
import webbrowser
import numpy as np
from PIL import Image, ImageCms
from tifffile import TIFF, TiffWriter
import torch
import torch.multiprocessing as mp
from tqdm import tqdm
# from . import srgb_profile, StyleTransfer, WebInterface
from style_transfer import STIterate, StyleTransfer
from web_interface import WebInterface
srgb_profile = (Path(__file__).resolve().parent / 'sRGB Profile.icc').read_bytes()
def prof_to_prof(image, src_prof, dst_prof, **kwargs):
src_prof = io.BytesIO(src_prof)
dst_prof = io.BytesIO(dst_prof)
return ImageCms.profileToProfile(image, src_prof, dst_prof, **kwargs)
def load_image(path, proof_prof=None):
src_prof = dst_prof = srgb_profile
try:
image = Image.open(path)
if 'icc_profile' in image.info:
src_prof = image.info['icc_profile']
else:
image = image.convert('RGB')
if proof_prof is None:
if src_prof == dst_prof:
return image.convert('RGB')
return prof_to_prof(image, src_prof, dst_prof, outputMode='RGB')
proof_prof = Path(proof_prof).read_bytes()
cmyk = prof_to_prof(image, src_prof, proof_prof, outputMode='CMYK')
return prof_to_prof(cmyk, proof_prof, dst_prof, outputMode='RGB')
except OSError as err:
print_error(err)
sys.exit(1)
def save_pil(path, image):
try:
kwargs = {'icc_profile': srgb_profile}
if path.suffix.lower() in {'.jpg', '.jpeg'}:
kwargs['quality'] = 95
kwargs['subsampling'] = 0
elif path.suffix.lower() == '.webp':
kwargs['quality'] = 95
image.save(path, **kwargs)
except (OSError, ValueError) as err:
print_error(err)
sys.exit(1)
def save_tiff(path, image):
tag = ('InterColorProfile', TIFF.DATATYPES.BYTE, len(srgb_profile), srgb_profile, False)
try:
with TiffWriter(path) as writer:
writer.save(image, photometric='rgb', resolution=(72, 72), extratags=[tag])
except OSError as err:
print_error(err)
sys.exit(1)
def save_image(path, image):
path = Path(path)
tqdm.write(f'Writing image to {path}.')
if isinstance(image, Image.Image):
save_pil(path, image)
elif isinstance(image, np.ndarray) and path.suffix.lower() in {'.tif', '.tiff'}:
save_tiff(path, image)
else:
raise ValueError('Unsupported combination of image type and extension')
def get_safe_scale(w, h, dim):
"""Given a w x h content image and that a dim x dim square does not
exceed GPU memory, compute a safe end_scale for that content image."""
return int(pow(w / h if w > h else h / w, 1/2) * dim)
def setup_exceptions():
try:
from IPython.core.ultratb import FormattedTB
sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral')
except ImportError:
pass
def fix_start_method():
if platform.system() == 'Darwin':
mp.set_start_method('spawn')
def print_error(err):
print('\033[31m{}:\033[0m {}'.format(type(err).__name__, err), file=sys.stderr)
class Callback:
def __init__(self, st, args, image_type='pil', web_interface=None):
self.st = st
self.args = args
self.image_type = image_type
self.web_interface = web_interface
self.iterates = []
self.progress = None
def __call__(self, iterate):
self.iterates.append(asdict(iterate))
if iterate.i == 1:
self.progress = tqdm(total=iterate.i_max, dynamic_ncols=True)
msg = 'Size: {}x{}, iteration: {}, loss: {:g}'
tqdm.write(msg.format(iterate.w, iterate.h, iterate.i, iterate.loss))
self.progress.update()
if self.web_interface is not None:
self.web_interface.put_iterate(iterate, self.st.get_image_tensor())
if iterate.i == iterate.i_max:
self.progress.close()
if max(iterate.w, iterate.h) != self.args.end_scale:
save_image(self.args.output, self.st.get_image(self.image_type))
else:
if self.web_interface is not None:
self.web_interface.put_done()
elif iterate.i % self.args.save_every == 0:
save_image(self.args.output, self.st.get_image(self.image_type))
def close(self):
if self.progress is not None:
self.progress.close()
def get_trace(self):
return {'args': self.args.__dict__, 'iterates': self.iterates}
def main():
setup_exceptions()
fix_start_method()
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
def arg_info(arg):
defaults = StyleTransfer.stylize.__kwdefaults__
default_types = StyleTransfer.stylize.__annotations__
return {'default': defaults[arg], 'type': default_types[arg]}
p.add_argument('content', type=str, help='the content image')
p.add_argument('styles', type=str, nargs='+', metavar='style', help='the style images')
p.add_argument('--output', '-o', type=str, default='out.png',
help='the output image')
p.add_argument('--style-weights', '-sw', type=float, nargs='+', default=None,
metavar='STYLE_WEIGHT', help='the relative weights for each style image')
p.add_argument('--devices', type=str, default=[], nargs='+',
help='the device names to use (omit for auto)')
p.add_argument('--random-seed', '-r', type=int, default=0,
help='the random seed')
p.add_argument('--content-weight', '-cw', **arg_info('content_weight'),
help='the content weight')
p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'),
help='the smoothing weight')
p.add_argument('--min-scale', '-ms', **arg_info('min_scale'),
help='the minimum scale (max image dim), in pixels')
p.add_argument('--end-scale', '-s', type=str, default='512',
help='the final scale (max image dim), in pixels')
p.add_argument('--iterations', '-i', **arg_info('iterations'),
help='the number of iterations per scale')
p.add_argument('--initial-iterations', '-ii', **arg_info('initial_iterations'),
help='the number of iterations on the first scale')
p.add_argument('--save-every', type=int, default=50,
help='save the image every SAVE_EVERY iterations')
p.add_argument('--step-size', '-ss', **arg_info('step_size'),
help='the step size (learning rate)')
p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'),
help='the EMA decay rate for iterate averaging')
p.add_argument('--init', **arg_info('init'),
choices=['content', 'gray', 'uniform', 'style_mean'],
help='the initial image')
p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'),
help='the relative scale of the style to the content')
p.add_argument('--style-size', **arg_info('style_size'),
help='the fixed scale of the style at different content scales')
p.add_argument('--pooling', type=str, default='max', choices=['max', 'average', 'l2'],
help='the model\'s pooling mode')
p.add_argument('--proof', type=str, default=None,
help='the ICC color profile (CMYK) for soft proofing the content and styles')
p.add_argument('--web', default=False, action='store_true', help='enable the web interface')
p.add_argument('--host', type=str, default='0.0.0.0',
help='the host the web interface binds to')
p.add_argument('--port', type=int, default=8080,
help='the port the web interface binds to')
p.add_argument('--browser', type=str, default='', nargs='?',
help='open a web browser (specify the browser if not system default)')
args = p.parse_args()
content_img = load_image(args.content, args.proof)
style_imgs = [load_image(img, args.proof) for img in args.styles]
image_type = 'pil'
if Path(args.output).suffix.lower() in {'.tif', '.tiff'}:
image_type = 'np_uint16'
devices = [torch.device(device) for device in args.devices]
if not devices:
devices = [torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')]
if len(set(device.type for device in devices)) != 1:
print('Devices must all be the same type.')
sys.exit(1)
if not 1 <= len(devices) <= 2:
print('Only 1 or 2 devices are supported.')
sys.exit(1)
print('Using devices:', ' '.join(str(device) for device in devices))
if devices[0].type == 'cpu':
print('CPU threads:', torch.get_num_threads())
if devices[0].type == 'cuda':
for i, device in enumerate(devices):
props = torch.cuda.get_device_properties(device)
print(f'GPU {i} type: {props.name} (compute {props.major}.{props.minor})')
print(f'GPU {i} RAM:', round(props.total_memory / 1024 / 1024), 'MB')
end_scale = int(args.end_scale.rstrip('+'))
if args.end_scale.endswith('+'):
end_scale = get_safe_scale(*content_img.size, end_scale)
args.end_scale = end_scale
web_interface = None
if args.web:
web_interface = WebInterface(args.host, args.port)
atexit.register(web_interface.close)
for device in devices:
torch.tensor(0).to(device)
torch.manual_seed(args.random_seed)
print('Loading model...')
st = StyleTransfer(devices=devices, pooling=args.pooling)
callback = Callback(st, args, image_type=image_type, web_interface=web_interface)
atexit.register(callback.close)
url = f'http://{args.host}:{args.port}/'
if args.web:
if args.browser:
webbrowser.get(args.browser).open(url)
elif args.browser is None:
webbrowser.open(url)
defaults = StyleTransfer.stylize.__kwdefaults__
st_kwargs = {k: v for k, v in args.__dict__.items() if k in defaults}
try:
st.stylize(content_img, style_imgs, **st_kwargs, callback=callback)
except KeyboardInterrupt:
pass
output_image = st.get_image(image_type)
if output_image is not None:
save_image(args.output, output_image)
with open('trace.json', 'w') as fp:
json.dump(callback.get_trace(), fp, indent=4)
if __name__ == '__main__':
main()
- 딥러닝을 공부하면서 생각이 든 것은 1) 사용할 수 있는 라이브러리를 이해, 2) 라이브러리를 활용하여 학습된 모델을 만드는 것, 3) 이미 만들어진 모델을 활용하여 필요한 기능을 구현해내는 것, 이 3가지 중에서 현재는 3)번을 이용해 기능을 어떻게 구현하는지 학습하고 있는데 학습 순서가 맞는지에 대한 의문이다.
- 학습된 모델이 어떤식으로 사용되고 구동되는지 먼저 잘 알고, 나중에 모델을 학습시킬 때 이에 대해서 생각을 하고 연구하게 될까?라는 생각이 들었다. 현재로썬 프로젝트에 딥러닝 말고도 Backend 부분도 매우 중요하므로 시간할애를 많이 하지 못하고 있는 실정이다.
- 나중에 기회가 되면 딥러닝의 다양한 모델을을 직접 써보고 관심이 가는 분야를 선택해서 공부해보는 것도 필요할 것 같다.
- 이번에 배웠던 기능은 학습된 모델(특정 화가의 화풍을 적용할 수 있도록 만들어진 모델)을 사용하여 임의의 그림을 특정 화풍으로 변환해주는 기능이다. 이를 거꾸로 사용할 수 있을까?하는 생각이 들었다. 예를들면 아래 2가지였다.
1) 특정 화가가 그린 그림을 실물 사진처럼 변환
2) 특정 화가가 그린 그림이 화풍만 보고 어떤 화가가 그린 그림인지 추론
- 이미 학습된 모델을 사용한다면 변환을 거꾸로 사용하는 것이라고 생각했는데, 생각하는 것 보다 복잡하거나, 아닐 수 있다.
1) 첫번째 케이스는 특정 화가가 그린 그림을 실물 사진처럼 변환한다면, '그림을 실물 사진처럼 변환시키는 모델'이 필요하지, '일반 그림이나 사진을 특정 화풍으로 변환시키는 모델'이 필요하지 않을 수 있다.
2) 두번째 케이스의 로직을 생각해보면 특정 그림을 갖고있는 여러 모델들을 돌려가며 화풍적용 변환을 하고 기존 그림과 변환된 그림을 비교하면서 화풍이 '가장 덜 바뀐' 그림을 찾는 것인데, '가장 덜 바뀐' 이라는 기준의 정도를 판단하고, 어떤 것이 바뀌었고 바뀌지 않았는지 판별하는 부분이 가능할 지 의문이다.
- 아무튼 딥러닝을 공부하면서 직접 사용하는 법을 익히는데 집중하고 있지만 관심이 생기니 다양한 의문점도 들고 있다. 딥러닝 기술은 최근 빠른 속도로 발전하고 생겨나고 있기 때문에 이 흐름을 잘 타고 차근차근 영역을 넓혀나가는 것이 필요할 것 같다.
- 모든 기술의 종착점은 서비스, 즉 사용하는 사람의 편의성 제공이라고 생각된다. 딥러닝 기술을 학문으로서 연구하는 것도 중요하지만 이러한 궁극적 목표를 항상 생각하고 개발하는 개발자가 되고싶다.
'DataScience > 딥러닝' 카테고리의 다른 글
딥러닝 :: 밑바닥부터 시작하는 딥러닝 Chap2. 퍼셉트론 (0) | 2023.05.02 |
---|---|
딥러닝 :: 밑바닥부터 시작하는 딥러닝 Chap1. 헬로 파이썬 (0) | 2023.05.01 |
딥러닝 :: 11월 셋째주 WIL #12 (0) | 2022.11.18 |
딥러닝 :: OpenCV를 활용한 영상 처리, 도커 복습_TIL54 (0) | 2022.11.18 |
딥러닝 :: 이미지 처리 구현 _TIL#52 (0) | 2022.11.16 |
댓글