본문 바로가기

Study/BackEnd

[MeMind 리팩토링] 2. OCP 원칙 적용하기

 

▽ SRP 원칙을 적용하는 과정은 이전 포스팅을 참고해주세요!! ▽

 

[MeMind 리팩토링] 1. SRP 원칙 적용하기

비즈니스 로직을 포함하는 Service 디렉토리의 코드에 SRP 원칙 (단일 책임 원칙)을 적용해보려고 한다. 1. 기존 코드 분석 # service/talk.py """ ChatGPT와의 API 통신을 통해 대화하는 로직 """ from fastapi impo

it-is-my-life.tistory.com

 

1. 기존 코드 분석

 

이전 포스팅에서 SRP 원칙까지 적용 완료된 코드의
가장 큰 문제점은 비슷한 일을 하는 클래스들이 너무 많다는 것이다.

 

# 월별 Conversation 조회
class MonthlyConversationLoader:
    def __init__(self, session) -> None:
        self.session = session

    def _get_conversation_list(self, date_object, nickname) -> List[Conversation]:
        conversation_object = select(Conversation).where(
            Conversation.nickname == nickname,
            Conversation.year == date_object.year,
            Conversation.month == date_object.month
        )

        return self.session.exec(conversation_object).all()

    async def get_conversation_by_month(self, date, nickname) -> Dict[str, List[Any]]:
        """ 해당 월에 나눈 대화 목록을 리턴하는 함수 """
        date_object = DateParser.parsing_date(date)
        conversation_list = self._get_conversation_list(date_object, nickname)

        if(not conversation_list):
            raise ConversationNotFound

        return {
            "conversation_list": conversation_list
        }

# 일별 Conversation 조회
def _get_conversation_id_by_date(self, date, nickname) -> str:
        """ 해당 일자에 conversation_id를 리턴하는 함수 """
        date_object = DateParser.parsing_date(date)

        conversation_object = select(Conversation).where(Conversation.nickname == nickname).where(Conversation.year == date_object.year).where(Conversation.month == date_object.month).where(Conversation.day == date_object.day)
        conversation_object = self.session.exec(conversation_object).first()

        return conversation_object.conversation_id

# 디비에 Conversation 데이터 추가
 # conversation 생성, message load, 길이 check
    async def start_conversation(self, date, nickname) -> Dict[str, Any]:
        """ conversation이 있으면, 대화 내용을 리턴하고, 없으면 새로 conversation을 생성하는 함수 """
        conversation_id = self._get_conversation_id_by_date(date, nickname)

        # conversation이 존재하지 않는 경우, 새로 conversation 생성
        if(not conversation_id):
            conversation_id = uuid4().hex
            conversation = Conversation(
                nickname = nickname,
                conversation_id = conversation_id
            )

            self.session.add(conversation)
            self.session.commit()

            response = MessageSender().send_message_to_chatgpt()

            MessageCreator().create_message(conversation_id, 0, False, response)

        chat_history = FullMessageLoader().get_all_full_messages(conversation_id)

        return {
            "conversation_id": conversation_id,
            "chat_history": chat_history,
            "is_enough": EnoughJudge.is_enough(len(chat_history))
        }

# 디비에 Message 데이터 추가
async def create_message(self, conversation_id, order, is_from_user, message) -> None:
        """ 채팅 이력을 저장하는 함수 """
        message_object = Message(
            conversation_id = conversation_id,
            order = order,
            is_from_user = is_from_user,
            message = message
        )

        try:
            self.session.add(message_object)
            self.session.commit()
        except Exception as e:
            raise NoSuchConversationIdError

 

특히, 날짜를 받아서 월별, 일별 Conversation을 조회하는
MonthlyConversationLoader 클래스와 _get_conversation_by_date 메서드
Conversation이나 Message 객체를 디비에 올리는 부분이 개별적으로 구현되어 있는 것을 볼 수 있다.

여기에 연별 Conversation을 조회하는 새로운 기능이 필요하거나
Report 모델을 디비에 추가하는 기능이 필요할 때도
각각 새로 구현해줘야 하는 번거로움이 생긴다.

즉, 확장에 개방되어 있지 않다.

또한, 디비에 Conversation이나 Message 데이터를 추가하는 부분에서 수정 사항이 생기면
(ex 새로운 컬럼이 추가되었을 경우나 추가적인 판단 로직이 필요한 경우)
해당 로직과는 상관이 없지만, 코드를 품고 있어 결합도가 높은
start_conversation 메서드나 answer_conversation 메서드에 직접적인 수정이 이루어져야 한다.

