본문 바로가기
DataScience/딥러닝

딥러닝 :: OpenCV, Style Transfer를 활용한 이미지 처리_TIL53

by 올커 2022. 11. 18.

■ JITHub 개발일지 53일차

□  TIL(Today I Learned) ::

OpenCV를 활용한 딥러닝 이미지 처리

 ※ 딥 러닝 이미지 처리 기능 학습 내용과 들었던 생각 

  - 이번에 알게된 딥러닝 기술은 이미지 처리기술 중 style transfer이다.

  - 딥러닝을 활용해서 이미지를 처리할 때 다양한 기능들을 활용할 수 있었다. 예를들면 두 개의 이미지를 합치는 기능부터 이미지의 특정부분을 이미 학습한 딥러닝 모델을 통해 이식하는 기능, 특정 화풍을 입혀넛는 기능 등이 있다.

  - style transfer : 다른 그림의 화풍을 적용시키는 기술로써 2015년부터 나왔던 간단한 기술이라고 한다.

  - 이번에 학습한 기능은 style transfer와 가깝다.

  - 특강에서 사용했던 코드는 아래와 같다.

# web_interface.py

import asyncio
from dataclasses import dataclass, is_dataclass
import io
import json
from pathlib import Path

from aiohttp import web
import torch
import torch.multiprocessing as mp
from torchvision.transforms import functional as TF

# from . import srgb_profile, STIterate
srgb_profile = (Path(__file__).resolve().parent / 'sRGB Profile.icc').read_bytes()
from style_transfer import STIterate, StyleTransfer

@dataclass
class WIIterate:
    iterate: STIterate
    image: torch.Tensor


@dataclass
class WIDone:
    pass


@dataclass
class WIStop:
    pass


class DCJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if is_dataclass(obj):
            dct = dict(obj.__dict__)
            dct['_type'] = type(obj).__name__
            return dct
        return super().default(obj)


class WebInterface:
    def __init__(self, host, port):
        self.host = host
        self.port = port
        self.q = mp.Queue()
        self.encoder = DCJSONEncoder()
        self.image = None
        self.loop = None
        self.runner = None
        self.wss = []

        self.app = web.Application()
        self.static_path = Path(__file__).resolve().parent / 'web_static'
        self.app.router.add_routes([web.get('/', self.handle_index),
                                    web.get('/image', self.handle_image),
                                    web.get('/websocket', self.handle_websocket),
                                    web.static('/', self.static_path)])

        print(f'Starting web interface at http://{self.host}:{self.port}/')
        self.process = mp.Process(target=self.run)
        self.process.start()

    async def run_app(self):
        self.runner = web.AppRunner(self.app)
        await self.runner.setup()
        site = web.TCPSite(self.runner, self.host, self.port, shutdown_timeout=5)
        await site.start()
        while True:
            await asyncio.sleep(3600)

    async def process_events(self):
        while True:
            f = self.loop.run_in_executor(None, self.q.get)
            await f
            event = f.result()
            if isinstance(event, WIIterate):
                self.image = event.image
                await self.send_websocket_message(event.iterate)
            elif isinstance(event, WIDone):
                await self.send_websocket_message(event)
                if self.wss:
                    print('Waiting for web clients to finish...')
                    await asyncio.sleep(5)
            elif isinstance(event, WIStop):
                for ws in self.wss:
                    await ws.close()
                if self.runner is not None:
                    await self.runner.cleanup()
                self.loop.stop()
                return

    def compress_image(self):
        buf = io.BytesIO()
        TF.to_pil_image(self.image).save(buf, format='jpeg', icc_profile=srgb_profile,
                                         quality=95, subsampling=0)
        return buf.getvalue()

    async def handle_image(self, request):
        if self.image is None:
            raise web.HTTPNotFound()
        f = self.loop.run_in_executor(None, self.compress_image)
        await f
        return web.Response(body=f.result(), content_type='image/jpeg')

    async def handle_index(self, request):
        body = (self.static_path / 'index.html').read_bytes()
        return web.Response(body=body, content_type='text/html')

    async def handle_websocket(self, request):
        ws = web.WebSocketResponse()
        await ws.prepare(request)
        self.wss.append(ws)
        async for _ in ws:
            pass
        try:
            self.wss.remove(ws)
        except ValueError:
            pass
        return ws

    async def send_websocket_message(self, msg):
        for ws in self.wss:
            try:
                await ws.send_json(msg, dumps=self.encoder.encode)
            except ConnectionError:
                try:
                    self.wss.remove(ws)
                except ValueError:
                    pass

    def put_iterate(self, iterate, image):
        self.q.put_nowait(WIIterate(iterate, image.cpu()))

    def put_done(self):
        self.q.put(WIDone())

    def close(self):
        self.q.put(WIStop())
        self.process.join(12)

    def run(self):
        self.loop = asyncio.get_event_loop()
        asyncio.ensure_future(self.run_app())
        asyncio.ensure_future(self.process_events())
        try:
            self.loop.run_forever()
        except KeyboardInterrupt:
            self.q.put(WIStop())
            self.loop.run_forever()
