ANN(Approximate Nearest Neighbor) Vector search model인 ScaNN을 사용하여 Tensorflow 모델 서빙과 인퍼런스를 빠르게 할 수 있는 방법을 공유합니다.
ANN(Approximate Nearest Neighbor) 모델에는 여러 종류가 있으며, 그중에서도 Annoy와 Faiss가 많이 알려져 있습니다. 저는 TensorFlow Recommenders(TFRS)를 사용해왔고, TFRS 내에서 동작하는 ANN 모델로 ScaNN을 활용했습니다. 그러나 TFRS 의존성을 줄이고자 ScaNN을 별도로 구현하여, 학습된 TensorFlow 모델에서 ScaNN을 사용하는 방법을 공유하려 합니다.
Faiss로 학습 결과를 추출한 뒤 Flask나 FastAPI로 서빙할 수도 있지만, TF-serving의 추론 속도가 10배 이상 빠르기 때문에 이를 활용하여 별도의 서빙용 모델을 제작했습니다.
핵심은 학습 모델과 서빙 모델을 분리하고, 서빙 모델에 ScaNN layer를 추가하는 것입니다.
Dual Encoder 모델 정의 예시
학습할 모델 정의
먼저, Dual Encoder (Two-Tower) 모델을 정의하고 학습합니다.
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dot, Dense, Flatten
from tensorflow.keras.models import Model
def create_model(ids):
# Define the embedding dimensions
embedding_dim = 32
user_vocab_size = 1000 # Adjust according to your dataset
product_vocab_size = 1000 # Adjust according to your dataset
# User tower
user_input = Input(shape=(1,), name='user_id')
user_embedding = Embedding(input_dim=user_vocab_size, output_dim=embedding_dim, name='user_embedding')(user_input)
user_embedding = Flatten()(user_embedding)
user_embedding = Dnese(embedding_dim, name='user_embedding')(user_embedding)
# Candidate (product) tower
product_input = Input(shape=(1,), name='product_id')
product_embedding = Embedding(input_dim=product_vocab_size, output_dim=embedding_dim, name='product_embedding')(product_input)
product_embedding = Flatten()(product_embedding)
product_embedding = Dnese(embedding_dim, name='product_embedding')(product_embedding)
# Create the model
model = Model(inputs=[user_input, product_input], outputs=[user_embedding, product_embedding])
return model
ScaNN 라이브러리 설치 및 ScaNN 클래스 정의
ScaNN 라이브러리를 설치하고, 이를 기반으로 Searcher를 사용하는 클래스를 정의합니다. 해당 코드는 TFRS ScaNN 클래스를 참고하였습니다.
> pip install scann
import scann
import tensorflow as tf
class ScaNN(tf.keras.layers.Layer):
def __init__(self,
num_neighbors=10,
distance_measure='dot_product',
num_leaves=500,
num_leaves_to_search=50,
training_iterations=10,
dimensions_per_block=2,
num_reordering_candidates=550,
parallelize_batch_searches=True,
name=None):
super().__init__(name=name)
self.num_neighbors = num_neighbors
self.distance_measure = distance_measure
self.num_leaves = num_leaves
self.num_leaves_to_search = num_leaves_to_search
self.training_iterations = training_iterations
self.dimensions_per_block = dimensions_per_block
self.num_reordering_candidates = num_reordering_candidates
self.parallelize_batch_searches = parallelize_batch_searches
self._serialized_searcher = None
self._identifiers = None
def index(self, candidates, identifiers=None):
if identifiers is None:
identifiers = tf.range(tf.shape(candidates)[0])
# Build the ScaNN searcher
builder = scann.scann_ops.builder(
candidates,
num_neighbors=self.num_neighbors,
distance_measure=self.distance_measure
)
builder = builder.tree(
num_leaves=self.num_leaves,
num_leaves_to_search=self.num_leaves_to_search,
training_iterations=self.training_iterations
)
builder = builder.score_ah(dimensions_per_block=self.dimensions_per_block)
if self.num_reordering_candidates is not None:
builder = builder.reorder(self.num_reordering_candidates)
searcher = builder.build()
# Serialize the searcher
self._serialized_searcher = searcher.serialize_to_module()
self._identifiers = tf.Variable(identifiers, trainable=False, name='identifiers')
@tf.function
def call(self, queries, k=None):
if k is None:
k = self.num_neighbors
if self._serialized_searcher is None:
raise ValueError("The `index` method must be called before querying.")
# Deserialize the searcher
searcher = scann.scann_ops.searcher_from_module(self._serialized_searcher)
# Perform the search
if tf.rank(queries) == 2:
if self.parallelize_batch_searches:
result = searcher.search_batched_parallel(queries, final_num_neighbors=k)
else:
result = searcher.search_batched(queries, final_num_neighbors=k)
indices = result.indices
distances = result.distances
else:
result = searcher.search(queries, final_num_neighbors=k)
indices = result.index
distances = result.distance
# Map indices back to identifiers
identifiers = tf.gather(self._identifiers, indices)
return distances, identifiers
서빙을 위한 모델 재정의
학습된 모델의 임베딩과 ScaNN index를 활용하여 서빙 모델을 정의합니다.
from scann_model import ScaNN
import tensorflow as tf
class RetrievalModel(tf.keras.Model):
def __init__(self, query_model, scann_layer):
super().__init__()
self.query_model = query_model
self.scann_layer = scann_layer
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.int32)])
def call(self, product_id):
# Compute query embeddings
query_embeddings = self.query_model(product_id)
# Retrieve top candidates
distances, identifiers = self.scann_layer(query_embeddings)
return distances, identifiers
unique_item_ids = dataset['product_id'].unique().tolist()
candidate_model = tf.keras.Model(
inputs=model.get_layer('product_id').input,
outputs=model.get_layer('product_embedding').output
)
# Compute candidate embeddings
candidate_embeddings = candidate_model(tf.constant(unique_item_ids))
# Step 3: Build the ScaNN index
scann_layer = ScaNN(num_neighbors=30)
scann_layer.index(candidate_embeddings, identifiers=tf.constant(unique_item_ids))
query_model = tf.keras.Model(
inputs=model.get_layer('product_id').input,
outputs=model.get_layer('product_embedding').output
)
retrieval_model = RetrievalModel(query_model=query_model, scann_layer=scann_layer)
return retrieval_model
모델 저장 및 로드
# 모델 저장
path = "./scann_model"
tf.saved_model.save(
retrieval_model,
path,
signatures={'serving_default': retrieval_model.call}
)
# 모델 로드
loaded_model = tf.saved_model.load(path)
tf-serving을 통한 서빙
아래 Dockerfile을 사용해 AWS ECR 또는 Kubernetes 클러스터에 서빙할 수 있습니다.
FROM google/tf-serving-scann:2.11.0
ENV MODEL_BASE_PATH={model_path}
ENV MODEL_NAME=scann_model
ScaNN 하이퍼파라미터 Rule-of-thumb
num_neighbors
: 추론 시 필요한 TopK 개수.num_leaves
: 파티션 개수. 데이터 수의 제곱근으로 설정해 정확도와 추론 속도의 균형을 맞출 수 있습니다.num_leaves_to_search
: 추론 시 검색할 파티션 수. 값이 높을수록 정확도는 높아지나 성능이 감소할 수 있습니다.- 데이터가 2만 미만일 경우 브루트포스, 10만 미만이면 비대칭 해싱(AH)과 재스코어링 사용, 10만 이상이면 파티셔닝+AH+재스코어링 사용.
- AH 사용 시
dimensions_per_block
은 2로 고정. reordering_num_neighbors
: 재정렬할 TopK 개수,num_neighbors
보다 커야 합니다.