즉, 수정에 폐쇄되어 있지 않다.

 

2. 기존 코드에 OCP 원칙 적용하기

기존 코드에 OCP 원칙을 적용하여, 앞서 발견한 문제점들을 해결해보았다.

비슷한 일을 하는 클래스나 메서드들의 상위 인터페이스를 만들고,
해당 인터페이스를 상속받아서 개별 구현하는 방식으로 수정하였다.

수정된 부분을 자세히 살펴보자.

 

# get_conversation_by_blah를 인터페이스로
class ConversationGetter(ABC):
    def __init__(self, session) -> None:
        self.session = session

    @abstractmethod
    def get_conversation(self, date, nickname):
        pass

# 월별 Conversation 조회
class MonthlyConversationLoader(ConversationGetter):
    def __init__(self, session):
        super().__init__(session)

    async def get_conversation(self, date, nickname) -> List[Conversation]:
        date_object = DateParser.parsing_date(date)
    
        # 입력받은 월에 해당하는 conversation object를 쿼리
        conversation_list = self.session.exec(select(Conversation).where(
            Conversation.nickname == nickname,
            Conversation.year == date_object.year,
            Conversation.month == date_object.month
        )).all()

        if(not conversation_list):
            raise ConversationNotFound

        return conversation_list

# 일별 Conversation 조회
class DailyConversationLoader(ConversationGetter):
    def __init__(self, session) -> None:
        super().__init__(session)

    async def get_conversation(self, date, nickname) -> List[Conversation]:
        """ 해당 일자에 conversation_id를 리턴하는 함수 """
        date_object = DateParser.parsing_date(date)

        conversation_list = self.session.exec(select(Conversation).where(
            Conversation.nickname == nickname,
            Conversation.year == date_object.year,
            Conversation.month == date_object.month,
            Conversation.day == date_object.day
        )).all()

        return conversation_list

 

 기존 각각 분리된 형태로 구현되어 있었던, 월별/일별 Conversation 조회 기능들을
상위 인터페이스를 상속받은 두개의 클래스가
다형성을 가진 추상 메서드를 개별 구현하는 방식으로 수정하였다.

위 코드에서 ConversationGetter라는 상위 인터페이스
MonthlyConversationLoader 클래스와 DailyConversationLoader 클래스가 상속받아서
get_conversation 메서드를 구현하고 있는 것을 볼 수 있다.

만약 추후에 연별 Conversation을 조회하는 기능이 필요하더라도
처음부터 구현할 필요가 없이 ConversationGetter 인터페이스를 상속받아
get_conversation 메서드를 구현하는 방식으로 쉽게 확장할 수 있다.

즉, 확장에 개방되어 있는 코드로 변화했다고 볼 수 있다.

 

다음은 데이터에 추가하는 로직을 수정한 내용이다.

 

class ObjectCreator(ABC):
	""" 디비에 데이터를 추가하는 creator 인터페이스 """
    def __init__(self, session) -> None:
        self.session = session

    @abstractmethod
    async def create_object(self, object_info):
        pass

class MessageCreator(ObjectCreator):
	""" Message 데이터를 추가하는 클래스 """
    def __init__(self, session) -> None:
        super().__init__(session)

    async def create_object(self, object_info) -> None:
        message_object = Message(
            conversation_id = object_info["conversation_id"],
            order = object_info["order"],
            is_from_user = object_info["is_from_user"],
            message = object_info["message"]
        )

        try:
            self.session.add(message_object)
            self.session.commit()
        except Exception as e:
            raise NoSuchConversationIdError

class ConversationCreator(ObjectCreator):
	""" Conversation 데이터를 추가하는 클래스 """
    def __init__(self, session) -> None:
        super().__init__(session)

    async def create_object(self, object_info) -> None:
        conversation_object = Conversation(
            nickname = object_info["nickname"],
            conversation_id = object_info["conversation_id"]
        )

        try:
            self.session.add(conversation_object)
            self.session.commit()
        except Exception as e:
            raise UserNotFound

 

기존 코드의 경우, 디비에 데이터를 추가하는 로직이 다른 메서드 안에 개별적으로 존재한다.
이렇게 되면, 코드 간의 결합도가 높아지고 여러 로직이 공존하므로 당연히 응집도는 낮아진다.
또한, 컬럼에 수정 사항이 생기거나 추가 로직이 필요한 경우에 관련 없는 메서드에 수정이 이루어지게 된다.

