[Paper review] Sampling-Bias Corrected Neural Modeling for Large Corpus Item Recommendations

chrisjune

--

오늘은 2019년 구글에서 발표한 Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations 논문에 대해서 알아보도록 하겠습니다. 논문에서 핵심이 되는 키워드 중심으로 정리하였습니다.

Introduction

이 논문은 유튜브 추천 시스템의 일부를 좀 더 자세하게 소개하였습니다. 유튜브에서는 추천 모델로 Two tower모델을 사용하였는데, 이는 SOTA모델의 복잡함을 개선할 수 있고 다양한 피쳐를 활용하여 콜드스타트 문제를 해결할 수 있기 때문입니다.

Two tower DNN model

Two tower모델은 기존 NLP에서 많이 사용하였고, Siamese neural network와 유사한 구조를 가지고 있습니다. 모델 학습시 Query tower와 Candidate tower embedding을 학습하고, 타워마다 임베딩을 Multi layer perceptron을 통과시켜 마지막에 내적을 통해 score를 계산하도록 하였습니다.

모델을 실제 서빙할 때는 Candidate embedding을 indexing하여 Input feature만 실시간으로 embedding하여 inference할 수 있습니다.

https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45530.pdf

유튜브에서는 영상 추천 과정을 두 단계로 나누어 수 백만개에서 수 백개로 추출하고, 이후 수백개를 유저에 맞게 정렬하도록 제안하였습니다. Candidate generation model로 몇 백개의 영상을 추출하고, Ranking model로 정렬하는 과정을 거치도록 하였습니다. 여기서 제안한 모델은 Candidate generation 모델입니다.

논외로 Meta에서는 Ranking 단계에 2nd stage ranking이라는 단계를 추가하여 클릭, 좋아요, 그만보기할 가중치와 각각의 확률로 한번더 정렬하는 과정을 거쳤습니다.

https://engineering.fb.com/2023/08/09/ml-applications/scaling-instagram-explore-recommendations-system/

In-batch negative sampling

https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/b9f4e78a8830fe5afcf2f0452862fb3c0d6584ea.pdf

유튜브와 같은 유저의 Positive한 액션 이벤트만 발생하는 데이터는 Negative데이터가 없습니다. 학습 데이터로 Impression event 데이터를 사용하지 않고 Click event를 사용하였습니다. Impression에서 노출되는 비디오중 Click한 비디오만 Positive, Not click한 비디오를 Negative로 볼 수 있습니다. 하지만 논문에서는 Impression 비디오도 이유가 있기 때문에 노출되었기 때문에 Not click한 비디오를 무조건 Negative로 보기 어려워 사용하지 않았습니다. (It is hard to tell negatives.)

여기서는 In-batch negative 샘플링을 사용하여 Label을 지정하였습니다. batch 안에서 interaction이 있는 user와 video의 pair는 positive, interaction이 없는 모든 pair들을 negative로 보는 것입니다. 예를 들어 A는 상품1을 좋아하고 B는 상품2를 좋아하는 데이터만 있으면, A는 상품2를 싫어하고 B는 상품1을 싫어한다고 가정하고 학습하는 것입니다.

코드로 짜면 Query embedding과 Item embedding의 백터 내적곱의 Label을 Batch size의 identity matrix로 정의할 수 있습니다.

scores = tf.linalg.matmul(
query_embeddings, candidate_embeddings, transpose_b=True)

num_queries = tf.shape(scores)[0]
num_candidates = tf.shape(scores)[1]

labels = tf.eye(num_queries, num_candidates)

Log Q Correction

실제 데이터가 따르는 확률 분포는 P라고 하지만, 샘플링이 따르는 분포는 Q라고 관용적으로 표현합니다.

in-batch negative sampling은 학습 속도가 빠르고, 많은 negative를 쉽게 구할 수 있는 큰 장점이 있습니다. 하지만 인기 상품들이 negative로 많이 선택되어 Penalized되는 문제가 있습니다. 이러한 문제를 해결하려는 것이 Log-Q correction이라고 부릅니다.

계산된 logit에 샘플링 확률을 빼주면 인기 있는 상품의 logit은 작아지고, 인기 있는 상품의 logit은 커지게 됩니다. 따라서 인기있는 상품의 label = 0으로 되는 학습의 영향을 줄일 수 있게 됩니다.

여기서 pj는 랜덤배치에서 item j의 샘플링 확률입니다.

샘플링 확률의 종류와 방법

1. Naive count ratio

