Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add multi-interaction registry support and testing
  • Loading branch information
SwordFaith committed Jun 24, 2025
commit e0699ec2a94ec9b63bccd4cb1c93887e48bdef5f
19 changes: 19 additions & 0 deletions 19 .github/workflows/sgl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ on:
- ".github/workflows/sgl.yml"
- "tests/rollout/*sglang*"
- "tests/rollout/async_rollout_utils.py"
- "tests/interactions/**"
- "tests/workers/rollout/*interaction*"

# Cancel jobs on the same ref if a new one is triggered
concurrency:
Expand Down Expand Up @@ -100,11 +102,28 @@ jobs:
- name: Download Model to Use
run: |
huggingface-cli download 'Qwen/Qwen2-7B-Instruct'
huggingface-cli download 'Qwen/Qwen2.5-0.5B'
export HF_HUB_OFFLINE=1
- name: Test GSM8K Interaction
run: |
cd tests/interactions
pytest -s test_gsm8k_interaction.py
- name: Test Interaction Registry
run: |
cd tests/interactions
pytest -s test_interaction_registry.py
- name: Test the latest SGLang
run: |
cd tests/workers/rollout
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_spmd.py
- name: Test the latest SGLang Rollout async with interaction
run: |
cd tests/workers/rollout
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_rollout_w_interaction.py
- name: Test the latest SGLang Multi Interaction
run: |
cd tests/workers/rollout
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_multi_interaction.py
- name: Test the latest SGLang Rollout async with tool
run: |
cd tests/workers/rollout
Expand Down
1 change: 1 addition & 0 deletions 1 examples/data_preprocess/gsm8k_multiturn_w_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def process_fn(example, idx):
"answer": answer_raw,
"question": question_raw,
"interaction_kwargs": {
"name": "gsm8k",
"query": question,
"ground_truth": solution,
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
interaction:
- class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction"
- name: "gsm8k"
class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction"
config: {}
19 changes: 18 additions & 1 deletion 19 tests/interactions/test_gsm8k_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ class TestGsm8kInteraction:

def setup_method(self):
"""Set up test environment before each test method."""
self.config = {}
self.config = {"name": "gsm8k"}
self.interaction = Gsm8kInteraction(self.config)

def test_init(self):
"""Test Gsm8kInteraction initialization."""
assert self.interaction._instance_dict == {}
assert self.interaction.config == self.config
assert self.interaction.name == "gsm8k"

@pytest.mark.asyncio
async def test_start_interaction_with_instance_id(self):
Expand Down Expand Up @@ -378,3 +379,19 @@ def test_inheritance_from_base_interaction(self):
assert callable(self.interaction.generate_response)
assert callable(self.interaction.calculate_score)
assert callable(self.interaction.finalize_interaction)

def test_name_attribute_initialization(self):
"""Test name attribute initialization with different configs."""
# Test with explicit name in config
config_with_name = {"name": "custom_gsm8k"}
interaction_with_name = Gsm8kInteraction(config_with_name)
assert interaction_with_name.name == "custom_gsm8k"

# Test with default name when not provided in config
config_without_name = {}
interaction_without_name = Gsm8kInteraction(config_without_name)
assert interaction_without_name.name == "interaction_agent" # Default from BaseInteraction

# Test that name is accessible as attribute
assert hasattr(self.interaction, "name")
assert self.interaction.name == "gsm8k"
177 changes: 177 additions & 0 deletions 177 tests/interactions/test_interaction_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import pytest
from omegaconf import OmegaConf

from verl.interactions.base import BaseInteraction
from verl.interactions.gsm8k_interaction import Gsm8kInteraction
from verl.interactions.utils.interaction_registry import (
get_interaction_class,
initialize_interactions_from_config,
)


class TestInteractionRegistry:
def test_get_interaction_class(self):
"""Test getting interaction class by name."""
# Test getting base interaction class
base_cls = get_interaction_class("verl.interactions.base.BaseInteraction")
assert base_cls == BaseInteraction

# Test getting gsm8k interaction class
gsm8k_cls = get_interaction_class("verl.interactions.gsm8k_interaction.Gsm8kInteraction")
assert gsm8k_cls == Gsm8kInteraction

def test_initialize_single_interaction_from_config(self):
"""Test initializing single interaction from config."""
# Create temporary config file
config_content = {"interaction": [{"name": "test_gsm8k", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}]}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
interaction_map = initialize_interactions_from_config(temp_config_path)

# Check that interaction was created
assert len(interaction_map) == 1
assert "test_gsm8k" in interaction_map
assert isinstance(interaction_map["test_gsm8k"], Gsm8kInteraction)
assert interaction_map["test_gsm8k"].name == "test_gsm8k"
finally:
os.unlink(temp_config_path)

def test_initialize_multiple_interactions_from_config(self):
"""Test initializing multiple interactions from config."""
config_content = {"interaction": [{"name": "gsm8k_solver", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}, {"name": "base_agent", "class_name": "verl.interactions.base.BaseInteraction", "config": {"custom_param": "test_value"}}]}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
interaction_map = initialize_interactions_from_config(temp_config_path)

# Check that both interactions were created
assert len(interaction_map) == 2
assert "gsm8k_solver" in interaction_map
assert "base_agent" in interaction_map

# Check types
assert isinstance(interaction_map["gsm8k_solver"], Gsm8kInteraction)
assert isinstance(interaction_map["base_agent"], BaseInteraction)

# Check names were injected
assert interaction_map["gsm8k_solver"].name == "gsm8k_solver"
assert interaction_map["base_agent"].name == "base_agent"

# Check custom config was passed
assert interaction_map["base_agent"].config.get("custom_param") == "test_value"
finally:
os.unlink(temp_config_path)

def test_initialize_interaction_without_explicit_name(self):
"""Test that interaction name is derived from class name when not specified."""
config_content = {"interaction": [{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}]}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
interaction_map = initialize_interactions_from_config(temp_config_path)

# Check that interaction name was derived from class name
assert len(interaction_map) == 1
assert "gsm8k" in interaction_map # Should be "gsm8k" after removing "interaction" suffix
assert isinstance(interaction_map["gsm8k"], Gsm8kInteraction)
assert interaction_map["gsm8k"].name == "gsm8k"
finally:
os.unlink(temp_config_path)

def test_initialize_empty_config(self):
"""Test initializing from empty config."""
config_content = {"interaction": []}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
interaction_map = initialize_interactions_from_config(temp_config_path)
assert len(interaction_map) == 0
finally:
os.unlink(temp_config_path)

def test_invalid_class_name(self):
"""Test handling of invalid class name."""
config_content = {"interaction": [{"name": "invalid", "class_name": "invalid.module.InvalidClass", "config": {}}]}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
with pytest.raises(ModuleNotFoundError):
initialize_interactions_from_config(temp_config_path)
finally:
os.unlink(temp_config_path)

def test_duplicate_interaction_names(self):
"""Test handling of duplicate interaction names."""
config_content = {
"interaction": [
{"name": "duplicate", "class_name": "verl.interactions.base.BaseInteraction", "config": {}},
{"name": "duplicate", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}},
]
}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
with pytest.raises(ValueError, match="Duplicate interaction name 'duplicate' found"):
initialize_interactions_from_config(temp_config_path)
finally:
os.unlink(temp_config_path)