그래서, 기존 각각 로직 형태로 다른 메서드에 존재했던 데이터 추가 로직을
ObjectCreator라는 상위 인터페이스를 상속받은 MessageCreator와 ConversationCreator 클래스로 분리하고 디비에 데이터를 추가하는 로직은 구체적인 하위 클래스 내에서
create_object 메서드를 구현하는 방식으로 수정하였다.

이렇게 되면, 데이터를 추가하는 구체 클래스들로 분리되므로 코드의 응집도가 높아지고
수정이 필요할 경우, 하위 클래스만 수정해주면 되기 때문에
기존 로직과 start_conversation 메서드 및 answer_conversation 메서드 간의
불필요하게 결합도가 높았던 부분도 낮출 수 있게 된다.

즉, 수정에 폐쇄된 코드로 변화했다고 볼 수 있다.

 

전체 코드는 다음과 같다.

from uuid import uuid4
import os
from datetime import datetime, date
from typing import Dict, List, Any
from abc import ABC, abstractmethod

from sqlmodel import select
import openai

from models import Conversation, Message
from exceptions import ConversationNotFound, UserNotFound, NoSuchConversationIdError

class DateParser:
    @staticmethod
    def parsing_date(date) -> datetime:
        """ date 정보를 year과 month로 파싱하는 함수 """
        return datetime.strptime(date, "%Y-%m-%d")

class EnoughJudge:
    @staticmethod
    def is_enough(conversation_lenght) -> bool:
        return conversation_lenght > 13

# get_conversation_by_blah를 인터페이스로
class ConversationGetter(ABC):
    def __init__(self, session) -> None:
        self.session = session

    @abstractmethod
    def get_conversation(self, date, nickname):
        pass

class MonthlyConversationLoader(ConversationGetter):
    def __init__(self, session):
        super().__init__(session)

    async def get_conversation(self, date, nickname) -> List[Conversation]:
        date_object = DateParser.parsing_date(date)
    
        # 입력받은 월에 해당하는 conversation object를 쿼리
        conversation_list = self.session.exec(select(Conversation).where(
            Conversation.nickname == nickname,
            Conversation.year == date_object.year,
            Conversation.month == date_object.month
        )).all()

        if(not conversation_list):
            raise ConversationNotFound

        return conversation_list

class DailyConversationLoader(ConversationGetter):
    def __init__(self, session) -> None:
        super().__init__(session)

    async def get_conversation(self, date, nickname) -> List[Conversation]:
        """ 해당 일자에 conversation_id를 리턴하는 함수 """
        date_object = DateParser.parsing_date(date)

        conversation_list = self.session.exec(select(Conversation).where(
            Conversation.nickname == nickname,
            Conversation.year == date_object.year,
            Conversation.month == date_object.month,
            Conversation.day == date_object.day
        )).all()

        return conversation_list

class FullMessageLoader:
    def __init__(self, session) -> None:
        self.session = session

    async def get_all_full_messages(self, conversation_id) -> List[Message]:
        """ 해당 conversation에서 나누었던 message들을 모두 리턴하는 함수 """
        try:
            message_object = select(Message).where(Message.conversation_id == conversation_id)
            messages = self.session.exec(message_object).all()
        except Exception as e:
            raise NoSuchConversationIdError

        return messages

class MessageSender:
    def __init__(self) -> None:
        openai.api_key(os.environ["GPT_APIKEY"])
        self.premessage = [
                {"role": "system", "content": "너는 친절한 심리상담가야. 사용자에게 오늘 하루는 어땠는지 물어보고 사용자가 응답하면 더 자세히 물어봐주고 위로해주는 상담가의 역할을 해줘. 사용자에게 보내는 너의 첫 메세지는 '안녕하세요! 오늘 하루는 어땠나요?'로 고정이야. 사용자의 응답에 적절하게 반응해주고 항상 더 자세히 질문해줘야 해. 그리고 2번 이상 응답을 받으면, '충분히 이야기를 나눈 것 같네요. 오늘 하루를 평가한다면 몇 점을 주시겠어요?'라는 말로 대화를 마무리해줘"},
                {"role": "user", "content": "안녕"}
            ]

    async def send_message_to_chatgpt(self, messages = None) -> str:
        # OpenAI GPT-3.5 Turbo 모델에 대화를 요청합니다.
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages = self.premessage.extend(messages)
        )

        return response["choices"][0]["message"]["content"]
        
