Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Last active January 13, 2022 11:39
Show Gist options
  • Save NegatioN/ae44f7004d81c39ecdf567728ba8b6ab to your computer and use it in GitHub Desktop.
Save NegatioN/ae44f7004d81c39ecdf567728ba8b6ab to your computer and use it in GitHub Desktop.
MVP Inconsistent output shape
'''
#Conda environment file (env.yml). Install with `conda env create -f env.yml`
channels:
- nvidia
- rapidsai
- anaconda
- conda-forge
dependencies:
- python=3.8
- pip
- nomkl
- pylint
- pandas==1.2.*
- pillow==8.2.0
- python-confluent-kafka==1.6.*
- ipykernel==6.2.*
- psutil==5.8.*
- google-cloud-sdk==342.0.*
- tqdm==4.61.*
- prometheus_client==0.10.*
- fire==0.3.*
- pyspark==3.1.2
- cudf==21.08.03
- nvtabular==0.8.0
- cudatoolkit==11.0.*
- pip:
- crcmod==1.7
- joblib==0.11
- annoy==1.16.3
- records==0.5.3
- psycopg2-binary==2.8.6
- pynvml
- fastcore
- transformers4rec==0.1.4
- https://download.pytorch.org/whl/cu113/torch-1.10.0%2Bcu113-cp38-cp38-linux_x86_64.whl
- torchmetrics
name: transformers4rec
'''
'''
./schema/mvp_schema.pbtxt
feature {
name: "item_id-list_seq"
type: INT
value_count {
min: 2
max: 20
}
int_domain {
name: "item_id/list"
min: 0
max: 161452
is_categorical: true
}
annotation {
tag: "categorical"
tag: "list"
tag: "item_id"
tag: "item"
}
}
'''
from merlin_standard_lib import Schema
from transformers4rec import torch as tr
import torch
SESSIONS_MAX_LENGTH = 20
schema = Schema().from_proto_text('./schema/mvp_schema.pbtxt').select_by_name(['item_id-list_seq'])
d_model = 320
input_module = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=SESSIONS_MAX_LENGTH,
aggregation="concat",
d_output=d_model,
masking="mlm",
)
prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
model_config = tr.XLNetConfig.build(d_model=d_model, n_head=8, n_layer=2, total_seq_length=SESSIONS_MAX_LENGTH)
model = model_config.to_torch_model(input_module, prediction_task)
items = torch.as_tensor([5, 7, 8]).unsqueeze(0)
item_tensor = {'item_id-list_seq': items}
for x in range(10):
ans = model(item_tensor)
print(ans['labels'].size())
'''
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([2])
torch.Size([2])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([2])
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment