Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9f2d2d6

Browse filesBrowse files
committed
Reworked w/ pre-allocated matrices, verrrrrrrry slow
1 parent 0c3ccc6 commit 9f2d2d6
Copy full SHA for 9f2d2d6

File tree

Expand file treeCollapse file tree

2 files changed

+30
-67
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+30
-67
lines changed

‎test/torchtext_unittest/prototype/test_generate.py

Copy file name to clipboardExpand all lines: test/torchtext_unittest/prototype/test_generate.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_hf_DELETE(self) -> None:
9090
test_sequence_tk,
9191
max_len=100,
9292
pad_idx=t5.config.pad_token_id,
93-
num_beams=10,
93+
num_beams=7,
9494
beam_size_token=t5.config.vocab_size,
9595
)
9696
end = time.time() - start

‎torchtext/prototype/generate.py

Copy file name to clipboardExpand all lines: torchtext/prototype/generate.py
+29-66Lines changed: 29 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def beam_search(
223223
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
224224
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
225225

226+
num_sequences = input_ids.shape[0]
227+
228+
# Pre-allocate everything
229+
token_idxs = torch.full((num_sequences, num_beams, 1), eos_idx).to(dtype=torch.long, device=device)
230+
beam_idxs = torch.zeros((num_sequences, num_beams, 1)).to(dtype=torch.long, device=device)
231+
226232
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep):
227233
# `emissions` and `N` are unused in this current implementation
228234

@@ -231,16 +237,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231237
# For first timestep, create previous step token_idxs and model_states
232238
if timestep == 0:
233239
prev_step_token_idxs = [-1]
234-
prev_step_model_states = [
235-
create_emitting_model_state(
236-
Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None)
237-
)
238-
]
239240

240241
encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None
241-
prev_model_state_sequences = [
242-
get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states
243-
]
244242
out_probs, model_states = [], []
245243

246244
start = 0
@@ -256,66 +254,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
256254
if end > curr_beam_size:
257255
end = curr_beam_size
258256

259-
num_samples = end - start
260-
261257
if prev_step_token_idxs != [-1]:
262-
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
263-
token_indices = (
264-
torch.Tensor(prev_step_token_idxs[start:end])
265-
.to(dtype=torch.long, device=device)
266-
.reshape(num_samples, 1)
267-
)
268-
269-
state_and_tokens = torch.cat(
270-
[state_sequences, token_indices], dim=-1
271-
) # [batch_size x (timestep + 1)]
272-
assert state_and_tokens.shape == (
273-
num_samples,
274-
timestep + 1,
275-
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
258+
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=device)
259+
token_idxs[i, : len(token_indices), 0] = token_indices
260+
curr_token_idxs = token_idxs[i, :, 0].reshape(num_beams, 1)
276261
else:
277-
assert len(prev_model_state_sequences) == 1
278-
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(
279-
num_beams, -1
280-
) # TODO: Make this more robust
281-
282-
# Cleanup -- combine this with the above
283-
if self.is_encoder_decoder:
284-
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
285-
# This is a view-only operation and doesn't copy
286-
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
287-
num_samples if timestep > 0 else num_beams, -1, -1
288-
)
262+
if self.is_encoder_decoder:
263+
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
264+
# This is a view-only operation and doesn't copy
265+
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
266+
num_beams, -1, -1
267+
)
268+
curr_token_idxs = torch.zeros((num_beams, 1)).to(dtype=torch.long, device=device)
269+
289270

290271
# Preprocess inputs for generation
291272
model_inputs = self.model.prepare_inputs_for_generation(
292-
token_indices, **model_kwargs
273+
curr_token_idxs, **model_kwargs
293274
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
294275
if self.is_huggingface_model:
295276
model_inputs.update(self._huggingface_model_input_values)
296277
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
297-
beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
298-
299-
# We could store this in model_kwargs
300-
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]
301-
302-
num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
303-
if num_finished_hyps_in_step > 0:
304-
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)
305-
306-
beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1)
307-
308-
reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
309-
310-
if num_finished_hyps_in_step > 0:
311-
sliced_cache = ()
312-
for states in reordered_cached:
313-
sliced_state = ()
314-
for state in states:
315-
sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],)
316-
sliced_cache = sliced_cache + (sliced_state,)
317-
reordered_cached = sliced_cache
278+
beam_indices = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
279+
beam_idxs[i, : len(prev_step_hyp_idxs), 0] = beam_indices
280+
curr_beam_idxs = beam_idxs[i, :, 0]
318281

282+
reordered_cached = self.model._reorder_cache(model_kwargs["past"], curr_beam_idxs)
319283
model_inputs["past_key_values"] = reordered_cached
320284

321285
# Forward pass
@@ -329,18 +293,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
329293
if self.is_huggingface_model:
330294
self._update_model_kwargs_for_generation(outputs, model_kwargs)
331295

296+
# Reset
297+
token_idxs[i, :, 0] = eos_idx
298+
beam_idxs[i, :, 0] = 0
299+
332300
# Keep track of probabilities over vocab for this pairing
333-
# TODO: fix how we track the number here?
334-
for i in range(lm_scores.shape[0]):
301+
for i in range(num_beams):
335302
sample_lm_scores = lm_scores[i, -1]
336303
out_probs.append(sample_lm_scores.tolist())
337304
# Keep track of sequence and decoder hidden states
338305
model_states.append(
339306
create_emitting_model_state(
340307
Seq2SeqModelState(
341308
timestep=timestep,
342-
sequence=state_and_tokens[i].unsqueeze(0),
343-
lm_scores=sample_lm_scores,
309+
sequence=[],
310+
lm_scores=0,
344311
)
345312
)
346313
)
@@ -386,10 +353,6 @@ def is_not_neg_one(elem: int) -> bool:
386353
if not self.is_encoder_decoder:
387354
final_tokens = input_ids[timestep].tolist() + final_tokens
388355

389-
# Makeshift padding so that we can stack the tensors
390-
while len(final_tokens) < max_len:
391-
final_tokens += [0]
392-
393356
# Convert from list to tensors
394357
final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long)
395358

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.