# cli.py

"""Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch."""

import argparse
import atexit
from dataclasses import asdict
import io
import json
from pathlib import Path
import platform
import sys
import webbrowser

import numpy as np
from PIL import Image, ImageCms
from tifffile import TIFF, TiffWriter
import torch
import torch.multiprocessing as mp
from tqdm import tqdm

# from . import srgb_profile, StyleTransfer, WebInterface
from style_transfer import STIterate, StyleTransfer
from web_interface import WebInterface
srgb_profile = (Path(__file__).resolve().parent / 'sRGB Profile.icc').read_bytes()

def prof_to_prof(image, src_prof, dst_prof, **kwargs):
    src_prof = io.BytesIO(src_prof)
    dst_prof = io.BytesIO(dst_prof)
    return ImageCms.profileToProfile(image, src_prof, dst_prof, **kwargs)




def load_image(path, proof_prof=None):
    src_prof = dst_prof = srgb_profile
    try:
        image = Image.open(path)
        if 'icc_profile' in image.info:
            src_prof = image.info['icc_profile']
        else:
            image = image.convert('RGB')
        if proof_prof is None:
            if src_prof == dst_prof:
                return image.convert('RGB')
            return prof_to_prof(image, src_prof, dst_prof, outputMode='RGB')
        proof_prof = Path(proof_prof).read_bytes()
        cmyk = prof_to_prof(image, src_prof, proof_prof, outputMode='CMYK')
        return prof_to_prof(cmyk, proof_prof, dst_prof, outputMode='RGB')
    except OSError as err:
        print_error(err)
        sys.exit(1)


def save_pil(path, image):
    try:
        kwargs = {'icc_profile': srgb_profile}
        if path.suffix.lower() in {'.jpg', '.jpeg'}:
            kwargs['quality'] = 95
            kwargs['subsampling'] = 0
        elif path.suffix.lower() == '.webp':
            kwargs['quality'] = 95
        image.save(path, **kwargs)
    except (OSError, ValueError) as err:
        print_error(err)
        sys.exit(1)


def save_tiff(path, image):
    tag = ('InterColorProfile', TIFF.DATATYPES.BYTE, len(srgb_profile), srgb_profile, False)
    try:
        with TiffWriter(path) as writer:
            writer.save(image, photometric='rgb', resolution=(72, 72), extratags=[tag])
    except OSError as err:
        print_error(err)
        sys.exit(1)


def save_image(path, image):
    path = Path(path)
    tqdm.write(f'Writing image to {path}.')
    if isinstance(image, Image.Image):
        save_pil(path, image)
    elif isinstance(image, np.ndarray) and path.suffix.lower() in {'.tif', '.tiff'}:
        save_tiff(path, image)
    else:
        raise ValueError('Unsupported combination of image type and extension')


def get_safe_scale(w, h, dim):
    """Given a w x h content image and that a dim x dim square does not
    exceed GPU memory, compute a safe end_scale for that content image."""
    return int(pow(w / h if w > h else h / w, 1/2) * dim)


def setup_exceptions():
    try:
        from IPython.core.ultratb import FormattedTB
        sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral')
    except ImportError:
        pass


def fix_start_method():
    if platform.system() == 'Darwin':
        mp.set_start_method('spawn')


def print_error(err):
    print('\033[31m{}:\033[0m {}'.format(type(err).__name__, err), file=sys.stderr)


class Callback:
    def __init__(self, st, args, image_type='pil', web_interface=None):
        self.st = st
        self.args = args
        self.image_type = image_type
        self.web_interface = web_interface
        self.iterates = []
        self.progress = None

    def __call__(self, iterate):
        self.iterates.append(asdict(iterate))
        if iterate.i == 1:
            self.progress = tqdm(total=iterate.i_max, dynamic_ncols=True)
        msg = 'Size: {}x{}, iteration: {}, loss: {:g}'
        tqdm.write(msg.format(iterate.w, iterate.h, iterate.i, iterate.loss))
        self.progress.update()
        if self.web_interface is not None:
            self.web_interface.put_iterate(iterate, self.st.get_image_tensor())
        if iterate.i == iterate.i_max:
            self.progress.close()
            if max(iterate.w, iterate.h) != self.args.end_scale:
                save_image(self.args.output, self.st.get_image(self.image_type))
            else:
                if self.web_interface is not None:
                    self.web_interface.put_done()
        elif iterate.i % self.args.save_every == 0:
            save_image(self.args.output, self.st.get_image(self.image_type))

    def close(self):
        if self.progress is not None:
            self.progress.close()

    def get_trace(self):
        return {'args': self.args.__dict__, 'iterates': self.iterates}