class ConversationStarter:
    def __init__(self, session) -> None:
        self.session = session

    # conversation 생성, message load, 길이 check
    async def start_conversation(self, date, nickname) -> Dict[str, Any]:
        """ conversation이 있으면, 대화 내용을 리턴하고, 없으면 새로 conversation을 생성하는 함수 """
        conversation_object = await DailyConversationLoader(self.session).get_conversation(date, nickname)

        # conversation이 존재하지 않는 경우, 새로 conversation 생성
        if(not conversation_object):
            object_info = {
                "nickname": nickname,
                "conversation_id": uuid4().hex
            }
            await ConversationCreator(self.session).create_object(object_info)

            response = await MessageSender().send_message_to_chatgpt()

            object_info["order"] = 0
            object_info["is_from_user"] = False
            object_info["message"] = response
            await MessageCreator(self.session).create_object(object_info)

        # 존재 하는 경우, 이전 대화 내용 로드
        chat_history = await FullMessageLoader().get_all_full_messages(object_info["conversation_id"])

        return {
            "conversation_id": object_info["conversation_id"],
            "chat_history": chat_history,
            "is_enough": EnoughJudge.is_enough(len(chat_history))
        }

class MessageGetter:
    def __init__(self, session) -> None:
        self.session = session

    async def classify_writer(self, conversation_id) -> List[Dict[str, str]]:
        """ message들을 화자에 따라서 분류하여 리턴하는 함수 """
        chat_history = []
        messages = FullMessageLoader(self.session).get_all_full_messages(conversation_id)
    
        for m in messages:
            chat_history.append({"role": "user" if m.is_from_user else "assistant", "content": m.message})

        return chat_history

class ObjectCreator(ABC):
    def __init__(self, session) -> None:
        self.session = session

    @abstractmethod
    async def create_object(self, object_info):
        pass

class MessageCreator(ObjectCreator):
    def __init__(self, session) -> None:
        super().__init__(session)

    async def create_object(self, object_info) -> None:
        message_object = Message(
            conversation_id = object_info["conversation_id"],
            order = object_info["order"],
            is_from_user = object_info["is_from_user"],
            message = object_info["message"]
        )

        try:
            self.session.add(message_object)
            self.session.commit()
        except Exception as e:
            raise NoSuchConversationIdError

class ConversationCreator(ObjectCreator):
    def __init__(self, session) -> None:
        super().__init__(session)

    async def create_object(self, object_info) -> None:
        conversation_object = Conversation(
            nickname = object_info["nickname"],
            conversation_id = object_info["conversation_id"]
        )

        try:
            self.session.add(conversation_object)
            self.session.commit()
        except Exception as e:
            raise UserNotFound

class MessageRespondent:
    def __init__(self, session) -> None:
        self.session = session

    async def answer_conversation(self, user_answer, conversation_id) -> Dict[str, Any]:
        """ 사용자의 응답에 대한 AI의 응답을 리턴하는 함수 """
        # 이전 대화 내역이 있으면 채팅에 추가합니다.
        messages = await FullMessageLoader().get_all_full_messages(conversation_id)

        # 사용자 입력을 채팅에 추가합니다.
        messages.append({"role": "user", "content": user_answer})
        order = len(messages)

        # message object를 생성하기 위한 args dict 생성
        object_info = {
            "conversation_id": conversation_id,
            "order": order,
            "is_from_user": True,
            "message": user_answer
        }
        await MessageCreator(self.session).create_object(object_info)
    
        # OpenAI GPT-3.5 Turbo 모델에 대화를 요청합니다.
        response = await MessageSender().send_message_to_chatgpt(messages)

        response_info = {
            "conversation_id": conversation_id,
            "order": order + 1,
            "is_from_user": False,
            "message": response
        }

        await MessageCreator(self.session).create_object(response_info)

        # 챗봇의 답변을 사용자 메시지와 함께 반환합니다.
        return {
            "message": response,
            "is_enough": EnoughJudge.is_enough(order)
        }

 

여기까지 OCP 원칙을 적용하는 과정을 포스팅 해보았다!!

언제나 그렇듯이 오류나 첨언은 언제든지 환영한다는 말과 함께 이번 포스팅을 마친다.