상품이 데이터에서 등장하는 확률을 그대로 사용하는 방법입니다. 직관적이고 실제 학습시에 효과적으로 사용가능합니다. 제 경우 단순빈도 확률을 사용하는 것이 다른 샘플링 확률보다 offline metric에서 훨씬더 높은 성능을 내는데 도움이 되었습니다.

def get_candidate_probability(df, target_column_name):
id_frequency = df[target_column_name].value_counts()

# Calculate total number of target_product_ids
total_id_counts = len(df[target_column_name])

# Calculate probabilities for each unique product_id
id_probabilities = id_frequency / total_id_counts

id_probabilities = id_probabilities / np.sum(id_probabilities)

2. Streaming Frequency Estimation

본 논문에서는 제시한 샘플링 확률을 아래와 같이 제안합니다. 배치가 진행이 되면서 빈도를 체크하고 이를 누적하여 영향도를 샘플링 확률로 사용하는 방법입니다.

위 pseudo 코드를 torch로 구현한 코드입니다.

class StreamingLogQCorrectionModule(nn.Module):
def __init__(
self,
num_buckets: int,
hash_offset: int,
alpha: float,
p_init: float,
):
super().__init__()
self.num_buckets = num_buckets
self.hash_offset = hash_offset
self.alpha = alpha
self.register_buffer('b', (1.0 / p_init) * torch.ones((num_buckets,), dtype=torch.float32))
self.register_buffer('a', torch.zeros((num_buckets,), dtype=torch.long))

def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
h = self.hash_fn(document_ids.view(-1))
return - self.b[h].log().reshape(*document_ids.shape)

def hash_fn(self, document_ids: torch.LongTensor) -> torch.LongTensor:
return (document_ids + self.hash_offset) % self.num_buckets

def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
h = self.hash_fn(document_ids).unique()
self.b[h] = (1 - self.alpha) * self.b[h] + self.alpha * (batch_idx - self.a[h]).float()
self.a[h] = batch_idx

StreamingLogQCorrection 모듈을 사용하는 일반적인 방법은 Cascade 모듈을 활용하여 사용합니다.

class CascadedStreamingLogQCorrectionModule(nn.Module):
def __init__(
self,
num_buckets: int,
hash_offsets: Tuple[int, ...],
alpha: float,
p_init: float,
):
super().__init__()
self.models = nn.ModuleList([
StreamingLogQCorrectionModule(num_buckets, offset, alpha, p_init)
for offset in hash_offsets
])

def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
result = torch.empty((0,), device=document_ids.device)
for i, mod in enumerate(self.models):
if i == 0:
result = mod(document_ids)
else:
result = torch.minimum(result, mod(document_ids))
return result

def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
for mod in self.models:
mod.train_step(document_ids, batch_idx)

이와 다른 방법으로 Locality Sensitive Hash 방법도 있지만 본 논문에서 벗어난 내용이라 링크로 남깁니다.

Neural retrieval system for youtube

Two tower모델을 활용하여 유튜브의 추천모델을 소개하였습니다.

Training label

label이 없는 implicit data이기 때문에 비디오를 모두 시청한 경우를 1, 조금 시청한 경우를 0으로 정의하여 Positive한 데이터만 사용하였습니다.

Query tower (user feature)

시청한 비디오 이력→ Bag of words로 인코딩→ 각각의 Video id embedding을 평균

Candidate tower (item feature)

video_id, chanel_id등의 embedding layer를 사용하였습니다.

학습은 매일 한번 이루어 지고 급상승 하는 인기영상등을 포착하기 위하여 데이터 셔플링은 하지 않습니다. Candidate model은 추천 후보군이 뽑혀야하기 때문에 recall을 중요한 메트릭으로, ranking model은 상위에 가장 연관된 영상이 위치해야 하기 때문에 ndgc와 hit rate의 메트릭이 중요합니다.

Experiments & Conclusion

streaming frequency probability를 활용하여 sampling biased가 조정한 모델이 offline metric과 online metric에서 모두 좋은 성능을 보여주었다고 합니다.

논문에서는 대규모 Retrieval 모델링 프레임워크를 제시하였습니다. 그리고 추천 비디오별 빈도수를 추정하는 알고리즘을 제시하였습니다. 유튜브 실시간 실험에 적용한 결과 유저의 Engagement가 향상되었습니다.

저도 Two tower 모델로 추천 모델을 학습시킬때 candidate sampling probability를 고려하여 학습한 경우 recall이 훨씬 더 높은 결과를 얻을 수 있었습니다. 조금이나마 도움이 되셨길 바랍니다.

Reference

https://www.tensorflow.org/recommenders/api_docs/python/tfrs/layers/loss/SamplingProbablityCorrection

--

--

No responses yet