def test_auto_name_generation_edge_cases(self):
"""Test automatic name generation for various class name patterns."""
config_content = {
"interaction": [
{"class_name": "verl.interactions.base.BaseInteraction", "config": {}},
{"class_name": "some.module.CustomAgent", "config": {}},
]
}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name

try:
interaction_map = initialize_interactions_from_config(temp_config_path)

# Check that names were generated correctly
assert len(interaction_map) == 2
assert "base" in interaction_map # BaseInteraction -> base
assert "customagent" in interaction_map # CustomAgent -> customagent
finally:
os.unlink(temp_config_path)
24 changes: 20 additions & 4 deletions 24 tests/workers/rollout/test_sglang_async_rollout_w_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def test_async_sglang_rollout_w_interaction():
]
]
interaction_kwargs = [
{"query": "Who won the Champions League in 2019?", "ground_truth": "Real Madrid"},
{"query": "The founder of Apple is", "ground_truth": "Steve Jobs"},
{"query": "What's the best way to learn python?", "ground_truth": "Learn python from scratch"},
{"name": "gsm8k", "query": "Who won the Champions League in 2019?", "ground_truth": "Real Madrid"},
{"name": "gsm8k", "query": "The founder of Apple is", "ground_truth": "Steve Jobs"},
{"name": "gsm8k", "query": "What's the best way to learn python?", "ground_truth": "Learn python from scratch"},
]
prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts]
input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)
Expand All @@ -82,7 +82,18 @@ def test_async_sglang_rollout_w_interaction():
device_mesh=fsdp_device_mesh,
)

rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, None)
# Create a temporary interaction config file for testing
import tempfile

from omegaconf import OmegaConf

interaction_config = {"interaction": [{"name": "gsm8k", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}]}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(interaction_config, f.name)
interaction_config_path = f.name

rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path)
rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, processing_class=tokenizer, model_hf_config=actor_model.config)

rollout_sharding_manager = FSDPSGLangShardingManager(
Expand Down Expand Up @@ -130,6 +141,11 @@ def test_async_sglang_rollout_w_interaction():
assert are_lists_similar(hf_response_tokens, sglang_response_tokens)
print("SGLang w interaction Test Passed!")

# Clean up temporary config file
import os

os.unlink(interaction_config_path)

torch.distributed.barrier()
torch.distributed.destroy_process_group()

Expand Down
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.