def main():
    setup_exceptions()
    fix_start_method()

    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    def arg_info(arg):
        defaults = StyleTransfer.stylize.__kwdefaults__
        default_types = StyleTransfer.stylize.__annotations__
        return {'default': defaults[arg], 'type': default_types[arg]}

    p.add_argument('content', type=str, help='the content image')
    p.add_argument('styles', type=str, nargs='+', metavar='style', help='the style images')
    p.add_argument('--output', '-o', type=str, default='out.png',
                   help='the output image')
    p.add_argument('--style-weights', '-sw', type=float, nargs='+', default=None,
                   metavar='STYLE_WEIGHT', help='the relative weights for each style image')
    p.add_argument('--devices', type=str, default=[], nargs='+',
                   help='the device names to use (omit for auto)')
    p.add_argument('--random-seed', '-r', type=int, default=0,
                   help='the random seed')
    p.add_argument('--content-weight', '-cw', **arg_info('content_weight'),
                   help='the content weight')
    p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'),
                   help='the smoothing weight')
    p.add_argument('--min-scale', '-ms', **arg_info('min_scale'),
                   help='the minimum scale (max image dim), in pixels')
    p.add_argument('--end-scale', '-s', type=str, default='512',
                   help='the final scale (max image dim), in pixels')
    p.add_argument('--iterations', '-i', **arg_info('iterations'),
                   help='the number of iterations per scale')
    p.add_argument('--initial-iterations', '-ii', **arg_info('initial_iterations'),
                   help='the number of iterations on the first scale')
    p.add_argument('--save-every', type=int, default=50,
                   help='save the image every SAVE_EVERY iterations')
    p.add_argument('--step-size', '-ss', **arg_info('step_size'),
                   help='the step size (learning rate)')
    p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'),
                   help='the EMA decay rate for iterate averaging')
    p.add_argument('--init', **arg_info('init'),
                   choices=['content', 'gray', 'uniform', 'style_mean'],
                   help='the initial image')
    p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'),
                   help='the relative scale of the style to the content')
    p.add_argument('--style-size', **arg_info('style_size'),
                   help='the fixed scale of the style at different content scales')
    p.add_argument('--pooling', type=str, default='max', choices=['max', 'average', 'l2'],
                   help='the model\'s pooling mode')
    p.add_argument('--proof', type=str, default=None,
                   help='the ICC color profile (CMYK) for soft proofing the content and styles')
    p.add_argument('--web', default=False, action='store_true', help='enable the web interface')
    p.add_argument('--host', type=str, default='0.0.0.0',
                   help='the host the web interface binds to')
    p.add_argument('--port', type=int, default=8080,
                   help='the port the web interface binds to')
    p.add_argument('--browser', type=str, default='', nargs='?',
                   help='open a web browser (specify the browser if not system default)')

    args = p.parse_args()

    content_img = load_image(args.content, args.proof)
    style_imgs = [load_image(img, args.proof) for img in args.styles]

    image_type = 'pil'
    if Path(args.output).suffix.lower() in {'.tif', '.tiff'}:
        image_type = 'np_uint16'

    devices = [torch.device(device) for device in args.devices]
    if not devices:
        devices = [torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')]
    if len(set(device.type for device in devices)) != 1:
        print('Devices must all be the same type.')
        sys.exit(1)
    if not 1 <= len(devices) <= 2:
        print('Only 1 or 2 devices are supported.')
        sys.exit(1)
    print('Using devices:', ' '.join(str(device) for device in devices))

    if devices[0].type == 'cpu':
        print('CPU threads:', torch.get_num_threads())
    if devices[0].type == 'cuda':
        for i, device in enumerate(devices):
            props = torch.cuda.get_device_properties(device)
            print(f'GPU {i} type: {props.name} (compute {props.major}.{props.minor})')
            print(f'GPU {i} RAM:', round(props.total_memory / 1024 / 1024), 'MB')

    end_scale = int(args.end_scale.rstrip('+'))
    if args.end_scale.endswith('+'):
        end_scale = get_safe_scale(*content_img.size, end_scale)
    args.end_scale = end_scale

    web_interface = None
    if args.web:
        web_interface = WebInterface(args.host, args.port)
        atexit.register(web_interface.close)

    for device in devices:
        torch.tensor(0).to(device)
    torch.manual_seed(args.random_seed)

    print('Loading model...')
    st = StyleTransfer(devices=devices, pooling=args.pooling)
    callback = Callback(st, args, image_type=image_type, web_interface=web_interface)
    atexit.register(callback.close)

    url = f'http://{args.host}:{args.port}/'
    if args.web:
        if args.browser:
            webbrowser.get(args.browser).open(url)
        elif args.browser is None:
            webbrowser.open(url)

    defaults = StyleTransfer.stylize.__kwdefaults__
    st_kwargs = {k: v for k, v in args.__dict__.items() if k in defaults}
    try:
        st.stylize(content_img, style_imgs, **st_kwargs, callback=callback)
    except KeyboardInterrupt:
        pass

    output_image = st.get_image(image_type)
    if output_image is not None:
        save_image(args.output, output_image)
    with open('trace.json', 'w') as fp:
        json.dump(callback.get_trace(), fp, indent=4)


if __name__ == '__main__':
    main()

  - 딥러닝을 공부하면서 생각이 든 것은 1) 사용할 수 있는 라이브러리를 이해, 2) 라이브러리를 활용하여 학습된 모델을 만드는 것, 3) 이미 만들어진 모델을 활용하여 필요한 기능을 구현해내는 것, 이 3가지 중에서 현재는 3)번을 이용해 기능을 어떻게 구현하는지 학습하고 있는데 학습 순서가 맞는지에 대한 의문이다. 

  - 학습된 모델이 어떤식으로 사용되고 구동되는지 먼저 잘 알고, 나중에 모델을 학습시킬 때 이에 대해서 생각을 하고 연구하게 될까?라는 생각이 들었다. 현재로썬 프로젝트에 딥러닝 말고도 Backend 부분도 매우 중요하므로 시간할애를 많이 하지 못하고 있는 실정이다.

  - 나중에 기회가 되면 딥러닝의 다양한 모델을을 직접 써보고 관심이 가는 분야를 선택해서 공부해보는 것도 필요할 것 같다.


  - 이번에 배웠던 기능은 학습된 모델(특정 화가의 화풍을 적용할 수 있도록 만들어진 모델)을 사용하여 임의의 그림을 특정 화풍으로 변환해주는 기능이다. 이를 거꾸로 사용할 수 있을까?하는 생각이 들었다. 예를들면 아래 2가지였다.

  1) 특정 화가가 그린 그림을 실물 사진처럼 변환

  2) 특정 화가가 그린 그림이 화풍만 보고 어떤 화가가 그린 그림인지 추론

  - 이미 학습된 모델을 사용한다면 변환을 거꾸로 사용하는 것이라고 생각했는데, 생각하는 것 보다 복잡하거나, 아닐 수 있다.

  1) 첫번째 케이스는 특정 화가가 그린 그림을 실물 사진처럼 변환한다면, '그림을 실물 사진처럼 변환시키는 모델'이 필요하지, '일반 그림이나 사진을 특정 화풍으로 변환시키는 모델'이 필요하지 않을 수 있다.

  2) 두번째 케이스의 로직을 생각해보면 특정 그림을 갖고있는 여러 모델들을 돌려가며 화풍적용 변환을 하고 기존 그림과 변환된 그림을 비교하면서 화풍이 '가장 덜 바뀐' 그림을 찾는 것인데, '가장 덜 바뀐' 이라는 기준의 정도를 판단하고, 어떤 것이 바뀌었고 바뀌지 않았는지 판별하는 부분이 가능할 지 의문이다.

 

 - 아무튼 딥러닝을 공부하면서 직접 사용하는 법을 익히는데 집중하고 있지만 관심이 생기니 다양한 의문점도 들고 있다. 딥러닝 기술은 최근 빠른 속도로 발전하고 생겨나고 있기 때문에 이 흐름을 잘 타고 차근차근 영역을 넓혀나가는 것이 필요할 것 같다.

 

 - 모든 기술의 종착점은 서비스, 즉 사용하는 사람의 편의성 제공이라고 생각된다. 딥러닝 기술을 학문으로서 연구하는 것도 중요하지만 이러한 궁극적 목표를 항상 생각하고 개발하는 개발자가 되고싶다.

 

 

 

반응형

댓글