PAI(Platform for AI)是阿里云的托管式机器学习平台。严格来说,它并非单一产品,而是五个独立子产品共享同一控制台的集合:Notebook 用于交互式探索,分布式训练服务支撑规模化训练,模型服务平台承载生产部署,可视化流水线面向偏好拖拽操作的用户,模型库则提供开源模型的一键部署能力。经过十八个月的真实 LLM 负载验证,各组件表现不一——EAS 表现优秀,Designer 基本够用;但一旦理清它们之间的协同机制,整体效能远超各部分之和。
本文将对 PAI 进行广度优先的概览。如果你想要深度优先的内容——比如实例选型策略、DLC 抢占式实例生存指南、EAS 冷启动缓解方案——可以参考专门的 PAI 系列
,其中五篇文章分别深入剖析每个子产品。本文的目标很明确:帮你快速理解 PAI 是什么、各组件适用场景,以及如何完成端到端的模型训练与部署。
frompai.sessionimportsetup_default_sessionfrompai.workspaceimportWorkspace# Configure sessionsetup_default_session(region_id="cn-shanghai")# List available instance types for DSWimportjsonfromalibabacloud_paistudio20220112.clientimportClientfromalibabacloud_tea_openapi.modelsimportConfigconfig=Config(access_key_id="<YOUR_AK>",access_key_secret="<YOUR_SK>",region_id="cn-shanghai",endpoint="pai.cn-shanghai.aliyuncs.com")client=Client(config)# Create a DSW instancefromalibabacloud_paistudio20220112.modelsimportCreateInstanceRequestrequest=CreateInstanceRequest(instance_name="dev-notebook",ecs_spec="ecs.gn7i-c8g1.2xlarge",image_id="pytorch2.3-gpu-py310-cu124",workspace_id="<YOUR_WORKSPACE_ID>",datasets=[{"dataset_id":"<OSS_DATASET_ID>","mount_path":"/mnt/data"}])response=client.create_instance(request)print(f"Instance ID: {response.body.instance_id}")
# Inside a DSW terminal -- verify OSS mountls /mnt/data/
# Should show your OSS bucket contents# Save a checkpoint to OSS (it persists)python train.py --output_dir /mnt/data/checkpoints/run-001/
# Bad: saving to /root/ (will be lost on restart)python train.py --output_dir /root/checkpoints/ # DON'T
若想深入了解 DSW —— 包括镜像选择、SSH 隧道、GPU 显存分析等 —— 请参阅 PAI Part 2
。
# In your training script -- checkpoint every 500 stepsifglobal_step%500==0:save_path=f"/mnt/data/checkpoints/step-{global_step}/"model.save_pretrained(save_path)tokenizer.save_pretrained(save_path)torch.save(optimizer.state_dict(),f"{save_path}/optimizer.pt")torch.save(lr_scheduler.state_dict(),f"{save_path}/scheduler.pt")print(f"Checkpoint saved at step {global_step}")# On resume -- find the latest checkpointimportglobcheckpoints=sorted(glob.glob("/mnt/data/checkpoints/step-*"))ifcheckpoints:latest=checkpoints[-1]model=AutoModelForCausalLM.from_pretrained(latest)print(f"Resumed from {latest}")
若想全面掌握 DLC —— 包括 RDMA 配置、DeepSpeed ZeRO 设置、抢占式中断处理等 —— 请参阅 PAI Part 3
。
# Test with Pythonimportrequestsendpoint="https://your-service-id.cn-shanghai.pai-eas.aliyuncs.com/api/predict"token="<ACCESS_TOKEN>"response=requests.post(endpoint,headers={"Authorization":f"Bearer {token}","Content-Type":"application/json"},json={"model":"qwen-7b-prod","messages":[{"role":"user","content":"What is PAI-EAS?"}],"max_tokens":128})print(response.json()["choices"][0]["message"]["content"])
# Deploy v2 alongside v1pai eas update qwen-7b-prod \
--canary-image "pai-image-vllm:0.7-cu124-py310"\
--canary-model "oss://my-bucket/checkpoints/qwen25-7b-sft-002/"\
--canary-weight 10# 10% of traffic goes to v2# Monitor metrics, then increase if v2 looks goodpai eas update qwen-7b-prod --canary-weight 50# Full rolloverpai eas update qwen-7b-prod --canary-weight 100# Or rollbackpai eas update qwen-7b-prod --canary-weight 0
若想深入理解 EAS —— 包括冷启动缓解、warm pool 调优、TPS 仪表盘的“陷阱”等 —— 请参阅 PAI Part 4
。
# Programmatic fine-tuning from Model Galleryfrompai.modelimportRegisteredModel# Get the model from gallerybase_model=RegisteredModel(model_name="Qwen2.5-7B",model_provider="pai")# Fine-tune with your datatraining_job=base_model.fine_tune(training_data="oss://my-bucket/datasets/sft-v1/",instance_type="ecs.gn7i-c16g1.4xlarge",# 4x A10instance_count=1,hyperparameters={"epochs":3,"learning_rate":2e-5,"lora_rank":16,"lora_alpha":32,},output_path="oss://my-bucket/fine-tuned/qwen-7b-custom/")training_job.wait()print(f"Fine-tuned model: {training_job.output_path}")
[Your local machine]
|
| (upload training data)
v
[OSS bucket: my-bucket/datasets/]
|
| (mount as /mnt/data in DSW/DLC)
v
[PAI-DSW: explore data, write training script]
|
| (submit job to DLC)
v
[PAI-DLC: distributed training, 8x GPU]
|
| (write checkpoints to OSS)
v
[OSS bucket: my-bucket/checkpoints/]
|
| (deploy to EAS)
v
[PAI-EAS: model serving endpoint]
|
| (HTTPS inference API)
v
[Your application / users]
OSS 是 PAI 工作负载的主要数据存储。你将数据集上传至 OSS,在 DSW/DLC 中挂载,并将结果写回。关于模型 artifacts 的 Bucket 配置与生命周期策略,详见 Part 4: OSS Storage
。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Upload training data to OSS (from your local machine)importoss2auth=oss2.Auth("<AK>","<SK>")bucket=oss2.Bucket(auth,"https://oss-cn-shanghai.aliyuncs.com","my-bucket")# Upload a dataset directoryimportoslocal_dir="./datasets/sft-v1/"forroot,dirs,filesinos.walk(local_dir):forfnameinfiles:local_path=os.path.join(root,fname)oss_key=f"datasets/sft-v1/{os.path.relpath(local_path,local_dir)}"bucket.put_object_from_file(oss_key,local_path)print(f"Uploaded: {oss_key}")
1
2
3
4
# Or use ossutil (faster for large uploads)ossutil cp -r ./datasets/sft-v1/ oss://my-bucket/datasets/sft-v1/ \
--parallel 10\
--part-size 104857600
# Inside a DSW notebook or DLC job -- call DashScope for embeddingsfromdashscopeimportTextEmbeddingdefget_embeddings(texts:list[str])->list[list[float]]:"""Get embeddings using DashScope API from within PAI."""response=TextEmbedding.call(model="text-embedding-v3",input=texts,dimension=1024,)return[item["embedding"]foriteminresponse.output["embeddings"]]# Generate embeddings for your training dataimportjsonwithopen("/mnt/data/datasets/corpus.jsonl")asf:texts=[json.loads(line)["text"]forlineinf]# Batch processbatch_size=25all_embeddings=[]foriinrange(0,len(texts),batch_size):batch=texts[i:i+batch_size]embeddings=get_embeddings(batch)all_embeddings.extend(embeddings)print(f"Processed {i+len(batch)}/{len(texts)}")
# prepare_dataset.pyimportjsonraw_data=[{"instruction":"What is the return policy for electronics?","output":"Electronics can be returned within 15 days of purchase..."},{"instruction":"How do I track my order?","output":"You can track your order by logging into your account..."},# ... thousands more]# Convert to chat formatwithopen("sft_data.jsonl","w")asf:foriteminraw_data:record={"messages":[{"role":"system","content":"You are a helpful customer service assistant."},{"role":"user","content":item["instruction"]},{"role":"assistant","content":item["output"]}]}f.write(json.dumps(record,ensure_ascii=False)+"\n")print(f"Prepared {len(raw_data)} training examples")
# In a DSW Jupyter notebookimportjsonfromcollectionsimportCounter# Load and inspect the datasetwithopen("/mnt/data/datasets/customer-service/sft_data.jsonl")asf:data=[json.loads(line)forlineinf]print(f"Total examples: {len(data)}")print(f"Avg user message length: {sum(len(d['messages'][1]['content'])fordindata)/len(data):.0f} chars")print(f"Avg assistant message length: {sum(len(d['messages'][2]['content'])fordindata)/len(data):.0f} chars")# Check for quality issuesshort_responses=[dfordindataiflen(d["messages"][2]["content"])<20]print(f"Short responses (<20 chars): {len(short_responses)}")# Quick test: load the base model and generate a response WITHOUT fine-tuningfromtransformersimportAutoModelForCausalLM,AutoTokenizerimporttorchmodel_path="/mnt/data/models/Qwen2.5-7B"tokenizer=AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)model=AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.bfloat16,device_map="auto")# Test the base model on a customer service questionmessages=[{"role":"system","content":"You are a helpful customer service assistant."},{"role":"user","content":"What is the return policy for electronics?"}]text=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)inputs=tokenizer(text,return_tensors="pt").to(model.device)outputs=model.generate(**inputs,max_new_tokens=256,temperature=0.7)print("Base model response:")print(tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],skip_special_tokens=True))
# Production client with error handlingimportrequestsimporttimeclassCustomerServiceClient:def__init__(self,endpoint:str,token:str):self.endpoint=endpointself.token=tokenself.session=requests.Session()self.session.headers.update({"Authorization":f"Bearer {token}","Content-Type":"application/json"})defask(self,question:str,max_retries:int=3)->str:payload={"model":"customer-service","messages":[{"role":"system","content":"You are a helpful customer service assistant."},{"role":"user","content":question}],"max_tokens":512,"temperature":0.3}forattemptinrange(max_retries):try:resp=self.session.post(f"{self.endpoint}/v1/chat/completions",json=payload,timeout=30)resp.raise_for_status()returnresp.json()["choices"][0]["message"]["content"]exceptrequests.exceptions.RequestExceptionase:ifattempt<max_retries-1:time.sleep(2**attempt)continueraise# Usageclient=CustomerServiceClient(endpoint="https://your-service.cn-shanghai.pai-eas.aliyuncs.com",token="<ACCESS_TOKEN>")answer=client.ask("How do I track my order?")print(answer)